247 lines
8.1 KiB
Python
247 lines
8.1 KiB
Python
|
import typing
|
||
|
from enum import Enum
|
||
|
|
||
|
|
||
|
from tml.core import config as base_config
|
||
|
|
||
|
import pydantic
|
||
|
|
||
|
|
||
|
class ExplicitDateInputs(base_config.BaseConfig):
|
||
|
"""Arguments to select train/validation data using end_date and days of data."""
|
||
|
|
||
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||
|
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
|
||
|
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
|
||
|
num_missing_days_tol: int = pydantic.Field(
|
||
|
0, description="We tolerate <= num_missing_days_tol days of missing data."
|
||
|
)
|
||
|
|
||
|
|
||
|
class ExplicitDatetimeInputs(base_config.BaseConfig):
|
||
|
"""Arguments to select train/validation data using end_datetime and hours of data."""
|
||
|
|
||
|
data_root: str = pydantic.Field(..., description="Data path prefix.")
|
||
|
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
|
||
|
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
|
||
|
num_missing_hours_tol: int = pydantic.Field(
|
||
|
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
|
||
|
)
|
||
|
|
||
|
|
||
|
class DdsCompressionOption(str, Enum):
|
||
|
"""The only valid compression option is 'AUTO'"""
|
||
|
|
||
|
AUTO = "AUTO"
|
||
|
|
||
|
|
||
|
class DatasetConfig(base_config.BaseConfig):
|
||
|
inputs: str = pydantic.Field(
|
||
|
None, description="A glob for selecting data.", one_of="date_inputs_format"
|
||
|
)
|
||
|
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
|
||
|
None, one_of="date_inputs_format"
|
||
|
)
|
||
|
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
|
||
|
|
||
|
global_batch_size: pydantic.PositiveInt
|
||
|
|
||
|
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
|
||
|
None, description="Number of shards to keep."
|
||
|
)
|
||
|
repeat_files: bool = pydantic.Field(
|
||
|
True, description="DEPRICATED. Files are repeated no matter what this is set to."
|
||
|
)
|
||
|
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
|
||
|
|
||
|
cache: bool = pydantic.Field(
|
||
|
False,
|
||
|
description="Cache dataset in memory. Careful to only use this when you"
|
||
|
" have enough memory to fit entire dataset.",
|
||
|
)
|
||
|
|
||
|
data_service_dispatcher: str = pydantic.Field(None)
|
||
|
ignore_data_errors: bool = pydantic.Field(
|
||
|
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
|
||
|
)
|
||
|
dataset_service_compression: DdsCompressionOption = pydantic.Field(
|
||
|
None,
|
||
|
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
|
||
|
)
|
||
|
|
||
|
# tf.data.Dataset options
|
||
|
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
|
||
|
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||
|
None, description="Number of parallel calls."
|
||
|
)
|
||
|
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
|
||
|
None, description="Number of shards to interleave."
|
||
|
)
|
||
|
|
||
|
|
||
|
class TruncateAndSlice(base_config.BaseConfig):
|
||
|
# Apply truncation and then slice.
|
||
|
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
||
|
None, description="Experimental. Truncates continuous features to this amount for efficiency."
|
||
|
)
|
||
|
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
|
||
|
None, description="Experimental. Truncates binary features to this amount for efficiency."
|
||
|
)
|
||
|
|
||
|
continuous_feature_mask_path: str = pydantic.Field(
|
||
|
None, description="Path of mask used to slice input continuous features."
|
||
|
)
|
||
|
binary_feature_mask_path: str = pydantic.Field(
|
||
|
None, description="Path of mask used to slice input binary features."
|
||
|
)
|
||
|
|
||
|
|
||
|
class DataType(str, Enum):
|
||
|
BFLOAT16 = "bfloat16"
|
||
|
BOOL = "bool"
|
||
|
|
||
|
FLOAT32 = "float32"
|
||
|
FLOAT16 = "float16"
|
||
|
|
||
|
UINT8 = "uint8"
|
||
|
|
||
|
|
||
|
class DownCast(base_config.BaseConfig):
|
||
|
# Apply down casting to selected features.
|
||
|
features: typing.Dict[str, DataType] = pydantic.Field(
|
||
|
None, description="Map features to down cast data types."
|
||
|
)
|
||
|
|
||
|
|
||
|
class TaskData(base_config.BaseConfig):
|
||
|
pos_downsampling_rate: float = pydantic.Field(
|
||
|
1.0,
|
||
|
description="Downsampling rate of positives used to generate dataset.",
|
||
|
)
|
||
|
neg_downsampling_rate: float = pydantic.Field(
|
||
|
1.0,
|
||
|
description="Downsampling rate of negatives used to generate dataset.",
|
||
|
)
|
||
|
|
||
|
|
||
|
class SegDenseSchema(base_config.BaseConfig):
|
||
|
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
|
||
|
features: typing.List[str] = pydantic.Field(
|
||
|
[],
|
||
|
description="List of features (in addition to the renamed features) to read from schema path above.",
|
||
|
)
|
||
|
renamed_features: typing.Dict[str, str] = pydantic.Field(
|
||
|
{}, description="Dictionary of renamed features."
|
||
|
)
|
||
|
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
|
||
|
{},
|
||
|
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
|
||
|
)
|
||
|
|
||
|
|
||
|
class RectifyLabels(base_config.BaseConfig):
|
||
|
label_rectification_window_in_hours: float = pydantic.Field(
|
||
|
3.0, description="overlap time in hours for which to flip labels"
|
||
|
)
|
||
|
served_timestamp_field: str = pydantic.Field(
|
||
|
..., description="input field corresponding to served time"
|
||
|
)
|
||
|
impressed_timestamp_field: str = pydantic.Field(
|
||
|
..., description="input field corresponding to impressed time"
|
||
|
)
|
||
|
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
|
||
|
..., description="label to the input field corresponding to engagement time"
|
||
|
)
|
||
|
|
||
|
|
||
|
class ExtractFeaturesRow(base_config.BaseConfig):
|
||
|
name: str = pydantic.Field(
|
||
|
...,
|
||
|
description="name of the new field name to be created",
|
||
|
)
|
||
|
source_tensor: str = pydantic.Field(
|
||
|
...,
|
||
|
description="name of the dense tensor to look for the feature",
|
||
|
)
|
||
|
index: int = pydantic.Field(
|
||
|
...,
|
||
|
description="index of the feature in the dense tensor",
|
||
|
)
|
||
|
|
||
|
|
||
|
class ExtractFeatures(base_config.BaseConfig):
|
||
|
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
|
||
|
[],
|
||
|
description="list of features to be extracted with their name, source tensor and index",
|
||
|
)
|
||
|
|
||
|
|
||
|
class DownsampleNegatives(base_config.BaseConfig):
|
||
|
batch_multiplier: int = pydantic.Field(
|
||
|
None,
|
||
|
description="batch multiplier",
|
||
|
)
|
||
|
engagements_list: typing.List[str] = pydantic.Field(
|
||
|
[],
|
||
|
description="engagements with kept positives",
|
||
|
)
|
||
|
num_engagements: int = pydantic.Field(
|
||
|
...,
|
||
|
description="number engagements used in the model, including ones excluded in engagements_list",
|
||
|
)
|
||
|
|
||
|
|
||
|
class Preprocess(base_config.BaseConfig):
|
||
|
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
|
||
|
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
|
||
|
rectify_labels: RectifyLabels = pydantic.Field(
|
||
|
None, description="Rectify labels for a given overlap window"
|
||
|
)
|
||
|
extract_features: ExtractFeatures = pydantic.Field(
|
||
|
None, description="Extract features from dense tensors."
|
||
|
)
|
||
|
downsample_negatives: DownsampleNegatives = pydantic.Field(
|
||
|
None, description="Downsample negatives."
|
||
|
)
|
||
|
|
||
|
|
||
|
class Sampler(base_config.BaseConfig):
|
||
|
"""Assumes function is defined in data/samplers.py.
|
||
|
|
||
|
Only use this for quick experimentation.
|
||
|
If samplers are useful, we should sample from upstream data generation.
|
||
|
|
||
|
DEPRICATED, DO NOT USE.
|
||
|
"""
|
||
|
|
||
|
name: str
|
||
|
kwargs: typing.Dict
|
||
|
|
||
|
|
||
|
class RecapDataConfig(DatasetConfig):
|
||
|
seg_dense_schema: SegDenseSchema
|
||
|
|
||
|
tasks: typing.Dict[str, TaskData] = pydantic.Field(
|
||
|
description="Description of individual tasks in this dataset."
|
||
|
)
|
||
|
evaluation_tasks: typing.List[str] = pydantic.Field(
|
||
|
[], description="If specified, lists the tasks we're generating metrics for."
|
||
|
)
|
||
|
|
||
|
preprocess: Preprocess = pydantic.Field(
|
||
|
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
|
||
|
)
|
||
|
|
||
|
sampler: Sampler = pydantic.Field(
|
||
|
None,
|
||
|
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
|
||
|
)
|
||
|
|
||
|
@pydantic.root_validator()
|
||
|
def _validate_evaluation_tasks(cls, values):
|
||
|
if values.get("evaluation_tasks") is not None:
|
||
|
for task in values["evaluation_tasks"]:
|
||
|
if task not in values["tasks"]:
|
||
|
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
|
||
|
return values
|