设置¶
# Install dependencies
!pip install ydf -U -q
!pip install tensorflow -U -q
!pip install optax pandas numpy -U -q
!pip install jax[cpu] -U
# OR
# !pip install jax[cuda12] -U -q
# See https://jax.net.cn/en/stable/installation.html for JAX variations.
import tempfile
import jax
from jax.experimental import jax2tf # To export JAX model to SavedModel
import optax # To finetune YDF+JAX models
import pandas as pd # We use Pandas to load small datasets
import tensorflow as tf # To create SavedModels
import ydf # Yggdrasil Decision Forests
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# Download and load the dataset as Pandas DataFrames
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
label = "income"
# Print the first 5 training examples
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
首先,我们在数据集上训练一个 YDF 模型。
learner = ydf.GradientBoostedTreesLearner(label=label)
model = learner.train(train_ds)
Train model on 22792 examples Model trained in 0:00:02.277830
我们将 YDF 模型转换为 JAX 函数。
jax_model = model.to_jax_function()
jax_model
对象包含三个字段。
predict
: 一个用于进行预测的 JAX 函数。encoder
: 一个可调用类,用于为predict
准备示例。由于 JAX 不支持字符串值,因此必须在调用predict
之前准备分类字符串输入特征。params
: 一个可选的 Jax Arrays 字典,定义模型的可微分参数。默认情况下,params
为 None,并且predict
不接受任何参数。我们将在第二节中展示如何使用params
。
我们为测试集中的前 5 个示例生成预测。
首先,我们选择一些示例并对其进行编码。
# Select the first 5 examples from the Pandas Dataframe and remove the labels.
selected_examples = test_ds[:5].drop(model.label(), axis=1)
# Encode the examples into a dictionary of JAX arrays.
jax_selected_examples = jax_model.encoder(selected_examples)
jax_selected_examples
{'age': Array([39, 40, 40, 35, 23], dtype=int32), 'workclass': Array([4, 1, 1, 6, 3], dtype=int32), 'fnlwgt': Array([ 77516, 121772, 193524, 76845, 190709], dtype=int32), 'education': Array([ 3, 5, 13, 11, 7], dtype=int32), 'education_num': Array([13, 11, 16, 5, 12], dtype=int32), 'marital_status': Array([2, 1, 1, 1, 2], dtype=int32), 'occupation': Array([ 4, 3, 1, 10, 12], dtype=int32), 'relationship': Array([2, 1, 1, 1, 2], dtype=int32), 'race': Array([1, 3, 1, 2, 1], dtype=int32), 'sex': Array([1, 1, 1, 1, 1], dtype=int32), 'capital_gain': Array([2174, 0, 0, 0, 0], dtype=int32), 'capital_loss': Array([0, 0, 0, 0, 0], dtype=int32), 'hours_per_week': Array([40, 40, 60, 40, 52], dtype=int32), 'native_country': Array([1, 0, 1, 1, 1], dtype=int32)}
然后,我们生成预测。
jax_predictions = jax_model.predict(jax_selected_examples)
jax_predictions
Array([0.01860434, 0.36130956, 0.83858865, 0.04385566, 0.02917648], dtype=float32)
Out[29]
model.predict(selected_examples)
array([0.01860435, 0.36130956, 0.83858865, 0.04385567, 0.02917649], dtype=float32)
Out[30]
# Create a TF module with the model.
tf_model = tf.Module()
tf_model.predict = tf.function(
jax2tf.convert(jax_model.predict, with_gradient=False),
jit_compile=True,
autograph=False,
)
# Check the predictions of the TF module.
tf_selected_examples = {
k: tf.constant(v) for k, v in jax_selected_examples.items()
}
tf_predictions = tf_model.predict(tf_selected_examples)
tf_predictions
<tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.01860434, 0.36130956, 0.83858865, 0.04385566, 0.02917648], dtype=float32)>
# Save the TF module to file.
with tempfile.TemporaryDirectory() as tempdir:
tf.saved_model.save(tf_model, tempdir)
INFO:tensorflow:Assets written to: /tmp/tmp90flesgr/assets
INFO:tensorflow:Assets written to: /tmp/tmp90flesgr/assets
# 将 TF 模块保存到文件。 with tempfile.TemporaryDirectory() as tempdir: tf.saved_model.save(tf_model, tempdir)
to_tensorflow_saved_model
函数允许直接创建 SavedModel 模型。这种方法产生的模型速度更快,但需要安装 TensorFlow Decision Forests。try:
with tempfile.TemporaryDirectory() as tempdir:
# Save the YDF model to a SavedModel directly.
model.to_tensorflow_saved_model(tempdir, mode="tf")
except Exception as e:
print("Could not save YDF model to SavedModel with to_tensorflow_saved_model")
[INFO 24-06-14 14:31:56.6553 CEST kernel.cc:1233] Loading model from path /tmp/tmp71lnhoy9/tmp83xu8mjt/ with prefix e57777e0_ [INFO 24-06-14 14:31:56.6795 CEST quick_scorer_extended.cc:911] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference. [INFO 24-06-14 14:31:56.6803 CEST abstract_model.cc:1362] Engine "GradientBoostedTreesQuickScorerExtended" built [INFO 24-06-14 14:31:56.6803 CEST kernel.cc:1061] Use fast generic engine
INFO:tensorflow:Assets written to: /tmp/tmpi0fp69xz/assets
INFO:tensorflow:Assets written to: /tmp/tmpi0fp69xz/assets
try: with tempfile.TemporaryDirectory() as tempdir: # 直接将 YDF 模型保存为 SavedModel。 model.to_tensorflow_saved_model(tempdir, mode="tf") except Exception as e: print("无法使用 to_tensorflow_saved_model 将 YDF 模型保存为 SavedModel")
使用 JAX 微调 YDF 模型¶
当关注的示例(服务示例)与训练数据集遵循不同的分布时,就会出现分布漂移问题。例如,在医院中,当在不同设备采集的数据上训练模型时,就会发生分布漂移。尽管来自不同设备的数据集应该兼容,但它们之间的细微差异导致在某个数据集上训练的模型在另一个数据集上表现不佳。例如,在由某个品牌设备捕获的图像上训练的用于检测肿瘤的机器学习模型,可能在另一个品牌设备捕获的图像上无法有效工作。分布漂移在随时间变化的动态系统(例如用户行为)中也很常见。
在本节中,我们使用微调来解决分布漂移问题。为此,我们使用经过修改的 Adult 数据集。我们假设只关注“relationship=Wife”的人。然而,只有 5% 的人在这个类别中,因此我们拥有的训练示例很少。
我们将首先观察到,仅在 relationship=Wife
示例上训练模型或在所有可用示例上训练模型都无法产生最佳模型。相反,我们将在所有示例上训练一个 YDF 模型,然后使用 JAX 在 relationship=Wife
示例上对其进行微调,并观察到这个微调模型表现更好。最后,将微调后的 JAX 模型转换回 YDF 模型,并使用 YDF 工具进行分析。
relationship
的分布。我们的目标是优化模型在 483 个 relationship == Wife
示例上的质量。test_ds["relationship"].value_counts()
relationship Husband 4002 Not-in-family 2505 Own-child 1521 Unmarried 948 Wife 483 Other-relative 310 Name: count, dtype: int64
Out[35]
def is_group_B(ds):
return ds["relationship"] == "Wife"
train_ds_group_A = train_ds[~is_group_B(train_ds)]
test_ds_group_A = test_ds[~is_group_B(test_ds)]
train_ds_group_B = train_ds[is_group_B(train_ds)]
test_ds_group_B = test_ds[is_group_B(test_ds)]
print("Number of examples per group")
print("\tTrain Group A:", len(train_ds_group_A))
print("\tTest Group A:", len(test_ds_group_A))
print("\tTrain Group B:", len(train_ds_group_B))
print("\tTest Group B:", len(test_ds_group_B))
relationship != Wife
的示例,B 组包含 relationship == Wife
的示例。Number of examples per group Train Group A: 21707 Test Group A: 9286 Train Group B: 1085 Test Group B: 483
In [36]
def is_group_B(ds): return ds["relationship"] == "Wife" train_ds_group_A = train_ds[~is_group_B(train_ds)] test_ds_group_A = test_ds[~is_group_B(test_ds)] train_ds_group_B = train_ds[is_group_B(train_ds)] test_ds_group_B = test_ds[is_group_B(test_ds)] print("每组示例数量") print("\t训练集 A 组:", len(train_ds_group_A)) print("\t测试集 A 组:", len(test_ds_group_A)) print("\t训练集 B 组:", len(train_ds_group_B)) print("\t测试集 B 组:", len(test_ds_group_B))
# Train model on group A
model_group_A = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds_group_A, verbose=0
)
# Train model on group B
model_group_B = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds_group_B, verbose=0
)
# Train model on group A + B
model_group_AB = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds, verbose=0
)
# Evaluate the models on group B
accuracy_test_B_model_A = model_group_A.evaluate(test_ds_group_B).accuracy
accuracy_test_B_model_B = model_group_B.evaluate(test_ds_group_B).accuracy
accuracy_test_B_model_AB = model_group_AB.evaluate(test_ds_group_B).accuracy
print("Accuracy on B, model trained on A:", accuracy_test_B_model_A)
print("Accuracy on B, model trained on B:", accuracy_test_B_model_B)
print("Accuracy on B, model trained on A+B:", accuracy_test_B_model_AB)
Accuracy on B, model trained on A: 0.7204968944099379 Accuracy on B, model trained on B: 0.7329192546583851 Accuracy on B, model trained on A+B: 0.7556935817805382
In [37]
# 在 A 组上训练模型 model_group_A = ydf.GradientBoostedTreesLearner(label=label).train( train_ds_group_A, verbose=0 ) # 在 B 组上训练模型 model_group_B = ydf.GradientBoostedTreesLearner(label=label).train( train_ds_group_B, verbose=0 ) # 在 A+B 组上训练模型 model_group_AB = ydf.GradientBoostedTreesLearner(label=label).train( train_ds, verbose=0 ) # 在 B 组上评估模型 accuracy_test_B_model_A = model_group_A.evaluate(test_ds_group_B).accuracy accuracy_test_B_model_B = model_group_B.evaluate(test_ds_group_B).accuracy accuracy_test_B_model_AB = model_group_AB.evaluate(test_ds_group_B).accuracy print("模型在 A 组上训练,在 B 组上准确率:", accuracy_test_B_model_A) print("模型在 B 组上训练,在 B 组上准确率:", accuracy_test_B_model_B) print("模型在 A+B 组上训练,在 B 组上准确率:", accuracy_test_B_model_AB)
jax_model_group_AB = model_group_AB.to_jax_function(
apply_activation=False,
leaves_as_params=True,
)
jax_model_group_AB.params
{'leaf_values': Array([-0.1233467 , -0.0927111 , 0.2927755 , ..., 0.05464426, 0.12556875, -0.11374608], dtype=float32), 'initial_predictions': Array([-1.1630996], dtype=float32)}
jax_model_group_AB = model_group_AB.to_jax_function( apply_activation=False, leaves_as_params=True, ) jax_model_group_AB.params
- Out[38]
leaves_as_params=True
指定叶节点值作为模型参数导出到params
中。这对于微调模型是必要的。
请注意:
apply_activation=True
从模型中删除激活函数。这使得模型损失可以在 logits 上计算,而不是在概率上计算,从而使微调在数值上更稳定。def get_num_examples(ds):
return len(next(iter(ds.values())))
def prepare_dataset(ds, jax_model, batch=100):
ds = ds.copy()
# Make the label boolean
ds[label] = ds[label] == ">50K"
# Encode the input features
encoded_ds = jax_model.encoder(ds)
# Yield batches of examples
n = get_num_examples(encoded_ds)
i = 0
while i < n:
begin_idx = i
end_idx = min(i + batch, n)
yield {k: v[begin_idx:end_idx] for k, v in encoded_ds.items()}
i += batch
# Example of utilisation of "prepare_dataset".
for examples in prepare_dataset(train_ds_group_B, jax_model_group_AB, batch=4):
print(examples)
break # We only print the first batch
leaves_as_params=True
指定将叶子值作为模型参数导出到 params
中。这对于微调模型是必需的。{'age': Array([44, 67, 26, 30], dtype=int32), 'workclass': Array([1, 5, 0, 1], dtype=int32), 'fnlwgt': Array([228057, 171564, 167835, 118551], dtype=int32), 'education': Array([9, 1, 3, 3], dtype=int32), 'education_num': Array([ 4, 9, 13, 13], dtype=int32), 'marital_status': Array([1, 1, 1, 1], dtype=int32), 'occupation': Array([ 7, 1, 0, 11], dtype=int32), 'relationship': Array([5, 5, 5, 5], dtype=int32), 'race': Array([1, 1, 1, 1], dtype=int32), 'sex': Array([2, 2, 2, 2], dtype=int32), 'capital_gain': Array([ 0, 20051, 0, 0], dtype=int32), 'capital_loss': Array([0, 0, 0, 0], dtype=int32), 'hours_per_week': Array([40, 30, 20, 16], dtype=int32), 'native_country': Array([12, 10, 1, 1], dtype=int32), 'income': Array([False, True, False, True], dtype=bool)}
要微调模型,我们需要生成批量示例。以下代码块生成此类批量数据。
@jax.jit
def compute_accuracy(params, examples, logit=True):
examples = examples.copy()
labels = examples.pop(model.label())
predictions = jax_model_group_AB.predict(examples, params)
return ((predictions >= 0.0) == labels).mean()
@jax.jit
def compute_loss(params, examples):
examples = examples.copy()
labels = examples.pop(model.label())
logits = jax_model_group_AB.predict(examples, params)
return optax.sigmoid_binary_cross_entropy(logits, labels).mean()
def compute_metric(metric_fn, ds):
sum_metrics = 0
num_examples = 0
for examples in prepare_dataset(ds, jax_model_group_AB):
n = get_num_examples(examples)
sum_metrics += n * metric_fn(jax_model_group_AB.params, examples)
num_examples += n
return float(sum_metrics / num_examples)
def print_logs(stage):
train_accuracy = compute_metric(compute_accuracy, train_ds_group_B)
train_loss = compute_metric(compute_loss, train_ds_group_B)
test_accuracy = compute_metric(compute_accuracy, test_ds_group_B)
test_loss = compute_metric(compute_loss, test_ds_group_B)
print(
f"stage:{stage:10} "
f"test-accuracy:{test_accuracy:.5f} test-loss:{test_loss:.5f} "
f"train-accuracy:{train_accuracy:.5f} train-loss:{train_loss:.5f}"
)
# Metrics of the model before training.
print_logs("initial")
stage:initial test-accuracy:0.75569 test-loss:0.47798 train-accuracy:0.83963 train-loss:0.37099
让我们定义用于计算和打印模型损失和准确率的工具函数。
optimizer = optax.adam(0.001)
@jax.jit
def train_step(opt_state, mdl_state, examples):
loss, grads = jax.value_and_grad(compute_loss)(mdl_state, examples)
updates, opt_state = optimizer.update(grads, opt_state)
mdl_state = optax.apply_updates(mdl_state, updates)
return opt_state, mdl_state, loss
opt_state = optimizer.init(jax_model_group_AB.params)
for epoch_idx in range(10):
print_logs(f"epoch_{epoch_idx}")
for examples in prepare_dataset(train_ds_group_B, jax_model_group_AB):
opt_state, jax_model_group_AB.params, _ = train_step(
opt_state, jax_model_group_AB.params, examples
)
print_logs("final")
stage:epoch_0 test-accuracy:0.75569 test-loss:0.47798 train-accuracy:0.83963 train-loss:0.37099 stage:epoch_1 test-accuracy:0.75155 test-loss:0.48035 train-accuracy:0.84424 train-loss:0.36520 stage:epoch_2 test-accuracy:0.75776 test-loss:0.47823 train-accuracy:0.84240 train-loss:0.35878 stage:epoch_3 test-accuracy:0.75983 test-loss:0.48016 train-accuracy:0.84608 train-loss:0.35352 stage:epoch_4 test-accuracy:0.75776 test-loss:0.48063 train-accuracy:0.84793 train-loss:0.34862 stage:epoch_5 test-accuracy:0.75569 test-loss:0.48173 train-accuracy:0.85069 train-loss:0.34419 stage:epoch_6 test-accuracy:0.75776 test-loss:0.48283 train-accuracy:0.85346 train-loss:0.34008 stage:epoch_7 test-accuracy:0.75776 test-loss:0.48381 train-accuracy:0.85806 train-loss:0.33622 stage:epoch_8 test-accuracy:0.75983 test-loss:0.48495 train-accuracy:0.86175 train-loss:0.33260 stage:epoch_9 test-accuracy:0.75983 test-loss:0.48595 train-accuracy:0.86267 train-loss:0.32917 stage:final test-accuracy:0.75983 test-loss:0.48703 train-accuracy:0.86359 train-loss:0.32592
以下是训练循环。
In [41]
model_group_AB.update_with_jax_params(jax_model_group_AB.params)
我们现在可以使用微调后的权重更新 YDF 模型。
accuracy_test_B_model_AB_finetuned_B = model_group_AB.evaluate(
test_ds_group_B
).accuracy
print("Accuracy on B, model trained on A:", accuracy_test_B_model_A)
print("Accuracy on B, model trained on B:", accuracy_test_B_model_B)
print("Accuracy on B, model trained on A+B:", accuracy_test_B_model_AB)
print("==================================")
print(
"Accuracy on B, model trained on A+B, finetuned on B:",
accuracy_test_B_model_AB_finetuned_B,
)
Accuracy on B, model trained on A: 0.7204968944099379 Accuracy on B, model trained on B: 0.7329192546583851 Accuracy on B, model trained on A+B: 0.7556935817805382 ================================== Accuracy on B, model trained on A+B, finetuned on B: 0.7598343685300207
model_group_AB
是微调后的模型。让我们评估并将其与其他模型进行比较。
In [43]
# Save the model
with tempfile.TemporaryDirectory() as tempdir:
model_group_AB.save(tempdir)
model_group_AB
是一个 YDF 模型,与其他模型一样。例如,您可以保存它并进行分析。# Analyse the model
model_group_AB.analyze(test_ds_group_B)
# 分析模型 model_group_AB.analyze(test_ds_group_B)
1. "capital_gain" 0.049689 ################ 2. "occupation" 0.045549 ############## 3. "education" 0.026915 ######## 4. "education_num" 0.026915 ######## 5. "age" 0.018634 ###### 6. "capital_loss" 0.018634 ###### 7. "workclass" 0.014493 ##### 8. "fnlwgt" 0.002070 # 9. "native_country" 0.002070 # 10. "relationship" 0.000000 11. "race" 0.000000 12. "sex" 0.000000 13. "hours_per_week" 0.000000 14. "marital_status" -0.002070
1. "capital_gain" 0.164288 ################ 2. "capital_loss" 0.048263 ##### 3. "occupation" 0.033196 ### 4. "education" 0.023903 ## 5. "education_num" 0.015137 ## 6. "age" 0.013872 # 7. "workclass" 0.006274 # 8. "race" 0.002477 9. "sex" 0.001453 10. "fnlwgt" 0.000984 11. "marital_status" 0.000722 12. "relationship" 0.000000 13. "native_country" -0.000019 14. "hours_per_week" -0.007143
1. "capital_gain" 0.083385 ################ 2. "occupation" 0.040765 ######## 3. "capital_loss" 0.030647 ###### 4. "education" 0.026051 ##### 5. "age" 0.024419 ##### 6. "education_num" 0.016887 #### 7. "workclass" 0.010427 ## 8. "race" 0.003161 # 9. "marital_status" 0.000790 # 10. "sex" 0.000704 # 11. "relationship" 0.000000 # 12. "native_country" -0.000361 # 13. "fnlwgt" -0.001022 14. "hours_per_week" -0.006107
1. "capital_gain" 0.162868 ################ 2. "capital_loss" 0.048043 ##### 3. "occupation" 0.033135 ### 4. "education" 0.023881 ## 5. "education_num" 0.015116 ## 6. "age" 0.013875 # 7. "workclass" 0.006275 # 8. "race" 0.002472 9. "sex" 0.001448 10. "fnlwgt" 0.000990 11. "marital_status" 0.000721 12. "relationship" 0.000000 13. "native_country" -0.000014 14. "hours_per_week" -0.007106
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727