跳到内容

CLI 快速入门

本页介绍如何使用 CLI API 训练、评估、分析、生成预测以及衡量二元分类模型的推理速度。

此处提供了端到端示例。

安装 YDF CLI

1. 前往 YDF GitHub GitHub 发布页面

2. 下载适用于您的操作系统的最新 CLI 版本。例如,要下载 Linux 的 CLI 版本,请点击 "cli_linux.zip" 文件旁边的“下载”按钮。

3. 将 ZIP 文件解压到您选择的目录,例如 unzip cli_linux.zip

4. 打开终端窗口并导航到您解压 ZIP 文件的目录。

每个可执行文件(例如 train, evaluate)执行不同的任务。例如,train 命令训练模型。

每个命令都在命令页面或使用 --help 标志进行解释

# Print the help of the 'train' command.
./train --help

下载数据集

在本例中,我们使用 UCI Adult 数据集。这是一个二元分类数据集,目标是预测个人的收入是否大于 50,000 美元。数据集中的特征是数值和分类特征的混合。

首先,我们从 UCI 机器学习存储库下载数据集的副本

DATASET_SRC=https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset
wget -q ${DATASET_SRC}/adult_train.csv -O adult_train.csv
wget -q ${DATASET_SRC}/adult_test.csv -O adult_test.csv

训练数据集的前 3 个示例如下

$ head -n 4 adult_train.csv

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
44,Private,228057,7th-8th,4,Married-civ-spouse,Machine-op-inspct,Wife,White,Female,0,0,40,Dominican-Republic,<=50K
20,Private,299047,Some-college,10,Never-married,Other-service,Not-in-family,White,Female,0,0,20,United-States,<=50K
40,Private,342164,HS-grad,9,Separated,Adm-clerical,Unmarried,White,Female,0,0,37,United-States,<=50K

数据集存储在两个 CSV 文件中,一个用于训练,一个用于测试。YDF 可以直接加载 CSV 文件,这使其成为使用此数据集的便捷方式。

将数据集路径传递给命令时,数据集的格式始终使用前缀指定。例如,路径 csv:/path/to/my/file 中的前缀 csv: 表示该文件是 csv 文件。有关支持的数据集格式列表,请参阅此处

创建 dataspec

dataspec数据集规范的简称)是数据集的描述。它包括可用列的列表、每列的语义(或类型)以及任何其他元数据,例如字典或缺失值的比例。

dataspec 可以使用 infer_dataspec 命令自动计算并存储在 dataspec 文件中。

# Create the dataspec
./infer_dataspec --dataset=csv:adult_train.csv --output=dataspec.pbtxt

在训练模型之前查看 dataspec 是检测数据集中问题的绝佳方法,例如缺失值或不正确的数据类型。

# Display the dataspec
./show_dataspec --dataspec=dataspec.pbtxt

结果是

Number of records: 22792
Number of columns: 15

Number of columns by type:
    CATEGORICAL: 9 (60%)
    NUMERICAL: 6 (40%)

Columns:

CATEGORICAL: 9 (60%)
    3: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%)
    14: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%)
    5: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%)
    13: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%)
    6: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:1 (0.00464425%) most-frequent:"Prof-specialty" 2870 (13.329%)
    8: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%)
    7: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%)
    9: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%)
    1: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:1 (0.0046436%) most-frequent:"Private" 15879 (73.7358%)

NUMERICAL: 6 (40%)
    0: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661
    10: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48
    11: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01
    4: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427
    2: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423
    12: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

此示例数据集包含 22,792 个示例和 15 列。其中有 9 个分类列和 6 个数值列。列的语义指其包含的数据类型。

例如,education 列是一个分类列,有 17 个唯一的可能值。最常见的值是 HS-grad(占所有值的 32%)。

(可选) 使用指南创建 dataspec

在此示例中,列的语义被正确检测到。但是,当值表示不明确时,情况可能并非如此。例如,在 .csv 文件中无法自动检测枚举值(即表示为整数的分类值)的语义。

