pip install ydf
什么是自定义损失函数?¶
在梯度提升树中,损失函数是一个接受标签值和预测值并返回该预测值“误差量”的函数。模型的训练目标是最小化所有训练示例的平均损失。YDF 实现了各种常见的损失函数。您可以通过“loss”参数配置它们。您可以在此处查看可用损失函数的列表。如果您未指定损失函数,它将根据模型任务自动选择。例如,如果任务是回归,损失函数将默认为均方误差。
如果 YDF 不支持您需要的损失函数,您可以手动定义它。这称为“自定义损失函数”。
在本入门教程中,我们将创建一个名为均方对数误差的自定义回归损失函数。
YDF 中的自定义损失函数¶
在 YDF 中,自定义损失函数由四个部分组成
- 初始预测:模型的初始预测,例如标签的平均值。
- 梯度和 Hessian:一个函数,用于计算给定标签和激活函数(又称链接函数)之前的模型预测值的损失函数的梯度和 Hessian 矩阵对角线。
- 损失函数:一个衡量当前解质量的函数。虽然理论上梯度和 Hessian 应该是损失函数的梯度和 Hessian,但在实践中近似表现得很好。
- 激活函数:应用于预测值以将其转换为正确空间(例如,分类问题的概率)的函数
使用自定义损失函数训练梯度提升树¶
我们首先配置一个回归数据集。
# Load libraries
import ydf # Yggdrasil Decision Forests
import pandas as pd # We use Pandas to load small datasets
import numpy as np # We use numpy for numerical operation
import numpy.typing as npty
from typing import Tuple
# Download a regression dataset and load it as a Pandas DataFrame.
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
all_ds = pd.read_csv(f"{ds_path}/abalone.csv")
# Randomly split the dataset into a training (70%) and testing (30%) dataset
all_ds = all_ds.sample(frac=1)
split_idx = len(all_ds) * 7 // 10
train_ds = all_ds.iloc[:split_idx]
test_ds = all_ds.iloc[split_idx:]
# Print the first 5 training examples
train_ds.head(5)
Type | LongestShell | Diameter | Height | WholeWeight | ShuckedWeight | VisceraWeight | ShellWeight | Rings | |
---|---|---|---|---|---|---|---|---|---|
1681 | F | 0.620 | 0.540 | 0.165 | 1.1390 | 0.4995 | 0.2435 | 0.3570 | 11 |
1168 | M | 0.620 | 0.450 | 0.200 | 0.8580 | 0.4285 | 0.1525 | 0.2405 | 8 |
484 | M | 0.630 | 0.480 | 0.145 | 1.0115 | 0.4235 | 0.2370 | 0.3050 | 12 |
1594 | I | 0.525 | 0.400 | 0.140 | 0.6540 | 0.3050 | 0.1600 | 0.1690 | 7 |
1192 | M | 0.700 | 0.565 | 0.180 | 1.7510 | 0.8950 | 0.3355 | 0.4460 | 9 |
均方对数误差¶
本教程使用均方对数误差 (MSLE) 损失函数。MSLE 的计算公式为
MSLE = $\frac{1}{n} \sum_{i=1}^n (\log(p_i + 1) - \log(a_i+1))^2$,
其中 $n$ 是观测值的总数,$p_i$ 和 $a_i$ 分别是示例 $i$ 的预测值和标签值,$\log$ 表示自然对数。
MSLE 损失函数对预测值 $p_i$ 的梯度为
$\frac{1}{n} \cdot \frac{2(\log(p_i + 1) - \log(a_i+1))}{p_i + 1}$
MSLE 损失函数的 Hessian 矩阵是一个矩阵。为了简单起见和性能考虑,YDF 只使用 Hessian 矩阵的对角线。对角线的第 $i$ 个元素为
$\frac{1}{n} \cdot \frac{2(1 - \log(p_i + 1) + \log(a_i+1))}{(p_i + 1)^2}$
# If predictions are close to -1, numerical instabilities will distort the
# results. The predictions are therefore capped slightly above -1.
PREDICTION_MINIMUM = -1 + 1e-6
def loss_msle(
labels: npty.NDArray[np.float32],
predictions: npty.NDArray[np.float32],
weights: npty.NDArray[np.float32],
) -> np.float32:
clipped_pred = np.maximum(PREDICTION_MINIMUM, predictions)
return np.sum((np.log1p(clipped_pred) - np.log1p(labels))**2) / len(labels)
def initial_predictions_msle(
labels: npty.NDArray[np.float32], _: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
return np.exp(np.mean(np.log1p(labels))) - 1
def grad_msle(
labels: npty.NDArray[np.float32], predictions: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
gradient = (2/ len(labels))*(np.log1p(predictions) - np.log1p(labels)) / (predictions + 1)
return gradient
def hessian_msle(
labels: npty.NDArray[np.float32], predictions: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
hessian = (2/ len(labels))*(1 - np.log1p(predictions) + np.log1p(labels)) / (predictions + 1)**2
return hessian
def gradient_and_hessian_msle(
labels: npty.NDArray[np.float32], predictions: npty.NDArray[np.float32]
) -> Tuple[npty.NDArray[np.float32], npty.NDArray[np.float32]]:
clipped_pred = np.maximum(PREDICTION_MINIMUM, predictions)
return [grad_msle(labels, clipped_pred), hessian_msle(labels, clipped_pred)]
# Construct the loss object.
msle_custom_loss = ydf.RegressionLoss(
initial_predictions=initial_predictions_msle,
gradient_and_hessian=gradient_and_hessian_msle,
loss=loss_msle,
activation=ydf.Activation.IDENTITY,
)
模型照常训练,损失对象作为超参数。
model = ydf.GradientBoostedTreesLearner(label="Rings", task=ydf.Task.REGRESSION, loss=msle_custom_loss).train(train_ds)
Train model on 2923 examples Using a custom loss. Note when using custom losses, hyperparameter `apply_link_function` is ignored. Use the losses' activation function instead. Model trained in 0:00:01.596486
模型描述展示了训练损失和验证损失的演变。
model.describe()
任务 : REGRESSION
标签 : Rings
特征 (8) : Type LongestShell Diameter Height WholeWeight ShuckedWeight VisceraWeight ShellWeight
权重 : None
是否使用调优器训练 : No
模型大小 : 2263 kB
Number of records: 2923 Number of columns: 9 Number of columns by type: NUMERICAL: 8 (88.8889%) CATEGORICAL: 1 (11.1111%) Columns: NUMERICAL: 8 (88.8889%) 0: "Rings" NUMERICAL mean:9.97366 min:1 max:29 sd:3.26558 dtype:DTYPE_INT64 2: "LongestShell" NUMERICAL mean:0.524798 min:0.075 max:0.815 sd:0.119372 dtype:DTYPE_FLOAT64 3: "Diameter" NUMERICAL mean:0.408751 min:0.055 max:0.65 sd:0.0987606 dtype:DTYPE_FLOAT64 4: "Height" NUMERICAL mean:0.139512 min:0.01 max:0.515 sd:0.0386353 dtype:DTYPE_FLOAT64 5: "WholeWeight" NUMERICAL mean:0.830059 min:0.002 max:2.657 sd:0.488709 dtype:DTYPE_FLOAT64 6: "ShuckedWeight" NUMERICAL mean:0.360019 min:0.001 max:1.488 sd:0.221456 dtype:DTYPE_FLOAT64 7: "VisceraWeight" NUMERICAL mean:0.180917 min:0.0005 max:0.6415 sd:0.108618 dtype:DTYPE_FLOAT64 8: "ShellWeight" NUMERICAL mean:0.238848 min:0.0015 max:1.005 sd:0.138498 dtype:DTYPE_FLOAT64 CATEGORICAL: 1 (11.1111%) 1: "Type" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"M" 1087 (37.1878%) dtype:DTYPE_BYTES Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
以下评估是在验证数据集或袋外数据集上计算的。
变量重要性衡量输入特征对模型的重要性。
1. "ShellWeight" 0.532153 ################ 2. "WholeWeight" 0.351222 ####### 3. "ShuckedWeight" 0.244100 ## 4. "LongestShell" 0.235793 ## 5. "Height" 0.223162 # 6. "VisceraWeight" 0.210488 # 7. "Diameter" 0.206625 8. "Type" 0.185702
1. "WholeWeight" 145.000000 ################ 2. "ShellWeight" 87.000000 ######### 3. "Height" 20.000000 # 4. "VisceraWeight" 16.000000 # 5. "LongestShell" 12.000000 6. "Diameter" 11.000000 7. "ShuckedWeight" 5.000000
1. "ShuckedWeight" 1160.000000 ################ 2. "ShellWeight" 919.000000 ############ 3. "WholeWeight" 660.000000 ######## 4. "Height" 437.000000 ##### 5. "LongestShell" 424.000000 ##### 6. "Diameter" 406.000000 ##### 7. "VisceraWeight" 322.000000 #### 8. "Type" 31.000000
1. "ShellWeight" 0.000008 ################ 2. "WholeWeight" 0.000005 ######### 3. "VisceraWeight" 0.000002 ### 4. "LongestShell" 0.000002 ### 5. "ShuckedWeight" 0.000002 ### 6. "Diameter" 0.000001 ## 7. "Height" 0.000000 8. "Type" 0.000000
这些变量重要性是在训练期间计算的。在测试数据集上分析模型时,可以获得更多且可能更具参考价值的变量重要性。
仅打印第一棵树。
Tree #0: "ShellWeight">=0.15375 [s:1.50679e-10 n:2648 np:1817 miss:1] ; pred:-6.08611e-07 ├─(pos)─ "ShellWeight">=0.28975 [s:3.06419e-11 n:1817 np:904 miss:0] ; pred:1.50836 | ├─(pos)─ "ShellWeight">=0.40975 [s:2.08793e-11 n:904 np:293 miss:0] ; pred:1.25334 | | ├─(pos)─ "ShuckedWeight">=0.63925 [s:3.76591e-11 n:293 np:173 miss:0] ; pred:0.599562 | | | ├─(pos)─ "ShellWeight">=0.57775 [s:3.09601e-11 n:173 np:33 miss:0] ; pred:0.265588 | | | | ├─(pos)─ pred:0.0884813 | | | | └─(neg)─ pred:0.177107 | | | └─(neg)─ "ShellWeight">=0.5075 [s:3.65286e-11 n:120 np:29 miss:0] ; pred:0.333974 | | | ├─(pos)─ pred:0.111759 | | | └─(neg)─ pred:0.222215 | | └─(neg)─ "ShuckedWeight">=0.43875 [s:4.23098e-11 n:611 np:455 miss:0] ; pred:0.653777 | | ├─(pos)─ "ShellWeight">=0.38925 [s:4.69731e-12 n:455 np:69 miss:0] ; pred:0.313559 | | | ├─(pos)─ pred:0.0829214 | | | └─(neg)─ pred:0.230638 | | └─(neg)─ "WholeWeight">=1.09075 [s:1.84243e-11 n:156 np:31 miss:0] ; pred:0.340218 | | ├─(pos)─ pred:0.0943271 | | └─(neg)─ pred:0.245891 | └─(neg)─ "ShuckedWeight">=0.27975 [s:9.24118e-12 n:913 np:621 miss:1] ; pred:0.255019 | ├─(pos)─ "ShellWeight">=0.23475 [s:1.3242e-11 n:621 np:353 miss:1] ; pred:0.0440083 | | ├─(pos)─ "ShuckedWeight">=0.39825 [s:1.79254e-11 n:353 np:190 miss:0] ; pred:0.136942 | | | ├─(pos)─ pred:-0.000799999 | | | └─(neg)─ pred:0.137742 | | └─(neg)─ "ShellWeight">=0.18475 [s:9.53197e-12 n:268 np:192 miss:1] ; pred:-0.0929341 | | ├─(pos)─ pred:-0.0292848 | | └─(neg)─ pred:-0.0636493 | └─(neg)─ "ShellWeight">=0.18975 [s:4.18104e-11 n:292 np:132 miss:1] ; pred:0.211011 | ├─(pos)─ "Type" is in [BITMAP] {M, F} [s:4.07772e-11 n:132 np:102 miss:0] ; pred:0.189359 | | ├─(pos)─ pred:0.181646 | | └─(neg)─ pred:0.00771209 | └─(neg)─ "ShuckedWeight">=0.21575 [s:2.21914e-11 n:160 np:97 miss:1] ; pred:0.0216526 | ├─(pos)─ pred:-0.0236986 | └─(neg)─ pred:0.0453512 └─(neg)─ "Diameter">=0.2225 [s:1.05556e-10 n:831 np:697 miss:1] ; pred:-1.50836 ├─(pos)─ "Type" is in [BITMAP] {<OOD>, M, F} [s:2.913e-11 n:697 np:242 miss:1] ; pred:-0.951147 | ├─(pos)─ "ShuckedWeight">=0.233 [s:1.74077e-11 n:242 np:47 miss:1] ; pred:-0.151146 | | ├─(pos)─ "VisceraWeight">=0.1525 [s:9.15881e-12 n:47 np:5 miss:1] ; pred:-0.0692974 | | | ├─(pos)─ pred:-0.00298646 | | | └─(neg)─ pred:-0.0663109 | | └─(neg)─ "Height">=0.1025 [s:1.99501e-11 n:195 np:117 miss:1] ; pred:-0.0818482 | | ├─(pos)─ pred:-0.00643989 | | └─(neg)─ pred:-0.0754083 | └─(neg)─ "ShellWeight">=0.112 [s:2.31179e-11 n:455 np:158 miss:1] ; pred:-0.800002 | ├─(pos)─ "Height">=0.1325 [s:1.79824e-11 n:158 np:15 miss:1] ; pred:-0.173648 | | ├─(pos)─ pred:0.00315428 | | └─(neg)─ pred:-0.176802 | └─(neg)─ "ShellWeight">=0.06875 [s:9.09147e-12 n:297 np:177 miss:1] ; pred:-0.626354 | ├─(pos)─ pred:-0.329338 | └─(neg)─ pred:-0.297016 └─(neg)─ "ShellWeight">=0.02175 [s:7.89323e-11 n:134 np:78 miss:1] ; pred:-0.557212 ├─(pos)─ "LongestShell">=0.2525 [s:8.02273e-12 n:78 np:70 miss:1] ; pred:-0.265629 | ├─(pos)─ "VisceraWeight">=0.01875 [s:6.70755e-12 n:70 np:58 miss:1] ; pred:-0.231683 | | ├─(pos)─ pred:-0.198798 | | └─(neg)─ pred:-0.0328844 | └─(neg)─ pred:-0.0339468 └─(neg)─ "WholeWeight">=0.0165 [s:7.82022e-11 n:56 np:51 miss:1] ; pred:-0.291582 ├─(pos)─ "VisceraWeight">=0.01025 [s:7.9962e-12 n:51 np:21 miss:1] ; pred:-0.251427 | ├─(pos)─ pred:-0.096431 | └─(neg)─ pred:-0.154996 └─(neg)─ pred:-0.0401556
我们可以将此模型与使用 RMSE 损失函数训练的模型进行比较。
model.evaluate(test_ds)
回归模型的评估
- RMSE (均方根误差)
- 预测值与真实值之间平方差的平均值的平方根。
解释:RMSE 越低越好。它与目标变量单位相同,使其具有一定的可解释性。 - 残差
- 每个示例的预测值与真实值之间的差值 (预测值 - 真实值)。
- 残差直方图
- 显示残差分布的直方图。
解释:理想情况下,您希望看到一个大致对称、钟形分布且中心在零附近的直方图,这表明误差是随机的且没有偏差。 - 真实值直方图
- 显示数据集中实际目标值分布的直方图。
- 预测值直方图
- 显示模型预测值分布的直方图。
- 真实值 vs 预测值曲线
- 一个散点图,其中每个点代表一个数据点。x 轴是真实值,y 轴是模型的预测值。
解释:完美的模型会使所有点落在对角线上(即预测值 = 真实值)。偏离这条线表示存在误差。 - 预测值 vs 残差曲线
- 一个散点图,其中 x 轴是模型的预测值,y 轴是残差。
解释:理想情况下,您希望看到点围绕零水平线随机散布。模式(例如,漏斗状)可能表明模型存在问题。 - 预测值 vs 真实值曲线
- 有时这会绘制一条穿过真实值 vs 预测值散点图上的点的拟合曲线来可视化趋势。它可以帮助查看模型是否在特定范围内系统性地高估或低估。
# A model trained with default regression loss (i.e. RMSE loss)
model_rmse_loss = ydf.GradientBoostedTreesLearner(label="Rings", task=ydf.Task.REGRESSION).train(train_ds)
model_rmse_loss.evaluate(test_ds)
Train model on 2923 examples Model trained in 0:00:01.017847
回归模型的评估
- RMSE (均方根误差)
- 预测值与真实值之间平方差的平均值的平方根。
解释:RMSE 越低越好。它与目标变量单位相同,使其具有一定的可解释性。 - 残差
- 每个示例的预测值与真实值之间的差值 (预测值 - 真实值)。
- 残差直方图
- 显示残差分布的直方图。
解释:理想情况下,您希望看到一个大致对称、钟形分布且中心在零附近的直方图,这表明误差是随机的且没有偏差。 - 真实值直方图
- 显示数据集中实际目标值分布的直方图。
- 预测值直方图
- 显示模型预测值分布的直方图。
- 真实值 vs 预测值曲线
- 一个散点图,其中每个点代表一个数据点。x 轴是真实值,y 轴是模型的预测值。
解释:完美的模型会使所有点落在对角线上(即预测值 = 真实值)。偏离这条线表示存在误差。 - 预测值 vs 残差曲线
- 一个散点图,其中 x 轴是模型的预测值,y 轴是残差。
解释:理想情况下,您希望看到点围绕零水平线随机散布。模式(例如,漏斗状)可能表明模型存在问题。 - 预测值 vs 真实值曲线
- 有时这会绘制一条穿过真实值 vs 预测值散点图上的点的拟合曲线来可视化趋势。它可以帮助查看模型是否在特定范围内系统性地高估或低估。
def binomial_initial_predictions(
labels: npty.NDArray[np.int32], weights: npty.NDArray[np.float32]
) -> np.float32:
sum_weights = np.sum(weights)
sum_weights_positive = np.sum((labels == 2) * weights)
ratio_positive = sum_weights_positive / sum_weights
if ratio_positive == 0.0:
return -np.iinfo(np.float32).max
elif ratio_positive == 1.0:
return np.iinfo(np.float32).max
return np.log(ratio_positive / (1 - ratio_positive))
def binomial_gradient_and_hessian(
labels: npty.NDArray[np.int32], predictions: npty.NDArray[np.float32]
) -> Tuple[npty.NDArray[np.float32], npty.NDArray[np.float32]]:
pred_probability = 1.0 / (1.0 + np.exp(-predictions))
binary_labels = labels == 2
return (
pred_probability - binary_labels,
pred_probability * (pred_probability - 1),
)
def binomial_loss(
labels: npty.NDArray[np.int32],
predictions: npty.NDArray[np.float32],
weights: npty.NDArray[np.float32],
) -> np.float32:
binary_labels = labels == 2
return (-2.0 * np.sum(
binary_labels * predictions- np.log(1.0 + np.exp(predictions))
) / len(labels)
)
binomial_custom_loss = ydf.BinaryClassificationLoss(
initial_predictions=binomial_initial_predictions,
gradient_and_hessian=binomial_gradient_and_hessian,
loss=binomial_loss,
activation=ydf.Activation.SIGMOID,
)
多元分类¶
对于多元分类问题,标签是从 1 开始的整数。损失函数必须为每个类别标签提供梯度和 Hessian。梯度和 Hessian 必须返回 d x n 矩阵,其中 n 是示例数量,d 是类别标签数量。类似地,模型必须为每个类别标签提供初始预测值,形式为一个包含 d 个元素的向量。
YDF 支持 Softmax 激活函数用于不操作在概率空间中的损失函数。
为了演示目的,下面的代码重新实现了多项对数似然损失函数作为一个自定义损失函数。请注意,此损失函数也可以直接通过 loss=MULTINOMIAL_LOG_LIKELIHOOD
超参数获得。
def multinomial_initial_predictions(
labels: npty.NDArray[np.int32], _: npty.NDArray[np.float32]
) -> npty.NDArray[np.float32]:
dimension = np.max(labels)
return np.zeros(dimension, dtype=np.float32)
def multinomial_gradient(
labels: npty.NDArray[np.int32], predictions: npty.NDArray[np.float32]
) -> Tuple[npty.NDArray[np.float32], npty.NDArray[np.float32]]:
dimension = np.max(labels)
normalization = 1.0 / np.sum(np.exp(predictions), axis=1)
normalized_predictions = np.exp(predictions) * normalization[:, None]
label_indicator = (
(labels - 1)[:, np.newaxis] == np.arange(dimension)
).astype(int)
gradient = normalized_predictions - label_indicator
hessian = np.abs(gradient) * (np.abs(gradient) - 1)
return (np.transpose(gradient), np.transpose(hessian))
def multinomial_loss(
labels: npty.NDArray[np.int32],
predictions: npty.NDArray[np.float32],
weights: npty.NDArray[np.float32],
) -> np.float32:
dimension = np.max(labels)
sum_exp_pred = np.sum(np.exp(predictions), axis=1)
indicator_matrix = (
(labels - 1)[:, np.newaxis] == np.arange(dimension)
).astype(int)
label_exp_pred = np.exp(np.sum(predictions * indicator_matrix, axis=1))
return (
-np.sum(np.log(label_exp_pred / sum_exp_pred)) / len(labels)
)
multinomial_custom_loss = ydf.MultiClassificationLoss(
initial_predictions=multinomial_initial_predictions,
gradient_and_hessian=multinomial_gradient,
loss=multinomial_loss,
activation=ydf.Activation.SOFTMAX,
)
使用 JAX 的自定义损失函数¶
JAX 允许使用自动微分定义损失函数。在此示例中,我们定义了回归的 Huber 损失函数。
import jax
import jax.numpy as jnp
@jax.jit
def huber_loss(labels, pred, delta=1.0):
abs_diff = jnp.abs(labels - pred)
return jnp.average(jnp.where(abs_diff > delta,delta * (abs_diff - .5 * delta), 0.5 * abs_diff ** 2))
huber_grad = jax.jit(jax.grad(huber_loss, argnums=1))
huber_hessian = jax.jit(jax.jacfwd(jax.jacrev(huber_loss, argnums=1)))
huber_init = jax.jit(lambda labels, weights: jnp.average(labels))
huber = ydf.RegressionLoss(
initial_predictions=jax.block_until_ready(huber_init),
gradient_and_hessian=lambda label, pred: (
huber_grad(label, pred).block_until_ready(),
jnp.diagonal(huber_hessian(label, pred)).block_until_ready()
),
loss=lambda label, pred, weight: huber_loss(label, pred).block_until_ready(),
activation=ydf.Activation.IDENTITY,
)
model = ydf.GradientBoostedTreesLearner(label="Rings", task=ydf.Task.REGRESSION, loss=huber).train(train_ds)
Train model on 2923 examples Using a custom loss. Note when using custom losses, hyperparameter `apply_link_function` is ignored. Use the losses' activation function instead.
INFO:2025-02-11 10:26:09,072:jax._src.xla_bridge:924: Unable to initialize backend 'cuda': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program. INFO:2025-02-11 10:26:09,073:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program. INFO:2025-02-11 10:26:09,073:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': No TPU backend found. Make sure //learning/brain/research/jax:tpu_support is included in your deps. WARNING:2025-02-11 10:26:09,074:jax._src.xla_bridge:966: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Model trained in 0:00:02.142357
其他详细信息和技巧¶
- 为简化阐述,以上示例假设单位权重。
- 损失函数不应创建对标签、预测值和权重数组的引用。这些数组由 C++ 内存支持,可能随时在 C++ 侧被删除。
- 使用自定义损失函数时,YDF 可能会触发 GC 以捕获非法内存访问。在损失对象上设置
may_trigger_gc=False
可避免此问题,但请注意,在这种情况下 YDF 可能不会警告非法内存访问。 - 自定义损失函数返回的数组可能会被 YDF 修改。
- 使用自定义损失函数进行训练通常比训练内置损失函数慢约 10%。
- 自定义损失函数尚未完全支持模型检查和分析——尚无法在 YDF 中计算模型在测试集上的自定义损失。