478 lines
16 KiB
Python
478 lines
16 KiB
Python
|
from dataclasses import dataclass
|
||
|
from typing import Callable, List, Optional, Tuple, Dict
|
||
|
import functools
|
||
|
|
||
|
import torch
|
||
|
import tensorflow as tf
|
||
|
|
||
|
from tml.common.batch import DataclassBatch
|
||
|
from tml.projects.home.recap.data.config import RecapDataConfig, TaskData
|
||
|
from tml.projects.home.recap.data import preprocessors
|
||
|
from tml.projects.home.recap.config import JobMode
|
||
|
from tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn
|
||
|
from tml.projects.home.recap.data.util import (
|
||
|
keyed_jagged_tensor_from_tensors_dict,
|
||
|
sparse_or_dense_tf_to_torch,
|
||
|
)
|
||
|
from absl import logging
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class RecapBatch(DataclassBatch):
|
||
|
"""Holds features and labels from the Recap dataset."""
|
||
|
|
||
|
continuous_features: torch.Tensor
|
||
|
binary_features: torch.Tensor
|
||
|
discrete_features: torch.Tensor
|
||
|
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
|
||
|
labels: torch.Tensor
|
||
|
user_embedding: torch.Tensor = None
|
||
|
user_eng_embedding: torch.Tensor = None
|
||
|
author_embedding: torch.Tensor = None
|
||
|
weights: torch.Tensor = None
|
||
|
|
||
|
def __post_init__(self):
|
||
|
if self.weights is None:
|
||
|
self.weights = torch.ones_like(self.labels)
|
||
|
for feature_name, feature_value in self.as_dict().items():
|
||
|
if ("embedding" in feature_name) and (feature_value is None):
|
||
|
setattr(self, feature_name, torch.empty([0, 0]))
|
||
|
|
||
|
|
||
|
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
|
||
|
"""Converts a torch data loader output into `RecapBatch`."""
|
||
|
|
||
|
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
|
||
|
try:
|
||
|
features_in, labels = x
|
||
|
except ValueError:
|
||
|
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
|
||
|
features_in, labels = x, None
|
||
|
|
||
|
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
|
||
|
if sparse_feature_names:
|
||
|
sparse_features = keyed_jagged_tensor_from_tensors_dict(
|
||
|
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
|
||
|
)
|
||
|
|
||
|
user_embedding, user_eng_embedding, author_embedding = None, None, None
|
||
|
if "user_embedding" in features_in:
|
||
|
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
|
||
|
raise ValueError("Only one source of embedding for user is supported")
|
||
|
else:
|
||
|
user_embedding = features_in["user_embedding"]
|
||
|
|
||
|
if "user_eng_embedding" in features_in:
|
||
|
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
|
||
|
raise ValueError("Only one source of embedding for user is supported")
|
||
|
else:
|
||
|
user_eng_embedding = features_in["user_eng_embedding"]
|
||
|
|
||
|
if "author_embedding" in features_in:
|
||
|
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
|
||
|
raise ValueError("Only one source of embedding for user is supported")
|
||
|
else:
|
||
|
author_embedding = features_in["author_embedding"]
|
||
|
|
||
|
return RecapBatch(
|
||
|
continuous_features=features_in["continuous"],
|
||
|
binary_features=features_in["binary"],
|
||
|
discrete_features=features_in["discrete"],
|
||
|
sparse_features=sparse_features,
|
||
|
user_embedding=user_embedding,
|
||
|
user_eng_embedding=user_eng_embedding,
|
||
|
author_embedding=author_embedding,
|
||
|
labels=labels,
|
||
|
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
|
||
|
)
|
||
|
|
||
|
|
||
|
def _chain(param, f1, f2):
|
||
|
"""
|
||
|
Reduce multiple functions into one chained function
|
||
|
_chain(x, f1, f2) -> f2(f1(x))
|
||
|
"""
|
||
|
output = param
|
||
|
fns = [f1, f2]
|
||
|
for f in fns:
|
||
|
output = f(output)
|
||
|
return output
|
||
|
|
||
|
|
||
|
def _add_weights(inputs, tasks: Dict[str, TaskData]):
|
||
|
"""Adds weights based on label sampling for positive and negatives.
|
||
|
|
||
|
This is useful for numeric calibration etc. This mutates inputs.
|
||
|
|
||
|
Args:
|
||
|
inputs: A dictionary of strings to tensor-like structures.
|
||
|
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of features and labels; weights are added to features.
|
||
|
"""
|
||
|
|
||
|
weights = []
|
||
|
for key, task in tasks.items():
|
||
|
label = inputs[key]
|
||
|
float_label = tf.cast(label, tf.float32)
|
||
|
|
||
|
weights.append(
|
||
|
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||
|
)
|
||
|
|
||
|
# Ensure we are batch-major (assumes we batch before this call).
|
||
|
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||
|
return inputs
|
||
|
|
||
|
|
||
|
def get_datetimes(explicit_datetime_inputs):
|
||
|
"""Compute list datetime strings for train/validation data."""
|
||
|
datetime_format = "%Y/%m/%d/%H"
|
||
|
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
|
||
|
dates = sorted(
|
||
|
[
|
||
|
(end - timedelta(hours=i + 1)).strftime(datetime_format)
|
||
|
for i in range(int(explicit_datetime_inputs.hours))
|
||
|
]
|
||
|
)
|
||
|
return dates
|
||
|
|
||
|
|
||
|
def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
|
||
|
"""
|
||
|
Compile list of files for training/validation.
|
||
|
|
||
|
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
|
||
|
For each hour of data, if the directory is missing or empty, we increment a counter to keep
|
||
|
track of the number of missing data hours.
|
||
|
Returns only files with a `.gz` extension.
|
||
|
|
||
|
Args:
|
||
|
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
|
||
|
|
||
|
Returns:
|
||
|
data_files: Sorted list of files to read corresponding to data at the desired datetimes
|
||
|
num_hours_missing: Number of hours that we are missing data
|
||
|
|
||
|
"""
|
||
|
datetimes = get_datetimes(explicit_datetime_inputs)
|
||
|
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
|
||
|
data_files = []
|
||
|
num_hours_missing = 0
|
||
|
for folder in folders:
|
||
|
try:
|
||
|
files = tf.io.gfile.listdir(folder)
|
||
|
if not files:
|
||
|
logging.warning(f"{folder} contained no data files")
|
||
|
num_hours_missing += 1
|
||
|
data_files.extend(
|
||
|
[
|
||
|
os.path.join(folder, filename)
|
||
|
for filename in files
|
||
|
if filename.rsplit(".", 1)[-1].lower() == "gz"
|
||
|
]
|
||
|
)
|
||
|
except tf.errors.NotFoundError as e:
|
||
|
num_hours_missing += 1
|
||
|
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
|
||
|
return sorted(data_files), num_hours_missing
|
||
|
|
||
|
|
||
|
def _map_output_for_inference(
|
||
|
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||
|
):
|
||
|
if preprocessor:
|
||
|
raise ValueError("No preprocessor should be used at inference time.")
|
||
|
if add_weights:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
# Add zero weights.
|
||
|
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
|
||
|
for label in tasks:
|
||
|
del inputs[label]
|
||
|
return inputs
|
||
|
|
||
|
|
||
|
def _map_output_for_train_eval(
|
||
|
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
|
||
|
):
|
||
|
if add_weights:
|
||
|
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
|
||
|
|
||
|
# Warning this has to happen first as it changes the input
|
||
|
if preprocessor:
|
||
|
inputs = preprocessor(inputs)
|
||
|
|
||
|
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
|
||
|
|
||
|
for label in tasks:
|
||
|
del inputs[label]
|
||
|
|
||
|
return inputs, label_values
|
||
|
|
||
|
|
||
|
def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
|
||
|
"""Adds weights based on label sampling for positive and negatives.
|
||
|
|
||
|
This is useful for numeric calibration etc. This mutates inputs.
|
||
|
|
||
|
Args:
|
||
|
inputs: A dictionary of strings to tensor-like structures.
|
||
|
tasks: A dict of string (label) to `TaskData` specifying inputs.
|
||
|
|
||
|
Returns:
|
||
|
A tuple of features and labels; weights are added to features.
|
||
|
"""
|
||
|
weights = []
|
||
|
for key, task in tasks.items():
|
||
|
label = inputs[key]
|
||
|
float_label = tf.cast(label, tf.float32)
|
||
|
|
||
|
weights.append(
|
||
|
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
|
||
|
)
|
||
|
|
||
|
# Ensure we are batch-major (assumes we batch before this call).
|
||
|
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
|
||
|
return inputs
|
||
|
|
||
|
|
||
|
class RecapDataset(torch.utils.data.IterableDataset):
|
||
|
def __init__(
|
||
|
self,
|
||
|
data_config: RecapDataConfig,
|
||
|
dataset_service: Optional[str] = None,
|
||
|
mode: JobMode = JobMode.TRAIN,
|
||
|
compression: Optional[str] = "AUTO",
|
||
|
repeat: bool = False,
|
||
|
vocab_mapper: tf.keras.Model = None,
|
||
|
):
|
||
|
logging.info("***** Labels *****")
|
||
|
logging.info(list(data_config.tasks.keys()))
|
||
|
|
||
|
self._data_config = data_config
|
||
|
self._parse_fn = get_seg_dense_parse_fn(data_config)
|
||
|
self._mode = mode
|
||
|
self._repeat = repeat
|
||
|
self._num_concurrent_iterators = 1
|
||
|
self._vocab_mapper = vocab_mapper
|
||
|
self.dataset_service = dataset_service
|
||
|
|
||
|
preprocessor = None
|
||
|
self._batch_size_multiplier = 1
|
||
|
if data_config.preprocess:
|
||
|
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
|
||
|
if data_config.preprocess.downsample_negatives:
|
||
|
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
|
||
|
|
||
|
self._preprocessor = preprocessor
|
||
|
|
||
|
if mode == JobMode.INFERENCE:
|
||
|
if preprocessor is not None:
|
||
|
raise ValueError("Expect no preprocessor at inference time.")
|
||
|
should_add_weights = False
|
||
|
output_map_fn = _map_output_for_inference # (features,)
|
||
|
else:
|
||
|
# Only add weights if there is a reason to! If all weights will
|
||
|
# be equal to 1.0, save bandwidth between DDS and Chief by simply
|
||
|
# relying on the fact that weights default to 1.0 in `RecapBatch`
|
||
|
# WARNING: Weights may still be added as a side effect of a preprocessor
|
||
|
# such as `DownsampleNegatives`.
|
||
|
should_add_weights = any(
|
||
|
[
|
||
|
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
|
||
|
for task_cfg in data_config.tasks.values()
|
||
|
]
|
||
|
)
|
||
|
output_map_fn = _map_output_for_train_eval # (features, labels)
|
||
|
|
||
|
self._output_map_fn = functools.partial(
|
||
|
output_map_fn,
|
||
|
tasks=data_config.tasks,
|
||
|
preprocessor=preprocessor,
|
||
|
add_weights=should_add_weights,
|
||
|
)
|
||
|
|
||
|
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
|
||
|
|
||
|
self._tf_dataset = self._create_tf_dataset()
|
||
|
|
||
|
self._init_tensor_spec()
|
||
|
|
||
|
def _init_tensor_spec(self):
|
||
|
def _tensor_spec_to_torch_shape(spec):
|
||
|
if spec.shape is None:
|
||
|
return None
|
||
|
shape = [x if x is not None else -1 for x in spec.shape]
|
||
|
return torch.Size(shape)
|
||
|
|
||
|
self.torch_element_spec = tf.nest.map_structure(
|
||
|
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
|
||
|
)
|
||
|
|
||
|
def _create_tf_dataset(self):
|
||
|
if hasattr(self, "_tf_dataset"):
|
||
|
raise ValueError("Do not call `_create_tf_dataset` more than once.")
|
||
|
|
||
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||
|
per_replica_bsz = (
|
||
|
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
|
||
|
)
|
||
|
|
||
|
dataset: tf.data.Dataset = self._create_base_tf_dataset(
|
||
|
batch_size=per_replica_bsz,
|
||
|
)
|
||
|
|
||
|
if self._repeat:
|
||
|
logging.info("Repeating dataset")
|
||
|
dataset = dataset.repeat()
|
||
|
|
||
|
if self.dataset_service:
|
||
|
if self._num_concurrent_iterators > 1:
|
||
|
if not self.machines_config:
|
||
|
raise ValueError(
|
||
|
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||
|
)
|
||
|
dataset = dataset_lib.with_auto_tune_budget(
|
||
|
dataset,
|
||
|
machine_config=self.machines_config.chief,
|
||
|
num_concurrent_iterators=self.num_concurrent_iterators,
|
||
|
on_chief=False,
|
||
|
)
|
||
|
|
||
|
self.dataset_id, self.job_name = register_dataset(
|
||
|
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
|
||
|
)
|
||
|
dataset = distribute_from_dataset_id(
|
||
|
dataset_id=self.dataset_id, # type: ignore[arg-type]
|
||
|
job_name=self.job_name,
|
||
|
dataset_service=self.dataset_service,
|
||
|
compression=self.compression,
|
||
|
)
|
||
|
|
||
|
elif self._num_concurrent_iterators > 1:
|
||
|
if not self.machines_config:
|
||
|
raise ValueError(
|
||
|
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
|
||
|
)
|
||
|
dataset = dataset_lib.with_auto_tune_budget(
|
||
|
dataset,
|
||
|
machine_config=self.machines_config.chief,
|
||
|
num_concurrent_iterators=self._num_concurrent_iterators,
|
||
|
on_chief=True,
|
||
|
)
|
||
|
|
||
|
# Vocabulary mapping happens on the training node, not in dds because of size.
|
||
|
if self._vocab_mapper:
|
||
|
dataset = dataset.map(self._vocab_mapper)
|
||
|
|
||
|
return dataset.prefetch(world_size * 2)
|
||
|
|
||
|
def _create_base_tf_dataset(self, batch_size: int):
|
||
|
if self._data_config.inputs:
|
||
|
glob = self._data_config.inputs
|
||
|
filenames = sorted(tf.io.gfile.glob(glob))
|
||
|
elif self._data_config.explicit_datetime_inputs:
|
||
|
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
|
||
|
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
|
||
|
self._data_config.explicit_datetime_inputs,
|
||
|
increment="hourly",
|
||
|
)
|
||
|
if num_hours_missing > num_missing_hours_tol:
|
||
|
raise ValueError(
|
||
|
f"We are missing {num_hours_missing} hours of data"
|
||
|
f"more than tolerance {num_missing_hours_tol}."
|
||
|
)
|
||
|
elif self._data_config.explicit_date_inputs:
|
||
|
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
|
||
|
filenames, num_days_missing = get_explicit_datetime_inputs_files(
|
||
|
self._data_config.explicit_date_inputs,
|
||
|
increment="daily",
|
||
|
)
|
||
|
if num_days_missing > num_missing_days_tol:
|
||
|
raise ValueError(
|
||
|
f"We are missing {num_days_missing} days of data"
|
||
|
f"more than tolerance {num_missing_days_tol}."
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
|
||
|
)
|
||
|
|
||
|
num_files = len(filenames)
|
||
|
logging.info(f"Found {num_files} data files")
|
||
|
if num_files < 1:
|
||
|
raise ValueError("No data files found")
|
||
|
|
||
|
if self._data_config.num_files_to_keep is not None:
|
||
|
filenames = filenames[: self._data_config.num_files_to_keep]
|
||
|
logging.info(f"Retaining only {len(filenames)} files.")
|
||
|
|
||
|
filenames_ds = (
|
||
|
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
|
||
|
# Because of drop_remainder, if our dataset does not fill
|
||
|
# up a batch, it will emit nothing without this repeat.
|
||
|
.repeat(-1)
|
||
|
)
|
||
|
|
||
|
if self._data_config.file_batch_size:
|
||
|
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
|
||
|
|
||
|
def per_shard_dataset(filename):
|
||
|
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
|
||
|
return ds.prefetch(4)
|
||
|
|
||
|
ds = filenames_ds.interleave(
|
||
|
per_shard_dataset,
|
||
|
block_length=4,
|
||
|
deterministic=False,
|
||
|
num_parallel_calls=self._data_config.interleave_num_parallel_calls
|
||
|
or tf.data.experimental.AUTOTUNE,
|
||
|
)
|
||
|
|
||
|
# Combine functions into one map call to reduce overhead.
|
||
|
map_fn = functools.partial(
|
||
|
_chain,
|
||
|
f1=self._parse_fn,
|
||
|
f2=self._output_map_fn,
|
||
|
)
|
||
|
|
||
|
# Shuffle -> Batch -> Parse is the correct ordering
|
||
|
# Shuffling needs to be performed before batching otherwise there is not much point
|
||
|
# Batching happens before parsing because tf.Example parsing is actually vectorized
|
||
|
# and works much faster overall on batches of data.
|
||
|
ds = (
|
||
|
# DANGER DANGER: there is a default shuffle size here.
|
||
|
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
|
||
|
.batch(batch_size=batch_size, drop_remainder=True)
|
||
|
.map(
|
||
|
map_fn,
|
||
|
num_parallel_calls=self._data_config.map_num_parallel_calls
|
||
|
or tf.data.experimental.AUTOTUNE,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
if self._data_config.cache:
|
||
|
ds = ds.cache()
|
||
|
|
||
|
if self._data_config.ignore_data_errors:
|
||
|
ds = ds.apply(tf.data.experimental.ignore_errors())
|
||
|
|
||
|
options = tf.data.Options()
|
||
|
options.experimental_deterministic = False
|
||
|
ds = ds.with_options(options)
|
||
|
|
||
|
return ds
|
||
|
|
||
|
def _gen(self):
|
||
|
for x in self._tf_dataset:
|
||
|
yield to_batch(x)
|
||
|
|
||
|
def to_dataloader(self) -> Dict[str, torch.Tensor]:
|
||
|
return torch.utils.data.DataLoader(self, batch_size=None)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return iter(self._gen())
|