在这种情况下,我们可以使用额外标志重新运行 infer_dataspec 命令,以指示误检列的真实语义。例如,要强制将 age 检测为数值列,我们将运行

# Force the detection of 'age' as numerical.
cat <<EOF > guide.pbtxt
column_guides {
    column_name_pattern: "^age$"
    type: NUMERICAL
}
EOF

./infer_dataspec --dataset=csv:adult_train.csv --guide=guide.pbtxt --output=dataspec.pbtxt

训练模型

模型使用 train 命令进行训练。标签、特征、超参数和其他训练设置在训练配置文件中指定。

# Create a training configuration file
cat <<EOF > train_config.pbtxt
task: CLASSIFICATION
label: "income"
learner: "GRADIENT_BOOSTED_TREES"
# Change learner-specific hyper-parameters.
[yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
  num_trees: 500
}
EOF

# Train the model
./train \
  --dataset=csv:adult_train.csv \
  --dataspec=dataspec.pbtxt \
  --config=train_config.pbtxt \
  --output=model

结果

[INFO train.cc:96] Start training model.
[INFO abstract_learner.cc:119] No input feature specified. Using all the available input features as input signal.
[INFO abstract_learner.cc:133] The label "income" was removed from the input feature set.
[INFO vertical_dataset_io.cc:74] 100 examples scanned.
[INFO vertical_dataset_io.cc:80] 22792 examples read. Memory: usage:1MB allocated:1MB. 0 (0%) examples have been skipped.
[INFO abstract_learner.cc:119] No input feature specified. Using all the available input features as input signal.
[INFO abstract_learner.cc:133] The label "income" was removed from the input feature set.
[INFO gradient_boosted_trees.cc:405] Default loss set to BINOMIAL_LOG_LIKELIHOOD
[INFO gradient_boosted_trees.cc:1008] Training gradient boosted tree on 22792 example(s) and 14 feature(s).
[INFO gradient_boosted_trees.cc:1051] 20533 examples used for training and 2259 examples used for validation
[INFO gradient_boosted_trees.cc:1434]   num-trees:1 train-loss:1.015975 train-accuracy:0.761895 valid-loss:1.071430 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:1436]   num-trees:2 train-loss:0.955303 train-accuracy:0.761895 valid-loss:1.007908 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:2871] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.579583
[INFO gradient_boosted_trees.cc:230] Truncates the model to 136 tree(s) i.e. 136  iteration(s).
[INFO gradient_boosted_trees.cc:264] Final model num-trees:136 valid-loss:0.579583 valid-accuracy:0.870297

几点说明

  • 由于未指定输入特征,因此除标签外的所有列都用作输入特征。

  • DF 原生支持数值、分类和分类集特征以及缺失值。数值特征不需要归一化,分类字符串值不需要编码到字典中。

  • num_trees 超参数外,未指定任何训练超参数。所有超参数的默认值设置使其能够在大多数情况下提供合理的结果。稍后我们将讨论替代的默认值(称为超参数模板)和超参数的自动调优。所有超参数及其默认值的列表可在超参数页面中找到。

  • 训练未提供验证数据集。并非所有学习器都需要验证数据集。但是,本例中使用的 GRADIENT_BOOSTED_TREES 学习器如果在启用提前停止(默认情况)的情况下需要验证数据集。在这种情况下,训练数据集的 10% 用于验证。可以使用 validation_ratio 参数更改此比例。或者,可以使用 --valid_dataset 标志提供验证数据集。最终模型包含 136 棵树,验证准确率约为 0.8702。

显示模型信息

使用 show_model 命令显示模型的详细信息。

# Show information about the model.
./show_model --model=model

结果示例

Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "income"

Input Features (14):
    age
    workclass
    fnlwgt
    education
    education_num
    marital_status
    occupation
    relationship
    race
    sex
    capital_gain
    capital_loss
    hours_per_week
    native_country

