CLI 快速入门¶
本页介绍如何使用 CLI API 训练、评估、分析、生成预测以及衡量二元分类模型的推理速度。
此处提供了端到端示例。
安装 YDF CLI¶
1. 前往 YDF GitHub GitHub 发布页面。
2. 下载适用于您的操作系统的最新 CLI 版本。例如,要下载 Linux 的 CLI 版本,请点击 "cli_linux.zip" 文件旁边的“下载”按钮。
3. 将 ZIP 文件解压到您选择的目录,例如 unzip cli_linux.zip
。
4. 打开终端窗口并导航到您解压 ZIP 文件的目录。
每个可执行文件(例如 train
, evaluate
)执行不同的任务。例如,train
命令训练模型。
每个命令都在命令页面或使用 --help
标志进行解释
下载数据集¶
在本例中,我们使用 UCI Adult 数据集。这是一个二元分类数据集,目标是预测个人的收入是否大于 50,000 美元。数据集中的特征是数值和分类特征的混合。
首先,我们从 UCI 机器学习存储库下载数据集的副本
DATASET_SRC=https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset
wget -q ${DATASET_SRC}/adult_train.csv -O adult_train.csv
wget -q ${DATASET_SRC}/adult_test.csv -O adult_test.csv
训练数据集的前 3 个示例如下
$ head -n 4 adult_train.csv
age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
44,Private,228057,7th-8th,4,Married-civ-spouse,Machine-op-inspct,Wife,White,Female,0,0,40,Dominican-Republic,<=50K
20,Private,299047,Some-college,10,Never-married,Other-service,Not-in-family,White,Female,0,0,20,United-States,<=50K
40,Private,342164,HS-grad,9,Separated,Adm-clerical,Unmarried,White,Female,0,0,37,United-States,<=50K
数据集存储在两个 CSV 文件中,一个用于训练,一个用于测试。YDF 可以直接加载 CSV 文件,这使其成为使用此数据集的便捷方式。
将数据集路径传递给命令时,数据集的格式始终使用前缀指定。例如,路径 csv:/path/to/my/file
中的前缀 csv:
表示该文件是 csv 文件。有关支持的数据集格式列表,请参阅此处。
创建 dataspec¶
dataspec(数据集规范的简称)是数据集的描述。它包括可用列的列表、每列的语义(或类型)以及任何其他元数据,例如字典或缺失值的比例。
dataspec 可以使用 infer_dataspec
命令自动计算并存储在 dataspec 文件中。
在训练模型之前查看 dataspec 是检测数据集中问题的绝佳方法,例如缺失值或不正确的数据类型。
结果是
Number of records: 22792
Number of columns: 15
Number of columns by type:
CATEGORICAL: 9 (60%)
NUMERICAL: 6 (40%)
Columns:
CATEGORICAL: 9 (60%)
3: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%)
14: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%)
5: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%)
13: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%)
6: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:1 (0.00464425%) most-frequent:"Prof-specialty" 2870 (13.329%)
8: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%)
7: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%)
9: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%)
1: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:1 (0.0046436%) most-frequent:"Private" 15879 (73.7358%)
NUMERICAL: 6 (40%)
0: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661
10: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48
11: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01
4: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427
2: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423
12: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249
Terminology:
nas: Number of non-available (i.e. missing) values.
ood: Out of dictionary.
manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
tokenized: The attribute value is obtained through tokenization.
has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
vocab-size: Number of unique values.
此示例数据集包含 22,792 个示例和 15 列。其中有 9 个分类列和 6 个数值列。列的语义指其包含的数据类型。
例如,education
列是一个分类列,有 17 个唯一的可能值。最常见的值是 HS-grad
(占所有值的 32%)。
(可选) 使用指南创建 dataspec¶
在此示例中,列的语义被正确检测到。但是,当值表示不明确时,情况可能并非如此。例如,在 .csv 文件中无法自动检测枚举值(即表示为整数的分类值)的语义。
在这种情况下,我们可以使用额外标志重新运行 infer_dataspec
命令,以指示误检列的真实语义。例如,要强制将 age
检测为数值列,我们将运行
# Force the detection of 'age' as numerical.
cat <<EOF > guide.pbtxt
column_guides {
column_name_pattern: "^age$"
type: NUMERICAL
}
EOF
./infer_dataspec --dataset=csv:adult_train.csv --guide=guide.pbtxt --output=dataspec.pbtxt
训练模型¶
模型使用 train
命令进行训练。标签、特征、超参数和其他训练设置在训练配置文件中指定。
# Create a training configuration file
cat <<EOF > train_config.pbtxt
task: CLASSIFICATION
label: "income"
learner: "GRADIENT_BOOSTED_TREES"
# Change learner-specific hyper-parameters.
[yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
num_trees: 500
}
EOF
# Train the model
./train \
--dataset=csv:adult_train.csv \
--dataspec=dataspec.pbtxt \
--config=train_config.pbtxt \
--output=model
结果
[INFO train.cc:96] Start training model.
[INFO abstract_learner.cc:119] No input feature specified. Using all the available input features as input signal.
[INFO abstract_learner.cc:133] The label "income" was removed from the input feature set.
[INFO vertical_dataset_io.cc:74] 100 examples scanned.
[INFO vertical_dataset_io.cc:80] 22792 examples read. Memory: usage:1MB allocated:1MB. 0 (0%) examples have been skipped.
[INFO abstract_learner.cc:119] No input feature specified. Using all the available input features as input signal.
[INFO abstract_learner.cc:133] The label "income" was removed from the input feature set.
[INFO gradient_boosted_trees.cc:405] Default loss set to BINOMIAL_LOG_LIKELIHOOD
[INFO gradient_boosted_trees.cc:1008] Training gradient boosted tree on 22792 example(s) and 14 feature(s).
[INFO gradient_boosted_trees.cc:1051] 20533 examples used for training and 2259 examples used for validation
[INFO gradient_boosted_trees.cc:1434] num-trees:1 train-loss:1.015975 train-accuracy:0.761895 valid-loss:1.071430 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:1436] num-trees:2 train-loss:0.955303 train-accuracy:0.761895 valid-loss:1.007908 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:2871] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.579583
[INFO gradient_boosted_trees.cc:230] Truncates the model to 136 tree(s) i.e. 136 iteration(s).
[INFO gradient_boosted_trees.cc:264] Final model num-trees:136 valid-loss:0.579583 valid-accuracy:0.870297
几点说明
-
由于未指定输入特征,因此除标签外的所有列都用作输入特征。
-
DF 原生支持数值、分类和分类集特征以及缺失值。数值特征不需要归一化,分类字符串值不需要编码到字典中。
-
除
num_trees
超参数外,未指定任何训练超参数。所有超参数的默认值设置使其能够在大多数情况下提供合理的结果。稍后我们将讨论替代的默认值(称为超参数模板)和超参数的自动调优。所有超参数及其默认值的列表可在超参数页面中找到。 -
训练未提供验证数据集。并非所有学习器都需要验证数据集。但是,本例中使用的
GRADIENT_BOOSTED_TREES
学习器如果在启用提前停止(默认情况)的情况下需要验证数据集。在这种情况下,训练数据集的 10% 用于验证。可以使用validation_ratio
参数更改此比例。或者,可以使用--valid_dataset
标志提供验证数据集。最终模型包含 136 棵树,验证准确率约为 0.8702。
显示模型信息¶
使用 show_model
命令显示模型的详细信息。
结果示例
Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "income"
Input Features (14):
age
workclass
fnlwgt
education
education_num
marital_status
occupation
relationship
race
sex
capital_gain
capital_loss
hours_per_week
native_country
No weights
Variable Importance: MEAN_MIN_DEPTH:
1. "income" 4.868164 ################
2. "sex" 4.625136 #############
3. "race" 4.590606 #############
...
13. "occupation" 3.640103 ####
14. "marital_status" 3.626898 ###
15. "age" 3.219872
Variable Importance: NUM_AS_ROOT:
1. "age" 28.000000 ################
2. "marital_status" 22.000000 ############
3. "capital_gain" 19.000000 ##########
...
11. "education_num" 3.000000
12. "occupation" 2.000000
13. "native_country" 2.000000
Variable Importance: NUM_NODES:
1. "occupation" 516.000000 ################
2. "age" 431.000000 #############
3. "education" 424.000000 ############
...
12. "education_num" 73.000000 #
13. "sex" 39.000000
14. "race" 26.000000
Variable Importance: SUM_SCORE:
1. "relationship" 3103.387636 ################
2. "capital_gain" 2041.557944 ##########
3. "education" 1090.544247 #####
...
12. "workclass" 176.876787
13. "sex" 49.287215
14. "race" 13.923084
Loss: BINOMIAL_LOG_LIKELIHOOD
Validation loss value: 0.579583
Number of trees per iteration: 1
Node format: BLOB_SEQUENCE
Number of trees: 136
Total number of nodes: 7384
Number of nodes by tree:
Count: 136 Average: 54.2941 StdDev: 5.7779
Min: 33 Max: 63 Ignored: 0
----------------------------------------------
[ 33, 34) 2 1.47% 1.47% #
...
[ 60, 62) 16 11.76% 96.32% ########
[ 62, 63] 5 3.68% 100.00% ##
Depth by leafs:
Count: 3760 Average: 4.87739 StdDev: 0.412078
Min: 2 Max: 5 Ignored: 0
----------------------------------------------
[ 2, 3) 14 0.37% 0.37%
[ 3, 4) 75 1.99% 2.37%
[ 4, 5) 269 7.15% 9.52% #
[ 5, 5] 3402 90.48% 100.00% ##########
Number of training obs by leaf:
Count: 3760 Average: 742.683 StdDev: 2419.64
Min: 5 Max: 19713 Ignored: 0
----------------------------------------------
[ 5, 990) 3270 86.97% 86.97% ##########
[ 990, 1975) 163 4.34% 91.30%
...
[ 17743, 18728) 10 0.27% 99.55%
[ 18728, 19713] 17 0.45% 100.00%
Attribute in nodes:
516 : occupation [CATEGORICAL]
431 : age [NUMERICAL]
424 : education [CATEGORICAL]
420 : fnlwgt [NUMERICAL]
297 : capital_gain [NUMERICAL]
291 : hours_per_week [NUMERICAL]
266 : capital_loss [NUMERICAL]
245 : native_country [CATEGORICAL]
224 : relationship [CATEGORICAL]
206 : workclass [CATEGORICAL]
166 : marital_status [CATEGORICAL]
73 : education_num [NUMERICAL]
39 : sex [CATEGORICAL]
26 : race [CATEGORICAL]
Attribute in nodes with depth <= 0:
28 : age [NUMERICAL]
22 : marital_status [CATEGORICAL]
19 : capital_gain [NUMERICAL]
12 : capital_loss [NUMERICAL]
11 : hours_per_week [NUMERICAL]
11 : fnlwgt [NUMERICAL]
8 : relationship [CATEGORICAL]
8 : education [CATEGORICAL]
6 : race [CATEGORICAL]
4 : sex [CATEGORICAL]
3 : education_num [NUMERICAL]
2 : native_country [CATEGORICAL]
2 : occupation [CATEGORICAL]
...
Condition type in nodes:
1844 : ContainsBitmapCondition
1778 : HigherCondition
2 : ContainsCondition
Condition type in nodes with depth <= 0:
84 : HigherCondition
52 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
243 : HigherCondition
165 : ContainsBitmapCondition
...
可以使用 --full_definition
标志打印模型的树结构。
评估模型¶
评估结果使用 evaluate
命令计算并以文本(--format=text
,默认)或带图表的 HTML(--format=html
)格式打印。
# Evaluate the model and print the result in the console.
./evaluate --dataset=csv:adult_test.csv --model=model
结果
Evaluation:
Number of predictions (without weights): 9769
Number of predictions (with weights): 9769
Task: CLASSIFICATION
Label: income
Accuracy: 0.874399 CI95[W][0.86875 0.879882]
LogLoss: 0.27768
ErrorRate: 0.125601
Default Accuracy: 0.758727
Default LogLoss: 0.552543
Default ErrorRate: 0.241273
Confusion Table:
truth\prediction
<OOD> <=50K >50K
<OOD> 0 0 0
<=50K 0 6971 441
>50K 0 786 1571
Total: 9769
One vs other classes:
"<=50K" vs. the others
auc: 0.929207 CI95[H][0.924358 0.934056] CI95[B][0.924076 0.934662]
p/r-auc: 0.975657 CI95[L][0.971891 0.97893] CI95[B][0.973397 0.977947]
ap: 0.975656 CI95[B][0.973393 0.977944]
">50K" vs. the others
auc: 0.929207 CI95[H][0.921866 0.936549] CI95[B][0.923642 0.934566]
p/r-auc: 0.830708 CI95[L][0.815025 0.845313] CI95[B][0.817588 0.843956]
ap: 0.830674 CI95[B][0.817513 0.843892]
观察结果
- 测试数据集包含 9769 个示例。
- 测试准确率为 0.874399,95% 置信区间边界为 [0.86875; 0.879882]。
- 测试 AUC 为 0.929207,通过闭合形式计算的 95% 置信区间边界为 [0.924358 0.934056],通过自助法计算的为 [0.973397 0.977947]。
- PR-AUC 和 AP 指标也可用。
以下命令评估模型并将评估报告导出到 HTML 文件。
# Evaluate the model and print the result in an Html file.
./evaluate --dataset=csv:adult_test.csv --model=model --format=html > evaluation.html
生成预测¶
使用 predict
命令计算预测并导出到文件。
# Exports the prediction of the model to a csv file
./predict --dataset=csv:adult_test.csv --model=model --output=csv:predictions.csv
# Show the predictions for the first 3 examples
head -n 4 predictions.csv
结果
基准测试模型速度¶
在时间要求严格的应用中,模型的推理速度可能至关重要。benchmark_inference
命令测量模型的平均推理时间。
YDF 有多种算法来计算模型的预测。这些算法在速度和覆盖范围上有所不同。生成预测时,YDF 自动使用兼容的最快算法。
benchmark_inference
显示所有兼容算法的速度。
推理算法是单线程的,这意味着它们一次只能处理一个数据点。用户可以自行使用多线程进行推理并行化。
# Benchmark the inference speed of the model
./benchmark_inference --dataset=csv:adult_test.csv --model=model
结果
batch_size : 100 num_runs : 20
time/example(us) time/batch(us) method
----------------------------------------
0.89 89 GradientBoostedTreesQuickScorerExtended [virtual interface]
5.8475 584.75 GradientBoostedTreesGeneric [virtual interface]
12.485 1248.5 Generic slow engine
----------------------------------------
我们看到模型平均每个示例可以运行 0.89 µs(微秒)。