Util¶
read_tf_record ¶
read_tf_record(
path: Path,
*,
compressed: bool = True,
process: Optional[Callable[[Example], Example]] = None,
verbose: bool = False,
threads: int = 20
) -> Data
读取 TensorFlow Records 数据并返回一个 numpy 数组字典。
此方法可用于在使用 YDF 之前读取 TFRecord。
警告:在 Python 中读取示例非常慢。建议直接将路径提供给 YDF(例如 model.predict("record:" + path)
),这样速度会快大约 20 倍。
使用示例
import ydf
# Load a dataset
ds = ydf.util.read_tf_record(path="/path/to/tfrecord")
# Apply some pre-processing
ds["my_label"] = np.log(ds["my_label"])
# Train a model
ydf.RandomForestLearner(label="my_label",
task=ydf.Task.REGRESSION).train(ds)
此方法要求所有 TF Example 具有相同的特征,并且所有特征具有相同的类型和值数量。如果您的 TF Record 通过跳过特征来编码缺失值,您可以使用 process
参数手动添加缺失值。
import math
def process(example: tf.train.Example):
# Add missing values for categorical features.
for key in ["feature_1", "feature_2]:
if key not in example.features.feature:
example.features.feature[key].bytes_list.value.append(b"")
# Add missing values for numerical features.
for key in ["feature_3", "feature_4]:
if key not in example.features.feature:
example.features.feature[key].float_list.value.append(math.nan)
return example
read_ds = tf_example.read_tf_recordio(path, process=process)
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
path
|
路径
|
TFRecord 文件路径或路径列表。支持分片路径。 |
必需 |
compressed
|
bool
|
TFRecord 是否已压缩。 |
True
|
process
|
Optional[Callable[[Example], Example]]
|
可选函数,用于处理每个 TF Example。可用于过滤掉某些特征或修复某些示例值(例如,确保所有示例具有一致的特征值)。 |
None
|
verbose
|
bool
|
如果为 True,则打印数据集读取状态。 |
False
|
threads
|
int
|
读取线程数。 |
20
|
返回值
类型 | 描述 |
---|---|
数据
|
一个 numpy 数组字典。 |
write_tf_record ¶
write_tf_record(
data: Data,
*,
path: Path,
compressed: bool = True,
process: Optional[Callable[[Example], Example]] = None,
verbose: bool = False,
threads: int = 20
) -> None
将一个 numpy 数组字典写入 TensorFlow Record。
此方法可用于准备用于分布式训练的 TFRecord。
使用示例
import ydf
import numpy as np
# Generate a dataset
dataset = {
"f1": np.array([1, 2, 3]),
"f2": np.array([1.1, 2.2, 3.3])}
# Write the dataset
ydf.util.write_tf_recordio(dataset, path="/path/to/tfrecord")
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
data
|
数据
|
一个 numpy 数组字典。支持分片路径。 |
必需 |
path
|
路径
|
TFRecord 文件路径或路径列表。 |
必需 |
compressed
|
bool
|
TFRecord 是否已压缩。 |
True
|
process
|
Optional[Callable[[Example], Example]]
|
可选函数,用于处理每个 TF Example。 |
None
|
verbose
|
bool
|
如果为 True,则打印数据集写入状态。 |
False
|
threads
|
int
|
写入线程数。 |
20
|
Google 内部
- ydf.util.read_tf_recordio
- ydf.util.write_tf_recordio
- ydf.util.read_f1_query