No weights

Variable Importance: MEAN_MIN_DEPTH:
    1.         "income"  4.868164 ################
    2.            "sex"  4.625136 #############
    3.           "race"  4.590606 #############
    ...
   13.     "occupation"  3.640103 ####
   14. "marital_status"  3.626898 ###
   15.            "age"  3.219872

Variable Importance: NUM_AS_ROOT:
    1.            "age" 28.000000 ################
    2. "marital_status" 22.000000 ############
    3.   "capital_gain" 19.000000 ##########
    ...
   11.  "education_num"  3.000000
   12.     "occupation"  2.000000
   13. "native_country"  2.000000

Variable Importance: NUM_NODES:
    1.     "occupation" 516.000000 ################
    2.            "age" 431.000000 #############
    3.      "education" 424.000000 ############
    ...
   12.  "education_num" 73.000000 #
   13.            "sex" 39.000000
   14.           "race" 26.000000

Variable Importance: SUM_SCORE:
    1.   "relationship" 3103.387636 ################
    2.   "capital_gain" 2041.557944 ##########
    3.      "education" 1090.544247 #####
    ...
   12.      "workclass" 176.876787
   13.            "sex" 49.287215
   14.           "race" 13.923084



Loss: BINOMIAL_LOG_LIKELIHOOD
Validation loss value: 0.579583
Number of trees per iteration: 1
Node format: BLOB_SEQUENCE
Number of trees: 136
Total number of nodes: 7384

Number of nodes by tree:
Count: 136 Average: 54.2941 StdDev: 5.7779
Min: 33 Max: 63 Ignored: 0
----------------------------------------------
[ 33, 34)  2   1.47%   1.47% #
...
[ 60, 62) 16  11.76%  96.32% ########
[ 62, 63]  5   3.68% 100.00% ##

Depth by leafs:
Count: 3760 Average: 4.87739 StdDev: 0.412078
Min: 2 Max: 5 Ignored: 0
----------------------------------------------
[ 2, 3)   14   0.37%   0.37%
[ 3, 4)   75   1.99%   2.37%
[ 4, 5)  269   7.15%   9.52% #
[ 5, 5] 3402  90.48% 100.00% ##########

Number of training obs by leaf:
Count: 3760 Average: 742.683 StdDev: 2419.64
Min: 5 Max: 19713 Ignored: 0
----------------------------------------------
[     5,   990) 3270  86.97%  86.97% ##########
[   990,  1975)  163   4.34%  91.30%
...
[ 17743, 18728)   10   0.27%  99.55%
[ 18728, 19713]   17   0.45% 100.00%

Attribute in nodes:
    516 : occupation [CATEGORICAL]
    431 : age [NUMERICAL]
    424 : education [CATEGORICAL]
    420 : fnlwgt [NUMERICAL]
    297 : capital_gain [NUMERICAL]
    291 : hours_per_week [NUMERICAL]
    266 : capital_loss [NUMERICAL]
    245 : native_country [CATEGORICAL]
    224 : relationship [CATEGORICAL]
    206 : workclass [CATEGORICAL]
    166 : marital_status [CATEGORICAL]
    73 : education_num [NUMERICAL]
    39 : sex [CATEGORICAL]
    26 : race [CATEGORICAL]

Attribute in nodes with depth <= 0:
    28 : age [NUMERICAL]
    22 : marital_status [CATEGORICAL]
    19 : capital_gain [NUMERICAL]
    12 : capital_loss [NUMERICAL]
    11 : hours_per_week [NUMERICAL]
    11 : fnlwgt [NUMERICAL]
    8 : relationship [CATEGORICAL]
    8 : education [CATEGORICAL]
    6 : race [CATEGORICAL]
    4 : sex [CATEGORICAL]
    3 : education_num [NUMERICAL]
    2 : native_country [CATEGORICAL]
    2 : occupation [CATEGORICAL]

