跳至内容

Util

read_tf_record

read_tf_record(
    path: Path,
    *,
    compressed: bool = True,
    process: Optional[Callable[[Example], Example]] = None,
    verbose: bool = False,
    threads: int = 20
) -> Data

读取 TensorFlow Records 数据并返回一个 numpy 数组字典。

此方法可用于在使用 YDF 之前读取 TFRecord。

警告:在 Python 中读取示例非常慢。建议直接将路径提供给 YDF(例如 model.predict("record:" + path)),这样速度会快大约 20 倍。

使用示例

import ydf

# Load a dataset
ds = ydf.util.read_tf_record(path="/path/to/tfrecord")

# Apply some pre-processing
ds["my_label"] = np.log(ds["my_label"])

# Train a model
ydf.RandomForestLearner(label="my_label",
  task=ydf.Task.REGRESSION).train(ds)

此方法要求所有 TF Example 具有相同的特征,并且所有特征具有相同的类型和值数量。如果您的 TF Record 通过跳过特征来编码缺失值,您可以使用 process 参数手动添加缺失值。

import math

def process(example: tf.train.Example):
  # Add missing values for categorical features.
  for key in ["feature_1", "feature_2]:
    if key not in example.features.feature:
      example.features.feature[key].bytes_list.value.append(b"")

  # Add missing values for numerical features.
  for key in ["feature_3", "feature_4]:
    if key not in example.features.feature:
      example.features.feature[key].float_list.value.append(math.nan)

  return example

read_ds = tf_example.read_tf_recordio(path, process=process)

参数

名称 类型 描述 默认值
path 路径

TFRecord 文件路径或路径列表。支持分片路径。

必需
compressed bool

TFRecord 是否已压缩。

True
process Optional[Callable[[Example], Example]]

可选函数,用于处理每个 TF Example。可用于过滤掉某些特征或修复某些示例值(例如,确保所有示例具有一致的特征值)。

None
verbose bool

如果为 True,则打印数据集读取状态。

False
threads int

读取线程数。

20

返回值

类型 描述
数据

一个 numpy 数组字典。

write_tf_record

write_tf_record(
    data: Data,
    *,
    path: Path,
    compressed: bool = True,
    process: Optional[Callable[[Example], Example]] = None,
    verbose: bool = False,
    threads: int = 20
) -> None

将一个 numpy 数组字典写入 TensorFlow Record。

此方法可用于准备用于分布式训练的 TFRecord。

使用示例

import ydf
import numpy as np

# Generate a dataset
dataset = {
    "f1": np.array([1, 2, 3]),
    "f2": np.array([1.1, 2.2, 3.3])}

# Write the dataset
ydf.util.write_tf_recordio(dataset, path="/path/to/tfrecord")

参数

名称 类型 描述 默认值
data 数据

一个 numpy 数组字典。支持分片路径。

必需
path 路径

TFRecord 文件路径或路径列表。

必需
compressed bool

TFRecord 是否已压缩。

True
process Optional[Callable[[Example], Example]]

可选函数,用于处理每个 TF Example。

None
verbose bool

如果为 True,则打印数据集写入状态。

False
threads int

写入线程数。

20

Google 内部

  • ydf.util.read_tf_recordio
  • ydf.util.write_tf_recordio
  • ydf.util.read_f1_query