pip install ydf transformers torch jax tqdm datasets scikit-learn matplotlib -U
from itertools import islice
from datasets import load_dataset # The text dataset
import matplotlib.pyplot as plt # For plotting the toy dataset
import numpy as np
from sklearn.decomposition import PCA # PCA is used to reduce the embedding dimension
from sklearn.preprocessing import StandardScaler
import torch
from tqdm import tqdm # For the progress-bar
from transformers import GPT2Model, GPT2Tokenizer # To compute some embeddings
import ydf
什么是向量序列特征?¶
向量序列特征是一种输入特征类型,其中每个值都是一个多维、固定大小数值向量的序列(或列表)。它们非常适合编码嵌入的集合或时间序列,例如图像集合的嵌入或大型语言模型(LLM)中中间层的嵌入。
它们可以被看作是多维数值特征的扩展,如下所示
类型 | 值示例 |
---|---|
(一维)数值型 | 4.3 |
多维数值型 | [1,5,2] |
向量序列数值型 | [[1,2,3], [4,5,6], [7,8,9]] |
虽然不同向量序列值中的向量数量可能不同,但给定序列中的所有向量必须具有相同的维度(形状)。
关于本教程¶
本教程分为两部分。第一部分展示如何在简单的玩具数据上创建序列特征。
第二部分展示了一个更复杂的示例,结合了 LLM 嵌入、PCA 和向量序列:我们将使用 GPT2 模型的第一层隐藏层,应用 PCA 降低其维度(这是一个可选步骤,可以加快训练速度),并将其用作决策森林文本分类的向量序列特征。
第一部分:简单玩具示例上的向量序列¶
对于内存中的数据集,向量序列表示为 NumPy 数组的 Python 列表,其中每个数组的形状为 <向量索引, 向量维度>。
注意
- Pandas DataFrames 不太适合处理多维值。一个简单的 Python 字典来存储值更简单高效。
- 对于基于文件的数据集,需要一种能够表示二维值的格式。目前,Avro 是唯一原生支持这种格式的。
我们的玩具数据集很简单:每个特征值都是一个 2D 点(介于 0 和 5 之间)的列表。如果样本中至少有一个点在单位圆(中心在 (0,0),半径为 1)内,则标记为“true”,否则标记为“false”。大约 50% 的示例将是阳性。让我们构建这个数据集。
def make_toy_ds(num_examples=1_000):
features = []
labels = []
for _ in range(num_examples):
num_vectors = np.random.randint(0, 5)
vectors = np.random.uniform(-1.5, 1.5, [num_vectors, 2])
label = np.any(np.sum(vectors**2, axis=1) < 1)
features.append(vectors)
labels.append(label)
return {"label": np.array(labels), "feature": features}
# Generate 3 examples
make_toy_ds(num_examples=3)
{'label': array([False, False, True]), 'feature': [array([[ 0.47221051, -1.04686068], [ 1.40348894, -1.00676166], [ 0.63440287, 1.29930153], [-1.28285904, -1.44044944]]), array([[-1.44968141, 1.18143043], [-1.35090648, 1.05524487], [-0.47623758, 1.02229518]]), array([[-0.73034892, -1.03253695], [ 0.47363613, -0.61725529], [-0.02912192, -1.47682402]])]}
让我们绘制一些示例,以确保模式存在。
num_examples = 3
dataset = make_toy_ds(num_examples)
fig, axs = plt.subplots(1, 3, figsize=(10, 3))
for example_idx, ax in enumerate(axs):
feature = dataset["feature"][example_idx]
ax.scatter([v[0] for v in feature], [v[1] for v in feature])
ax.set_title(f"label={dataset['label'][example_idx]}")
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
# Show the circle
ax.add_artist(
plt.Circle((0, 0), 1, edgecolor="blue", facecolor="none", linewidth=1)
)
然后,我们可以训练我们的模型。
train_ds = make_toy_ds(num_examples=10_000)
model = ydf.RandomForestLearner(label="label").train(train_ds)
Train model on 10000 examples Model trained in 0:00:05.606750
以下单元格显示了模型的描述。在数据规范选项卡中,您可以查看特征统计信息(例如,向量数量的分布、向量维度)。在结构选项卡中,您可以看到学习到的树条件。
例如,条件 "feature" contains X with | X - [0.054303, -0.062462] |² <= 0.996597
当且仅当存在一个向量与 (0.054303, -0.0624) 的距离小于 0.996 时评估为 true。这非常接近我们生成数据集时使用的规则:与 (0., 0.) 的距离小于 1。
model.describe()
任务 : 分类
标签 : label
特征 (1) : feature
权重 : 无
使用调优器训练 : 否
模型大小 : 13257 kB
Number of records: 10000 Number of columns: 2 Number of columns by type: NUMERICAL_VECTOR_SEQUENCE: 1 (50%) CATEGORICAL: 1 (50%) Columns: NUMERICAL_VECTOR_SEQUENCE: 1 (50%) 1: "feature" NUMERICAL_VECTOR_SEQUENCE mean:0.00616592 min:-1.49985 max:1.49994 sd:0.865995 dims:2 min-vecs:0 max-vecs:4 dtype:DTYPE_FLOAT64 CATEGORICAL: 1 (50%) 0: "label" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"false" 5080 (50.8%) dtype:DTYPE_BOOL Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
以下评估是在验证集或袋外数据集上计算的。
Number of predictions (without weights): 10000 Number of predictions (with weights): 10000 Task: CLASSIFICATION Label: label Accuracy: 0.9936 CI95[W][0.992125 0.994854] LogLoss: : 0.0304112 ErrorRate: : 0.00639999 Default Accuracy: : 0.508 Default LogLoss: : 0.693019 Default ErrorRate: : 0.492 Confusion Table: truth\prediction false true false 5063 17 true 47 4873 Total: 10000
变量重要性衡量输入特征对于模型的重要性。
1. "feature" 1.000000
1. "feature" 300.000000
1. "feature" 19649.000000
1. "feature" 2054347.537343
这些变量重要性在训练期间计算。分析测试数据集上的模型时,可以获得更多可能信息量更大的变量重要性。
仅打印第一棵树。
Tree #0: "feature" contains X with | X - [-0.073972, -0.13002] |² <= 1.13198 [s:0.474337 n:10000 np:5257 miss:1] ; val:"false" prob:[0.5169, 0.4831] ├─(pos)─ "feature" contains X with | X - [0.23399, -0.012537] |² <= 1.14279 [s:0.116865 n:5257 np:4686 miss:1] ; val:"true" prob:[0.0996766, 0.900323] | ├─(pos)─ "feature" contains X with | X - [0.046393, 0.021299] |² <= 0.925552 [s:0.11139 n:4686 np:4403 miss:1] ; val:"true" prob:[0.0352113, 0.964789] | | ├─(pos)─ val:"true" prob:[0, 1] | | └─(neg)─ "feature" contains X with X @ [2.6077, -1.6338] >= 1.79877 [s:0.0892376 n:283 np:236 miss:0] ; val:"false" prob:[0.583039, 0.416961] | | ├─(pos)─ "feature" contains X with | X - [-0.89351, 0.41496] |² <= 3.78143 [s:0.0793153 n:236 np:194 miss:1] ; val:"false" prob:[0.673729, 0.326271] | | | ├─(pos)─ "feature" contains X with | X - [-1.0493, -0.92765] |² <= 0.194547 [s:0.0745684 n:194 np:38 miss:1] ; val:"false" prob:[0.603093, 0.396907] | | | | ├─(pos)─ "feature" contains X with | X - [0.15594, -1.0044] |² <= 0.0241888 [s:0.117638 n:38 np:5 miss:1] ; val:"false" prob:[0.947368, 0.0526316] | | | | | ├─(pos)─ val:"false" prob:[0.6, 0.4] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ "feature" contains X with X @ [0.63711, -1.3639] >= 1.4673 [s:0.0778821 n:156 np:99 miss:0] ; val:"false" prob:[0.519231, 0.480769] | | | | ├─(pos)─ "feature" contains X with X @ [-2.2051, 0.86293] >= 2.3505 [s:0.116979 n:99 np:31 miss:0] ; val:"false" prob:[0.666667, 0.333333] | | | | | ├─(pos)─ "feature" contains X with | X - [0.93742, -0.45437] |² <= 0.79547 [s:0.0617958 n:31 np:26 miss:1] ; val:"false" prob:[0.967742, 0.0322581] | | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | | └─(neg)─ val:"false" prob:[0.8, 0.2] | | | | | └─(neg)─ "feature" contains X with X @ [-1.1863, -0.37657] >= 0.371011 [s:0.254621 n:68 np:35 miss:0] ; val:"false" prob:[0.529412, 0.470588] | | | | | ├─(pos)─ "feature" contains X with X @ [2.0587, -1.4712] >= 3.88402 [s:0.290462 n:35 np:5 miss:0] ; val:"true" prob:[0.2, 0.8] | | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | | └─(neg)─ "feature" contains X with X @ [-0.97637, -1.6315] >= 2.05231 [s:0.132761 n:30 np:5 miss:0] ; val:"true" prob:[0.0666667, 0.933333] | | | | | | ├─(pos)─ val:"true" prob:[0.4, 0.6] | | | | | | └─(neg)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ "feature" contains X with | X - [0.46516, -0.84671] |² <= 0.0262009 [s:0.253603 n:33 np:6 miss:1] ; val:"false" prob:[0.878788, 0.121212] | | | | | ├─(pos)─ val:"true" prob:[0.333333, 0.666667] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ "feature" contains X with | X - [1.0294, -0.05871] |² <= 0.135678 [s:0.139053 n:57 np:18 miss:1] ; val:"true" prob:[0.263158, 0.736842] | | | | ├─(pos)─ "feature" contains X with | X - [-0.92035, -0.2734] |² <= 0.0374038 [s:0.358182 n:18 np:5 miss:1] ; val:"false" prob:[0.611111, 0.388889] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ "feature" contains X with | X - [1.2494, 0.90085] |² <= 0.453873 [s:0.170472 n:13 np:5 miss:1] ; val:"false" prob:[0.846154, 0.153846] | | | | | ├─(pos)─ val:"false" prob:[0.6, 0.4] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ "feature" contains X with | X - [0.87904, -0.54081] |² <= 0.0324876 [s:0.145801 n:39 np:11 miss:1] ; val:"true" prob:[0.102564, 0.897436] | | | | ├─(pos)─ "feature" contains X with X @ [-0.19497, 1.5814] >= 1.58562 [s:0.428026 n:11 np:6 miss:0] ; val:"true" prob:[0.363636, 0.636364] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ val:"false" prob:[0.8, 0.2] | | | | └─(neg)─ val:"true" prob:[0, 1] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with X @ [-0.10046, 2.3265] >= 2.95329 [s:0.220656 n:47 np:11 miss:0] ; val:"true" prob:[0.12766, 0.87234] | | ├─(pos)─ "feature" contains X with | X - [0.48423, 0.96221] |² <= 0.128651 [s:0.689009 n:11 np:6 miss:1] ; val:"false" prob:[0.545455, 0.454545] | | | ├─(pos)─ val:"false" prob:[1, 0] | | | └─(neg)─ val:"true" prob:[0, 1] | | └─(neg)─ val:"true" prob:[0, 1] | └─(neg)─ "feature" contains X with | X - [-0.83747, 0.27321] |² <= 0.959256 [s:0.113909 n:571 np:390 miss:1] ; val:"false" prob:[0.628722, 0.371278] | ├─(pos)─ "feature" contains X with X @ [-2.0488, -0.56637] >= 2.05562 [s:0.0916326 n:390 np:233 miss:0] ; val:"true" prob:[0.484615, 0.515385] | | ├─(pos)─ "feature" contains X with X @ [-0.18616, -0.088117] >= 0.252543 [s:0.130844 n:233 np:92 miss:0] ; val:"false" prob:[0.656652, 0.343348] | | | ├─(pos)─ "feature" contains X with | X - [1.4429, -0.35152] |² <= 3.80915 [s:0.0996367 n:92 np:40 miss:1] ; val:"true" prob:[0.358696, 0.641304] | | | | ├─(pos)─ "feature" contains X with | X - [-1.3635, 0.20028] |² <= 0.19175 [s:0.292586 n:40 np:20 miss:1] ; val:"false" prob:[0.6, 0.4] | | | | | ├─(pos)─ "feature" contains X with X @ [-1.6536, -1.8412] >= 3.44988 [s:0.0734147 n:20 np:15 miss:0] ; val:"false" prob:[0.95, 0.05] | | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | | └─(neg)─ val:"false" prob:[0.8, 0.2] | | | | | └─(neg)─ "feature" contains X with X @ [0.21358, -0.18956] >= 0.187504 [s:0.191258 n:20 np:14 miss:0] ; val:"true" prob:[0.25, 0.75] | | | | | ├─(pos)─ "feature" contains X with X @ [1.2553, 2.2469] >= 1.13286 [s:0.0786036 n:14 np:9 miss:0] ; val:"true" prob:[0.0714286, 0.928571] | | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | | └─(neg)─ val:"true" prob:[0.2, 0.8] | | | | | └─(neg)─ val:"false" prob:[0.666667, 0.333333] | | | | └─(neg)─ "feature" contains X with | X - [-1.1892, -0.49788] |² <= 0.088806 [s:0.234692 n:52 np:17 miss:1] ; val:"true" prob:[0.173077, 0.826923] | | | | ├─(pos)─ "feature" contains X with | X - [-1.1892, -0.49788] |² <= 0.0223469 [s:0.691416 n:17 np:8 miss:1] ; val:"false" prob:[0.529412, 0.470588] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | └─(neg)─ val:"true" prob:[0, 1] | | | └─(neg)─ "feature" contains X with | X - [-0.91476, 0.31935] |² <= 0.00975047 [s:0.0820785 n:141 np:8 miss:1] ; val:"false" prob:[0.851064, 0.148936] | | | ├─(pos)─ val:"true" prob:[0.125, 0.875] | | | └─(neg)─ "feature" contains X with | X - [-0.67918, 0.58657] |² <= 0.0723841 [s:0.044536 n:133 np:6 miss:1] ; val:"false" prob:[0.894737, 0.105263] | | | ├─(pos)─ val:"true" prob:[0.333333, 0.666667] | | | └─(neg)─ "feature" contains X with | X - [-0.77116, -0.52501] |² <= 0.0114129 [s:0.0381457 n:127 np:5 miss:1] ; val:"false" prob:[0.92126, 0.0787402] | | | ├─(pos)─ val:"true" prob:[0.4, 0.6] | | | └─(neg)─ "feature" contains X with X @ [-1.8854, -1.1367] >= 2.09149 [s:0.0480761 n:122 np:90 miss:0] ; val:"false" prob:[0.942623, 0.057377] | | | ├─(pos)─ "feature" contains X with | X - [-0.91822, 1.101] |² <= 0.0467917 [s:0.0332469 n:90 np:5 miss:1] ; val:"false" prob:[0.988889, 0.0111111] | | | | ├─(pos)─ val:"false" prob:[0.8, 0.2] | | | | └─(neg)─ val:"false" prob:[1, 0] | | | └─(neg)─ "feature" contains X with X @ [-1.7734, -0.97891] >= 1.87687 [s:0.2205 n:32 np:7 miss:0] ; val:"false" prob:[0.8125, 0.1875] | | | ├─(pos)─ val:"true" prob:[0.285714, 0.714286] | | | └─(neg)─ "feature" contains X with X @ [-0.055621, 0.0095099] >= 0.0746306 [s:0.0678637 n:25 np:5 miss:0] ; val:"false" prob:[0.96, 0.04] | | | ├─(pos)─ val:"false" prob:[0.8, 0.2] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with X @ [-2.0516, 0.99254] >= 2.24698 [s:0.1663 n:157 np:52 miss:0] ; val:"true" prob:[0.229299, 0.770701] | | ├─(pos)─ "feature" contains X with | X - [-1.3386, 1.4541] |² <= 0.73135 [s:0.325203 n:52 np:22 miss:1] ; val:"false" prob:[0.576923, 0.423077] | | | ├─(pos)─ "feature" contains X with X @ [-2.1489, -0.094533] >= 1.86235 [s:0.24535 n:22 np:17 miss:0] ; val:"true" prob:[0.136364, 0.863636] | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | └─(neg)─ val:"false" prob:[0.6, 0.4] | | | └─(neg)─ "feature" contains X with | X - [-0.98063, 0.40346] |² <= 0.034739 [s:0.0748818 n:30 np:15 miss:1] ; val:"false" prob:[0.9, 0.1] | | | ├─(pos)─ "feature" contains X with | X - [-1.025, 0.4303] |² <= 0.00414973 [s:0.223144 n:15 np:9 miss:1] ; val:"false" prob:[0.8, 0.2] | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | └─(neg)─ val:"false" prob:[0.5, 0.5] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [-0.84759, 0.17248] |² <= 0.582581 [s:0.0840044 n:105 np:92 miss:1] ; val:"true" prob:[0.0571429, 0.942857] | | ├─(pos)─ "feature" contains X with X @ [-0.12316, 0.011823] >= 0.118263 [s:0.0327644 n:92 np:5 miss:0] ; val:"true" prob:[0.0108696, 0.98913] | | | ├─(pos)─ val:"true" prob:[0.2, 0.8] | | | └─(neg)─ val:"true" prob:[0, 1] | | └─(neg)─ "feature" contains X with | X - [-0.69287, -0.60329] |² <= 0.0086635 [s:0.458327 n:13 np:7 miss:1] ; val:"true" prob:[0.384615, 0.615385] | | ├─(pos)─ val:"true" prob:[0, 1] | | └─(neg)─ val:"false" prob:[0.833333, 0.166667] | └─(neg)─ "feature" contains X with | X - [-0.5638, -0.88036] |² <= 0.0402295 [s:0.0537754 n:181 np:78 miss:1] ; val:"false" prob:[0.939227, 0.0607735] | ├─(pos)─ "feature" contains X with X @ [-1.5154, -2.4025] >= 2.82972 [s:0.174021 n:78 np:72 miss:0] ; val:"false" prob:[0.858974, 0.141026] | | ├─(pos)─ "feature" contains X with | X - [-1.3237, -1.1666] |² <= 0.280085 [s:0.214653 n:72 np:6 miss:1] ; val:"false" prob:[0.930556, 0.0694444] | | | ├─(pos)─ val:"true" prob:[0.166667, 0.833333] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ val:"true" prob:[0, 1] | └─(neg)─ val:"false" prob:[1, 0] └─(neg)─ "feature" contains X with | X - [0.44434, 0.9401] |² <= 0.654926 [s:0.027429 n:4743 np:1042 miss:1] ; val:"false" prob:[0.979338, 0.020662] ├─(pos)─ "feature" contains X with X @ [0.42567, 1.8565] >= 1.83566 [s:0.0636289 n:1042 np:883 miss:0] ; val:"false" prob:[0.909789, 0.0902111] | ├─(pos)─ "feature" contains X with | X - [-0.27788, 1.2018] |² <= 0.351519 [s:0.026306 n:883 np:357 miss:1] ; val:"false" prob:[0.961495, 0.0385051] | | ├─(pos)─ "feature" contains X with | X - [0.74848, 0.6849] |² <= 0.472482 [s:0.0761831 n:357 np:162 miss:1] ; val:"false" prob:[0.910364, 0.0896359] | | | ├─(pos)─ "feature" contains X with | X - [0.33268, 0.8648] |² <= 0.0237821 [s:0.0884767 n:162 np:13 miss:1] ; val:"false" prob:[0.802469, 0.197531] | | | | ├─(pos)─ "feature" contains X with | X - [1.0219, 0.78696] |² <= 0.319171 [s:0.170472 n:13 np:5 miss:1] ; val:"true" prob:[0.153846, 0.846154] | | | | | ├─(pos)─ val:"true" prob:[0.4, 0.6] | | | | | └─(neg)─ val:"true" prob:[0, 1] | | | | └─(neg)─ "feature" contains X with | X - [-0.057619, 0.97193] |² <= 0.0473645 [s:0.0832656 n:149 np:19 miss:1] ; val:"false" prob:[0.85906, 0.14094] | | | | ├─(pos)─ "feature" contains X with | X - [-0.093403, 1.0975] |² <= 0.0167449 [s:0.49947 n:19 np:7 miss:1] ; val:"true" prob:[0.421053, 0.578947] | | | | | ├─(pos)─ val:"false" prob:[1, 0] | | | | | └─(neg)─ "feature" contains X with | X - [-0.74371, 1.055] |² <= 0.604696 [s:0.0783349 n:12 np:7 miss:1] ; val:"true" prob:[0.0833333, 0.916667] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ val:"true" prob:[0.2, 0.8] | | | | └─(neg)─ "feature" contains X with X @ [-0.057704, 2.4017] >= 3.23595 [s:0.0917971 n:130 np:43 miss:0] ; val:"false" prob:[0.923077, 0.0769231] | | | | ├─(pos)─ "feature" contains X with | X - [0.79814, 0.54855] |² <= 0.0079685 [s:0.247604 n:43 np:6 miss:1] ; val:"false" prob:[0.767442, 0.232558] | | | | | ├─(pos)─ val:"true" prob:[0, 1] | | | | | └─(neg)─ "feature" contains X with X @ [-0.50249, -0.070775] >= -0.221604 [s:0.131325 n:37 np:32 miss:1] ; val:"false" prob:[0.891892, 0.108108] | | | | | ├─(pos)─ "feature" contains X with | X - [0.65113, 0.83079] |² <= 0.0219412 [s:0.0608729 n:32 np:5 miss:1] ; val:"false" prob:[0.96875, 0.03125] | | | | | | ├─(pos)─ val:"false" prob:[0.8, 0.2] | | | | | | └─(neg)─ val:"false" prob:[1, 0] | | | | | └─(neg)─ val:"true" prob:[0.4, 0.6] | | | | └─(neg)─ val:"false" prob:[1, 0] | | | └─(neg)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [0.83084, 1.1891] |² <= 0.284236 [s:0.0140663 n:526 np:512 miss:1] ; val:"false" prob:[0.996198, 0.00380228] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [0.28383, 0.97257] |² <= 0.255275 [s:0.169755 n:14 np:9 miss:1] ; val:"false" prob:[0.857143, 0.142857] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ val:"false" prob:[0.6, 0.4] | └─(neg)─ "feature" contains X with X @ [2.1172, 1.7939] >= 2.75332 [s:0.394315 n:159 np:94 miss:0] ; val:"false" prob:[0.622642, 0.377358] | ├─(pos)─ "feature" contains X with | X - [0.3891, 0.89861] |² <= 0.0924216 [s:0.135347 n:94 np:6 miss:1] ; val:"false" prob:[0.957447, 0.0425532] | | ├─(pos)─ val:"true" prob:[0.333333, 0.666667] | | └─(neg)─ val:"false" prob:[1, 0] | └─(neg)─ "feature" contains X with | X - [-0.25747, 1.0062] |² <= 0.00479214 [s:0.402161 n:65 np:9 miss:1] ; val:"true" prob:[0.138462, 0.861538] | ├─(pos)─ val:"false" prob:[1, 0] | └─(neg)─ val:"true" prob:[0, 1] └─(neg)─ "feature" contains X with | X - [1.4578, 0.72475] |² <= 0.735278 [s:0.00250899 n:3701 np:365 miss:1] ; val:"false" prob:[0.998919, 0.00108079] ├─(pos)─ "feature" contains X with | X - [1.025, 0.18544] |² <= 0.0345168 [s:0.0244523 n:365 np:41 miss:1] ; val:"false" prob:[0.989041, 0.0109589] | ├─(pos)─ "feature" contains X with | X - [1.1412, 0.27389] |² <= 0.0336243 [s:0.133394 n:41 np:29 miss:1] ; val:"false" prob:[0.902439, 0.097561] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ "feature" contains X with | X - [1.0641, 0.041744] |² <= 0.00282113 [s:0.428013 n:12 np:7 miss:1] ; val:"false" prob:[0.666667, 0.333333] | | ├─(pos)─ val:"false" prob:[1, 0] | | └─(neg)─ val:"true" prob:[0.2, 0.8] | └─(neg)─ val:"false" prob:[1, 0] └─(neg)─ val:"false" prob:[1, 0]
最后,我们可以评估模型。
test_ds = make_toy_ds(num_examples=1000)
model.evaluate(test_ds)
分类模型评估
- 准确率
- 最简单的指标。它是正确预测(与真实值匹配)的百分比。
示例:如果模型在 100 张图像中正确识别出 90 张是猫或狗,则准确率为 90%。 - 混淆矩阵
- 一个显示以下计数信息的表格:
- 真阳性 (TP):模型正确预测为阳性。
- 真阴性 (TN):模型正确预测为阴性。
- 假阳性 (FP):模型错误预测为阳性(“误报”)。
- 假阴性 (FN):模型错误预测为阴性(“漏报”)。
- 阈值
- YDF 分类模型预测每个类别的概率。阈值确定将某个事物分类为阳性或阴性的截止点。
示例:如果阈值为 0.5,则任何高于 0.5 的预测可能被分类为“垃圾邮件”,低于 0.5 的则被分类为“非垃圾邮件”。 - ROC 曲线 (受试者工作特征曲线)
- 一个绘制在各种阈值下真阳性率 (TPR) 对假阳性率 (FPR) 的图表。
- TPR (灵敏度或召回率): TP / (TP + FN) - 模型捕获了多少实际的阳性样本?
- FPR: FP / (FP + TN) - 模型错误地将多少阴性样本分类为阳性?
解释:一个好的模型其 ROC 曲线会靠近左上角(高 TPR,低 FPR)。 - AUC (ROC 曲线下面积)
- 一个概括 ROC 曲线所示整体性能的单一数值。AUC 是比准确率更稳定的指标。多类别分类模型针对所有其他类别评估单个类别。
解释:范围从 0 到 1。完美模型的 AUC 为 1,而随机模型的 AUC 为 0.5。值越高越好。 - 精确率-召回率曲线
- 一个绘制在各种阈值下精确率对召回率的图表。
- 精确率: TP / (TP + FP) - 在模型标记为阳性的所有预测中,有多少是实际的阳性?
- 召回率(与 TPR 相同): TP / (TP + FN) - 在所有实际的阳性案例中,模型正确识别了多少?
解释:一个好的模型其曲线会保持在高位(高精确率和高召回率)。在处理不平衡数据集(例如,当一个类别比另一个类别罕见得多)时,它特别有用。 - PR-AUC (精确率-召回率曲线下面积)
- 类似于 AUC,但用于精确率-召回率曲线。概括性能的单一数值。多类别分类模型针对所有其他类别评估单个类别。值越高越好。
- 阈值 / 准确率曲线
- 一个图表,显示模型的准确率如何随分类阈值的变化而变化。
- 阈值 / 样本数量曲线
- 一个图表,显示随着阈值变化,被分类为阳性的数据点数量如何变化。
第二部分:LLM 嵌入 + PCA + 向量序列¶
既然我们了解了向量序列,接下来让我们使用 GPT2 的第一层隐藏层和随机森林构建一个文本分类器。
我们的模型流程
- GPT2 分词器: 将文本转换为 token,例如:“the cat is red” → [362, 82, 673, 6543]。
- GPT2 Token 嵌入: 获取每个 token 的嵌入。输出形状:<token 数量, 768>。
- GPT2 第一层隐藏层: 应用 GPT2 的第一层(注意力及其他神经网络操作)。输出形状:<token 数量, 768>。
- PCA: 降低嵌入的维度。输出形状:<token 数量, 100>。
- 随机森林: 训练一个随机森林来根据 PCA 转换后的嵌入对文本进行分类。
GPT2 模型¶
我们加载 GPT2 分词器和权重。
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2Model.from_pretrained("gpt2", output_hidden_states=True)
分词器将文本编码为 token 索引列表。
tokens = gpt2_tokenizer("This is a good movie", return_tensors="pt")
print(tokens["input_ids"])
tensor([[1212, 318, 257, 922, 3807]])
然后,我们将 GPT2 模型应用于 token,并提取第 #0 层的输出。这会计算 token 嵌入并应用一层注意力机制。
selected_hidden_layer = 0
gpt2_model(**tokens).hidden_states[selected_hidden_layer]
tensor([[[ 0.0065, -0.2930, 0.0762, ..., 0.0184, -0.0275, 0.1638], [ 0.0142, -0.0437, -0.0393, ..., 0.1487, -0.0278, -0.0255], [-0.0464, -0.0791, 0.1016, ..., 0.0623, 0.0928, -0.0598], [-0.0841, -0.1244, 0.1423, ..., -0.1435, -0.0718, -0.1183], [ 0.0331, -0.0645, 0.3507, ..., -0.0210, 0.0279, 0.1440]]], grad_fn=<AddBackward0>)
我们将这两个步骤封装在一个返回 numpy 数组的函数中。
def text_to_embedding(text: str) -> np.ndarray:
tokens = gpt2_tokenizer(text, return_tensors="pt")
return (
gpt2_model(**tokens)
.hidden_states[selected_hidden_layer]
.detach()
.numpy()[0]
)
text_to_embedding("This is a good movie")
array([[ 0.00649832, -0.29302013, 0.07615747, ..., 0.01843522, -0.02754061, 0.16376127], [ 0.01423593, -0.0437407 , -0.0392998 , ..., 0.14866675, -0.02783391, -0.02553328], [-0.04641282, -0.07912885, 0.10156769, ..., 0.06225622, 0.09284618, -0.05983091], [-0.08413801, -0.12438498, 0.14228812, ..., -0.14347112, -0.07182924, -0.1183255 ], [ 0.03311015, -0.06451828, 0.35070336, ..., -0.02101075, 0.0278743 , 0.14398581]], dtype=float32)
加载数据集¶
AG News 是一个文本分类数据集,其任务是根据文章内容预测其类别。让我们加载它。
def ag_news_dataset(split: str):
class_mapping = {
0: "World",
1: "Sports",
2: "Business",
3: "Sci/Tech",
}
for example in load_dataset("ag_news")[split]:
yield {
"text": example["text"],
"label": class_mapping[example["label"]],
}
# Print the first 3 training examples
for example_idx, example in enumerate(islice(ag_news_dataset("train"), 3)):
print(f"==========\nExample #{example_idx}\n----------")
print(example)
========== Example #0 ---------- {'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 'Business'} ========== Example #1 ---------- {'text': 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.', 'label': 'Business'} ========== Example #2 ---------- {'text': "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.", 'label': 'Business'}
我们加载更多数据集。
num_examples = 1000 # Only load 1k example for the example.
labels = []
embeddings = []
for example in tqdm(
islice(ag_news_dataset("train"), num_examples), total=num_examples
):
embeddings.append(text_to_embedding(example["text"]))
labels.append(example["label"])
# raw_dataset = {"label": np.array(labels), "embedding": embeddings}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:41<00:00, 24.00it/s]
每个示例具有不同数量的 token,即嵌入矩阵具有不同的形状。
print("First example embedding size:", embeddings[0].shape)
print("Second example embedding size:", embeddings[1].shape)
First example embedding size: (37, 768) Second example embedding size: (55, 768)
压缩嵌入¶
我们的数据集很小,因此我们可以通过使用 PCA 将 768 维嵌入降至 50 维来加快训练速度,同时不会损失太多准确率。让我们这样做。
# Collect all the embeddings into a single matrix.
combined_embeddings = np.concatenate(embeddings, axis=0)
# Normalize the embedding (this is necessary for PCA).
normalized_combined_embeddings = StandardScaler().fit_transform(
combined_embeddings
)
# Learn the compressed representation.
pca = PCA(n_components=50)
_ = pca.fit(normalized_combined_embeddings)
我们压缩实际的嵌入值。
reduced_embeddings = [pca.transform(e) for e in embeddings]
print("First example embedding size:", reduced_embeddings[0].shape)
print("Second example embedding size:", reduced_embeddings[1].shape)
First example embedding size: (37, 50) Second example embedding size: (55, 50)
最后,我们将数据组装成一个字典。
dataset = {"label": np.array(labels), "reduced_embeddings": reduced_embeddings}
训练模型¶
现在我们可以训练我们的模型了。
model = ydf.RandomForestLearner(label="label").train(dataset, verbose=2)
Train model on 1000 examples Model trained in 0:00:49.579565
由于模型是随机森林,我们可以查看模型自评估(也称为袋外评估)来估计模型质量。
model.self_evaluation()
分类模型评估
- 准确率
- 最简单的指标。它是正确预测(与真实值匹配)的百分比。
示例:如果模型在 100 张图像中正确识别出 90 张是猫或狗,则准确率为 90%。 - 混淆矩阵
- 一个显示以下计数信息的表格:
- 真阳性 (TP):模型正确预测为阳性。
- 真阴性 (TN):模型正确预测为阴性。
- 假阳性 (FP):模型错误预测为阳性(“误报”)。
- 假阴性 (FN):模型错误预测为阴性(“漏报”)。
- 阈值
- YDF 分类模型预测每个类别的概率。阈值确定将某个事物分类为阳性或阴性的截止点。
示例:如果阈值为 0.5,则任何高于 0.5 的预测可能被分类为“垃圾邮件”,低于 0.5 的则被分类为“非垃圾邮件”。 - ROC 曲线 (受试者工作特征曲线)
- 一个绘制在各种阈值下真阳性率 (TPR) 对假阳性率 (FPR) 的图表。
- TPR (灵敏度或召回率): TP / (TP + FN) - 模型捕获了多少实际的阳性样本?
- FPR: FP / (FP + TN) - 模型错误地将多少阴性样本分类为阳性?
解释:一个好的模型其 ROC 曲线会靠近左上角(高 TPR,低 FPR)。 - AUC (ROC 曲线下面积)
- 一个概括 ROC 曲线所示整体性能的单一数值。AUC 是比准确率更稳定的指标。多类别分类模型针对所有其他类别评估单个类别。
解释:范围从 0 到 1。完美模型的 AUC 为 1,而随机模型的 AUC 为 0.5。值越高越好。 - 精确率-召回率曲线
- 一个绘制在各种阈值下精确率对召回率的图表。
- 精确率: TP / (TP + FP) - 在模型标记为阳性的所有预测中,有多少是实际的阳性?
- 召回率(与 TPR 相同): TP / (TP + FN) - 在所有实际的阳性案例中,模型正确识别了多少?
解释:一个好的模型其曲线会保持在高位(高精确率和高召回率)。在处理不平衡数据集(例如,当一个类别比另一个类别罕见得多)时,它特别有用。 - PR-AUC (精确率-召回率曲线下面积)
- 类似于 AUC,但用于精确率-召回率曲线。概括性能的单一数值。多类别分类模型针对所有其他类别评估单个类别。值越高越好。
- 阈值 / 准确率曲线
- 一个图表,显示模型的准确率如何随分类阈值的变化而变化。
- 阈值 / 样本数量曲线
- 一个图表,显示随着阈值变化,被分类为阳性的数据点数量如何变化。