...

Condition type in nodes:
    1844 : ContainsBitmapCondition
    1778 : HigherCondition
    2 : ContainsCondition
Condition type in nodes with depth <= 0:
    84 : HigherCondition
    52 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
    243 : HigherCondition
    165 : ContainsBitmapCondition
...

可以使用 --full_definition 标志打印模型的树结构。

评估模型

评估结果使用 evaluate 命令计算并以文本(--format=text,默认)或带图表的 HTML(--format=html)格式打印。

# Evaluate the model and print the result in the console.
./evaluate --dataset=csv:adult_test.csv --model=model

结果

Evaluation:
Number of predictions (without weights): 9769
Number of predictions (with weights): 9769
Task: CLASSIFICATION
Label: income

Accuracy: 0.874399  CI95[W][0.86875 0.879882]
LogLoss: 0.27768
ErrorRate: 0.125601

Default Accuracy: 0.758727
Default LogLoss: 0.552543
Default ErrorRate: 0.241273

Confusion Table:
truth\prediction
       <OOD>  <=50K  >50K
<OOD>      0      0     0
<=50K      0   6971   441
 >50K      0    786  1571
Total: 9769

One vs other classes:
  "<=50K" vs. the others
    auc: 0.929207  CI95[H][0.924358 0.934056] CI95[B][0.924076 0.934662]
    p/r-auc: 0.975657  CI95[L][0.971891 0.97893] CI95[B][0.973397 0.977947]
    ap: 0.975656   CI95[B][0.973393 0.977944]

  ">50K" vs. the others
    auc: 0.929207  CI95[H][0.921866 0.936549] CI95[B][0.923642 0.934566]
    p/r-auc: 0.830708  CI95[L][0.815025 0.845313] CI95[B][0.817588 0.843956]
    ap: 0.830674   CI95[B][0.817513 0.843892]

观察结果

  • 测试数据集包含 9769 个示例。
  • 测试准确率为 0.874399,95% 置信区间边界为 [0.86875; 0.879882]。
  • 测试 AUC 为 0.929207,通过闭合形式计算的 95% 置信区间边界为 [0.924358 0.934056],通过自助法计算的为 [0.973397 0.977947]。
  • PR-AUC 和 AP 指标也可用。

以下命令评估模型并将评估报告导出到 HTML 文件。

# Evaluate the model and print the result in an Html file.
./evaluate --dataset=csv:adult_test.csv --model=model --format=html > evaluation.html

Evaluation plot on the adult dataset

生成预测

使用 predict 命令计算预测并导出到文件。

# Exports the prediction of the model to a csv file
./predict --dataset=csv:adult_test.csv --model=model --output=csv:predictions.csv

# Show the predictions for the first 3 examples
head -n 4 predictions.csv

结果

<=50K,>50K
0.978384,0.0216162
0.641894,0.358106
0.180569,0.819431

基准测试模型速度

在时间要求严格的应用中,模型的推理速度可能至关重要。benchmark_inference 命令测量模型的平均推理时间。

YDF 有多种算法来计算模型的预测。这些算法在速度和覆盖范围上有所不同。生成预测时,YDF 自动使用兼容的最快算法。

benchmark_inference 显示所有兼容算法的速度。

推理算法是单线程的,这意味着它们一次只能处理一个数据点。用户可以自行使用多线程进行推理并行化。

# Benchmark the inference speed of the model
./benchmark_inference --dataset=csv:adult_test.csv --model=model

结果

batch_size : 100  num_runs : 20
time/example(us)  time/batch(us)  method
----------------------------------------
            0.89              89  GradientBoostedTreesQuickScorerExtended [virtual interface]
          5.8475          584.75  GradientBoostedTreesGeneric [virtual interface]
          12.485          1248.5  Generic slow engine
----------------------------------------

我们看到模型平均每个示例可以运行 0.89 µs(微秒)。