检查树¶
设置¶
In [ ]
已复制!
pip install ydf -U
pip install ydf -U
In [4]
已复制!
import ydf
import numpy as np
import ydf import numpy as np
In [5]
已复制!
dataset = {
"x1": np.array([0, 0, 0, 1, 1, 1]),
"x2": np.array([1, 1, 0, 0, 1, 1]),
"y": np.array([0, 0, 0, 0, 1, 1]),
}
dataset
dataset = { "x1": np.array([0, 0, 0, 1, 1, 1]), "x2": np.array([1, 1, 0, 0, 1, 1]), "y": np.array([0, 0, 0, 0, 1, 1]), } dataset
Out[5]
{'x1': array([0, 0, 0, 1, 1, 1]), 'x2': array([1, 1, 0, 0, 1, 1]), 'y': array([0, 0, 0, 0, 1, 1])}
训练模型¶
In [8]
已复制!
model = ydf.CartLearner(label="y", min_examples=1, task=ydf.Task.REGRESSION).train(dataset)
model.describe()
model = ydf.CartLearner(label="y", min_examples=1, task=ydf.Task.REGRESSION).train(dataset) model.describe()
Train model on 6 examples Model trained in 0:00:00.000728
Out[8]
名称 : RANDOM_FOREST
任务 : 回归
标签 : y
特征 (2) : x1 x2
权重 : 无
使用调优器训练 : 否
模型大小 : 3 kB
任务 : 回归
标签 : y
特征 (2) : x1 x2
权重 : 无
使用调优器训练 : 否
模型大小 : 3 kB
Number of records: 6 Number of columns: 3 Number of columns by type: NUMERICAL: 3 (100%) Columns: NUMERICAL: 3 (100%) 0: "y" NUMERICAL mean:0.333333 min:0 max:1 sd:0.471405 1: "x1" NUMERICAL mean:0.5 min:0 max:1 sd:0.5 2: "x2" NUMERICAL mean:0.666667 min:0 max:1 sd:0.471405 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
以下评估是在验证集或袋外数据集上计算的。
随机森林没有袋外评估训练日志。请使用 compute_oob_performances=True 参数训练模型以计算训练日志。确保训练日志没有因 pure_serving_model=True 参数而被移除。
变量重要性衡量输入特征对模型的重要性。
1. "x1" 1.000000 ################ 2. "x2" 0.500000
1. "x1" 1.000000
1. "x1" 1.000000 2. "x2" 1.000000
1. "x1" 0.666667 2. "x2" 0.666667
这些变量重要性是在训练期间计算的。分析测试数据集上的模型时,可以获得更多,且可能更具信息量的变量重要性。
树的数量 : 1
Tree #0: "x1">=0.5 [s:0.111111 n:6 np:3 miss:1] ; pred:0.333333 ├─(pos)─ "x2">=0.5 [s:0.222222 n:3 np:2 miss:1] ; pred:0.666667 | ├─(pos)─ pred:1 | └─(neg)─ pred:0 └─(neg)─ pred:0
绘制模型图¶
模型的树结构可以在 model.describe()
的“结构”标签页中看到。您也可以使用 print_tree
方法打印树。
In [9]
已复制!
model.print_tree()
model.print_tree()
'x1' >= 0.5 [score=0.11111 missing=True] ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0
访问树结构¶
get_tree
和 get_all_trees
方法允许以编程方式访问树的结构。
注意: CART 模型只有一个树,因此 tree_idx
参数设置为 0
。对于具有多个树的模型,可以使用 model.num_trees()
获取树的数量。
In [11]
已复制!
tree = model.get_tree(tree_idx=0)
tree
tree = model.get_tree(tree_idx=0) tree
Out[11]
Tree(root=NonLeaf(value=RegressionValue(num_examples=6.0, value=0.3333333432674408, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.1111111119389534, attribute=1, threshold=0.5), pos_child=NonLeaf(value=RegressionValue(num_examples=3.0, value=0.6666666865348816, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5), pos_child=Leaf(value=RegressionValue(num_examples=2.0, value=1.0, standard_deviation=0.0)), neg_child=Leaf(value=RegressionValue(num_examples=1.0, value=0.0, standard_deviation=0.0))), neg_child=Leaf(value=RegressionValue(num_examples=3.0, value=0.0, standard_deviation=0.0))))
您能识别出上面打印的树的结构吗?您可以访问树的部分内容。例如,您可以访问关于 x2
的条件
In [12]
已复制!
tree.root.pos_child.condition
tree.root.pos_child.condition
Out[12]
NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5)
为了以更易读的形式显示树,可以使用 pretty
函数。
In [14]
已复制!
print(tree.pretty(model.data_spec()))
print(tree.pretty(model.data_spec()))
'x1' >= 0.5 [score=0.11111 missing=True] ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0