迁移到 YDF¶
YDF 是 Google 用于训练决策森林的新库,也是 TensorFlow Decision Forests 的继任者。
两个库都依赖于自 2018 年以来为生产环境开发的高性能 C++ 实现,称为 YDF C++。然而,YDF 比 TF-DF 功能更丰富、效率更高、更易于使用。迁移到 YDF 将降低您的开发和维护成本,同时可能提高模型的质量。
TF-DF 中的大多数函数在 YDF 中都有对应的功能。下表显示了映射关系。
注意: YDF 中的许多功能/函数在 TF-DF 中不存在。因此,阅读 YDF 入门指南很可能是一个值得的时间投入。
注意: 通过设置环境变量 TFDF_DISABLE_WELCOME_MESSAGE,可以移除 TF-DF 中的迁移消息。
操作 | TF-DF | YDF |
---|---|---|
训练模型 |
tf_ds = tfdf.keras.pd_dataframe_to_tf_dataset(ds, label="l")
model = tfdf.keras.RandomForestModel()
model.fit(tf_ds)
|
model = ydf.GradientBoostedTreesLearner(label="l").train(ds)
|
查看模型 |
model.summary()
|
model.describe()
|
评估模型 |
model.compile(["accuracy", tf.keras.metrics.AUC()])
model.evaluate(test_ds)
|
model.evaluate(ds)
|
保存模型 |
model.save("project/model")
|
model.save("project/model")
|
加载模型 |
model = tf_keras.models.load_model("project/model")
|
model = ydf.load_model("project/model")
|
将模型导出为 TF SavedModel |
model.save("project/model")
|
model.to_tensorflow_saved_model("project/model", mode="tf")
|
以下是 TF-DF 和 YDF 代码的 1:1 等效示例。
TF-DF | YDF |
---|---|
!pip install tensorflow tensorflow_decision_forests # Install TF-DF
import tensorflow_decision_forests as tfdf
import tensorflow as tf
import pandas as pd
# Load a dataset with Pandas
train_df = pd.read_csv("train.csv")
test_df = pd.read_csv("test.csv")
# Convert the dataset to a TensorFlow Dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="my_label")
# Train a model
model = tfdf.keras.RandomForestModel(num_trees=500)
model.fit(train_ds)
# Evaluate model.
model.compile([tf.keras.metrics.SparseCategoricalAccuracy(),tf.keras.metrics.AUC()])
model.evaluate(test_ds)
# Saved model
model.save("/tmp/my_model")
|
pip install ydf # Install YDF
import ydf
import pandas as pd
# Load a dataset with Pandas
train_ds = pd.read_csv("train.csv")
test_ds = pd.read_csv("test.csv")
# Train a model
model = ydf.RandomForestLearner(label="my_label", num_trees=500).train(train_ds)
# Evaluate model.
model.evaluate(test_ds)
# Save the model
model.save("/tmp/my_model")
|