编辑树¶
设置¶
In [ ]
已复制!
pip install ydf -U
pip install ydf -U
In [1]
已复制!
import ydf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import ydf import numpy as np import matplotlib.pyplot as plt import pandas as pd
编辑树是什么意思?¶
在检查树 notebook 中,您学习了如何访问树结构。在本 notebook 中,我们将展示如何修改已训练模型的树结构以及如何手动创建一个新树。
数据集¶
为了使示例更容易理解,我们使用一个遵循二维网格的合成数据集。
In [2]
已复制!
def build_grid_dataset(resolution = 20):
"""Creates a 2d grid."""
vs = np.linspace(0, 1, resolution)
xs, ys = np.meshgrid(vs, vs)
return pd.DataFrame({
"x": xs.flatten(),
"y": ys.flatten()
})
dataset = build_grid_dataset()
dataset
def build_grid_dataset(resolution = 20): """创建一个二维网格。""" vs = np.linspace(0, 1, resolution) xs, ys = np.meshgrid(vs, vs) return pd.DataFrame({ "x": xs.flatten(), "y": ys.flatten() }) dataset = build_grid_dataset() dataset
Out[2]
x | y | |
---|---|---|
0 | 0.000000 | 0.0 |
1 | 0.052632 | 0.0 |
2 | 0.105263 | 0.0 |
3 | 0.157895 | 0.0 |
4 | 0.210526 | 0.0 |
... | ... | ... |
395 | 0.789474 | 1.0 |
396 | 0.842105 | 1.0 |
397 | 0.894737 | 1.0 |
398 | 0.947368 | 1.0 |
399 | 1.000000 | 1.0 |
400 行 × 2 列
plot_predictions
方法在由 build_grid_dataset
定义的二维网格上绘制预测结果。
让我们定义并绘制一个合成标签 x>=0.5 and y>=0.5
。
In [7]
已复制!
def plot_predictions(values, resolution = 20):
plt.imshow(np.reshape(values,[resolution,resolution]), interpolation="none")
dataset["label"] = (dataset.x >= 0.5) & (dataset.y >= 0.5)
plot_predictions(dataset["label"])
def plot_predictions(values, resolution = 20): plt.imshow(np.reshape(values,[resolution,resolution]), interpolation="none") dataset["label"] = (dataset.x >= 0.5) & (dataset.y >= 0.5) plot_predictions(dataset["label"])
编辑现有模型¶
我们训练一个单树模型,然后编辑其结构并检查其预测结果是否如预期般改变。
In [32]
已复制!
model = ydf.CartLearner(label="label", task=ydf.Task.REGRESSION).train(dataset)
model = ydf.CartLearner(label="label", task=ydf.Task.REGRESSION).train(dataset)
Train model on 400 examples Model trained in 0:00:00.002200
我们绘制模型树结构。我们期望看到一棵等效于 x>=0.5 and y>=0.5
的树。
In [33]
已复制!
model.print_tree(0)
model.print_tree(0)
'x' >= 0.5 [score=0.066036 missing=True] ├─(pos)─ 'y' >= 0.5 [score=0.2498 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0
看起来不错。
模型的预测结果与标签相似。
In [36]
已复制!
plot_predictions(model.predict(dataset))
plot_predictions(model.predict(dataset))
让我们更改第一个条件的阈值。我们将 x >= 0.5 替换为 x >= 0.8。
In [40]
已复制!
# Extract the tree
tree = model.get_tree(0)
tree
# 提取树 tree = model.get_tree(0) tree
Out[40]
Tree(root=NonLeaf(value=RegressionValue(num_examples=361.0, value=0.2548476457595825, standard_deviation=0.4357755420494278), condition=NumericalHigherThanCondition(missing=True, score=0.06603582203388214, attribute=1, threshold=0.5), pos_child=NonLeaf(value=RegressionValue(num_examples=179.0, value=0.5139665007591248, standard_deviation=0.49980489935966577), condition=NumericalHigherThanCondition(missing=True, score=0.24980494379997253, attribute=2, threshold=0.5), pos_child=Leaf(value=RegressionValue(num_examples=92.0, value=1.0, standard_deviation=0.0)), neg_child=Leaf(value=RegressionValue(num_examples=87.0, value=0.0, standard_deviation=0.0))), neg_child=Leaf(value=RegressionValue(num_examples=182.0, value=0.0, standard_deviation=0.0))))
In [41]
已复制!
# Change the tree
tree.root.condition.threshold = 0.8
# 更改树 tree.root.condition.threshold = 0.8
In [42]
已复制!
# Update the model
model.set_tree(0, tree)
# 更新模型 model.set_tree(0, tree)
让我们检查模型是否已更改
In [43]
已复制!
model.print_tree(0)
model.print_tree(0)
'x' >= 0.8 [score=0.066036 missing=True] ├─(pos)─ 'y' >= 0.5 [score=0.2498 missing=True] │ ├─(pos)─ value=1 sd=0 │ └─(neg)─ value=0 sd=0 └─(neg)─ value=0 sd=0
最后,让我们检查新的预测结果
In [44]
已复制!
plot_predictions(model.predict(dataset))
plot_predictions(model.predict(dataset))
创建模型¶
YDF 也支持创建新树。为此,我们将定义一个包含零棵树的模型,然后手动添加一棵树。
In [58]
已复制!
model = ydf.RandomForestLearner(label="label", num_trees=0, task=ydf.Task.REGRESSION).train(dataset)
model = ydf.RandomForestLearner(label="label", num_trees=0, task=ydf.Task.REGRESSION).train(dataset)
Train model on 400 examples Model trained in 0:00:00.003442
正如预期,模型不包含任何树。
In [59]
已复制!
model.num_trees()
model.num_trees()
Out[59]
0
让我们添加一棵新树。
在树中,输入特征由称为“列索引”的整数索引。model.input_features()
列出可用的输入特征及其对应的列索引。
In [60]
已复制!
model.input_features()
model.input_features()
Out[60]
[InputFeature(name='x', semantic=<Semantic.NUMERICAL: 1>, column_idx=1), InputFeature(name='y', semantic=<Semantic.NUMERICAL: 1>, column_idx=2)]
特征 "x" 的列索引是 1。我们定义了一棵具有单个条件测试 x >= 0.6
的树。
In [61]
已复制!
tree = ydf.tree.Tree(
root=ydf.tree.NonLeaf(
condition=ydf.tree.NumericalHigherThanCondition(attribute=1, # Feature "x"
threshold=0.6,
missing=False, # Value of the condition when the feature is missing. Not used here.
score=1, # How good is the condition. Not used here.
),
pos_child=ydf.tree.Leaf(value=ydf.tree.RegressionValue(num_examples=1.0, value=1.0)),
neg_child=ydf.tree.Leaf(value=ydf.tree.RegressionValue(num_examples=1.0, value=0.0)),
)
)
tree = ydf.tree.Tree( root=ydf.tree.NonLeaf( condition=ydf.tree.NumericalHigherThanCondition(attribute=1, # 特征 "x" threshold=0.6, missing=False, # 特征缺失时条件的值。此处未使用。 score=1, # 条件有多好。此处未使用。 ), pos_child=ydf.tree.Leaf(value=ydf.tree.RegressionValue(num_examples=1.0, value=1.0)), neg_child=ydf.tree.Leaf(value=ydf.tree.RegressionValue(num_examples=1.0, value=0.0)), ) )
我们绘制新创建的树。
注意:模型的 dataspec(可通过 model.data_spec()
获取)定义了模型期望的列(例如输入特征、标签、权重)。它还定义了列的名称及其字典(对于文本类别特征)。
In [62]
已复制!
print(tree.pretty(model.data_spec()))
print(tree.pretty(model.data_spec()))
'x' >= 0.6 [score=1 missing=False] ├─(pos)─ value=1 └─(neg)─ value=0
树被添加到模型中
In [63]
已复制!
model.add_tree(tree)
model.add_tree(tree)
最后,我们绘制模型预测结果。
In [64]
已复制!
plot_predictions(model.predict(dataset))
plot_predictions(model.predict(dataset))