The Algorithm ML

This commit is contained in:
twitter-service 2023-03-31 10:31:35 -05:00
commit abdcff1c75
119 changed files with 12012 additions and 0 deletions

37
.github/workflows/main.yml vendored Normal file
View file

@ -0,0 +1,37 @@
name: Python package
on: [push]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
# - uses: pre-commit/action@v3.0.0
# name: Run pre-commit checks (pylint/yapf/isort)
# env:
# SKIP: insert-license
# with:
# extra_args: --hook-stage push --all-files
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: "pip" # caching pip dependencies
- name: install packages
run: |
/usr/bin/python -m pip install --upgrade pip
pip install --no-deps -r images/requirements.txt
# - name: ssh access
# uses: lhotari/action-upterm@v1
# with:
# limit-access-to-actor: true
# limit-access-to-users: arashd
- name: run tests
run: |
# Environment variables are reset in between steps.
export PYTHONPATH="$GITHUB_WORKSPACE/../:$PYTHONPATH"
pytest -vv

35
.gitignore vendored Normal file
View file

@ -0,0 +1,35 @@
# Mac
.DS_Store
# Vim
*.py.swp
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
.hypothesis
venv

16
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,16 @@
repos:
- repo: https://github.com/pausan/cblack
rev: release-22.3.0
hooks:
- id: cblack
name: cblack
description: "Black: The uncompromising Python code formatter - 2 space indent fork"
entry: cblack . -l 100
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict

33
LICENSE.torchrec Normal file
View file

@ -0,0 +1,33 @@
A few files here (where it is specifically noted in comments) are based on code from torchrec but
adapted for our use. Torchrec license is below:
BSD 3-Clause License
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

16
README.md Normal file
View file

@ -0,0 +1,16 @@
This project open sources some of the ML models used at Twitter.
Note these are released to be open about our recommendation system. However, these are not supported Twitter products and we are unable to provide support for the use of this code.
Currently these are:
1. The "For You" Heavy Ranker (projects/home/recap).
2. TwHIN embeddings (projects/twhin) https://arxiv.org/abs/2202.05387
This project can be run inside a python virtualenv. We have only tried this on Linux machines and because we use torchrec it works best with an Nvidia GPU. To setup run
`./images/init_venv.sh` (Linux only).
The READMEs of each project contain instructions about how to run each project.

0
common/__init__.py Normal file
View file

85
common/batch.py Normal file
View file

@ -0,0 +1,85 @@
"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
"""
# flake8: noqa
from __future__ import annotations
from typing import Dict
import abc
from dataclasses import dataclass
import dataclasses
import torch
from torchrec.streamable import Pipelineable
class BatchBase(Pipelineable, abc.ABC):
@abc.abstractmethod
def as_dict(self) -> Dict:
raise NotImplementedError
def to(self, device: torch.device, non_blocking: bool = False):
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.to(device=device, non_blocking=non_blocking)
return self.__class__(**args)
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature_value in self.as_dict().values():
feature_value.record_stream(stream)
def pin_memory(self):
args = {}
for feature_name, feature_value in self.as_dict().items():
args[feature_name] = feature_value.pin_memory()
return self.__class__(**args)
def __repr__(self) -> str:
def obj2str(v):
return f"{v.size()}" if hasattr(v, "size") else f"{v.length_per_key()}"
return "\n".join([f"{k}: {obj2str(v)}," for k, v in self.as_dict().items()])
@property
def batch_size(self) -> int:
for tensor in self.as_dict().values():
if tensor is None:
continue
if not isinstance(tensor, torch.Tensor):
continue
return tensor.shape[0]
raise Exception("Could not determine batch size from tensors.")
@dataclass
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls):
return list(cls.__dataclass_fields__.keys())
def as_dict(self):
return {
feature_name: getattr(self, feature_name)
for feature_name in self.feature_names()
if hasattr(self, feature_name)
}
@staticmethod
def from_schema(name: str, schema):
"""Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
return dataclasses.make_dataclass(
cls_name=name,
fields=[(name, torch.Tensor, dataclasses.field(default=None)) for name in schema.names],
bases=(DataclassBatch,),
)
@staticmethod
def from_fields(name: str, fields: dict):
return dataclasses.make_dataclass(
cls_name=name,
fields=[(_name, _type, dataclasses.field(default=None)) for _name, _type in fields.items()],
bases=(DataclassBatch,),
)
class DictionaryBatch(BatchBase, dict):
def as_dict(self) -> Dict:
return self

View file

@ -0,0 +1 @@
from tml.common.checkpointing.snapshot import get_checkpoint, Snapshot

View file

@ -0,0 +1,260 @@
import os
import time
from typing import Any, Dict, List, Optional
from tml.ml_logging.torch_logging import logging
from tml.common.filesystem import infer_fs, is_gcs_fs
import torchsnapshot
DONE_EVAL_SUBDIR = "evaled_by"
GCS_PREFIX = "gs://"
class Snapshot:
"""Checkpoints using torchsnapshot.
Also saves step to be updated by the training loop.
"""
def __init__(self, save_dir: str, state: Dict[str, Any]) -> None:
self.save_dir = save_dir
self.state = state
self.state["extra_state"] = torchsnapshot.StateDict(step=0, walltime=0.0)
@property
def step(self):
return self.state["extra_state"]["step"]
@step.setter
def step(self, step: int) -> None:
self.state["extra_state"]["step"] = step
@property
def walltime(self):
return self.state["extra_state"]["walltime"]
@walltime.setter
def walltime(self, walltime: float) -> None:
self.state["extra_state"]["walltime"] = walltime
def save(self, global_step: int) -> "PendingSnapshot":
"""Saves checkpoint with given global_step."""
path = os.path.join(self.save_dir, str(global_step))
logging.info(f"Saving snapshot global_step {global_step} to {path}.")
start_time = time.time()
# Take a snapshot in async manner, the snapshot is consistent that state changes after this method returns have no effect on the snapshot. It performs storage I/O in the background.
snapshot = torchsnapshot.Snapshot.async_take(
app_state=self.state,
path=path,
# commented out because DistributedModelParallel model saving
# errors with this on multi-GPU. With it removed, CPU, single
# GPU, and multi-GPU training all successfully checkpoint.
# replicated=["**"],
)
logging.info(f"Snapshot saved to {snapshot.path} ({time.time() - start_time:.05}s")
return snapshot
def restore(self, checkpoint: str) -> None:
"""Restores a given checkpoint."""
snapshot = torchsnapshot.Snapshot(path=checkpoint)
logging.info(f"Restoring snapshot from {snapshot.path}.")
start_time = time.time()
# We can remove the try-except when we are confident that we no longer need to restore from
# checkpoints from before walltime was added
try:
# checkpoints that do not have extra_state[walltime] will fail here
snapshot.restore(self.state)
except RuntimeError:
# extra_state[walltime] does not exist in the checkpoint, but step should be there so restore it
self.state["extra_state"] = torchsnapshot.StateDict(step=0)
snapshot.restore(self.state)
# we still need to ensure that extra_state has walltime in it
self.state["extra_state"] = torchsnapshot.StateDict(step=self.step, walltime=0.0)
logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s")
@classmethod
def get_torch_snapshot(
cls,
snapshot_path: str,
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> torchsnapshot.Snapshot:
"""Get torch stateless snapshot, without actually loading it.
Args:
snapshot_path: path to the model snapshot
global_step: restores from this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
path = get_checkpoint(snapshot_path, global_step, missing_ok)
logging.info(f"Loading snapshot from {path}.")
return torchsnapshot.Snapshot(path=path)
@classmethod
def load_snapshot_to_weight(
cls,
embedding_snapshot: torchsnapshot.Snapshot,
snapshot_emb_name: str,
weight_tensor,
) -> None:
"""Loads pretrained embedding from the snapshot to the model.
Utilise partial lodaing meachanism from torchsnapshot.
Args:
embedding_snapshot: Path to the snapshot containing pretrained embeddings (EBC).
snapshot_emb_name: Name of the layer in the *snapshot* model, containing the EBC.
weight_tensor: embeddings tensor of *current* model, where the embeddings will be loaded.
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
if path.startswith("0") and snapshot_emb_name in path:
snapshot_path_to_load = path
embedding_snapshot.read_object(snapshot_path_to_load, weight_tensor)
logging.info(
f"Loaded embedding snapshot from {snapshot_path_to_load}: {time.time() - start_time:.05}s",
rank=-1,
)
logging.info(f"Snapshot loaded to {weight_tensor.metadata()}", rank=-1)
def _eval_subdir(checkpoint_path: str) -> str:
return os.path.join(checkpoint_path, DONE_EVAL_SUBDIR)
def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str:
return os.path.join(_eval_subdir(checkpoint_path), f"{eval_partition}_DONE")
def is_done_eval(checkpoint_path: str, eval_partition: str):
return get_checkpoint(checkpoint_path).exists(_eval_done_path(checkpoint_path, eval_partition))
def mark_done_eval(checkpoint_path: str, eval_partition: str):
infer_fs(checkpoint_path).touch(_eval_done_path(checkpoint_path, eval_partition))
def step_from_checkpoint(checkpoint: str) -> int:
return int(os.path.basename(checkpoint))
def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, timeout: int = 1800):
"""Simplified equivalent of tf.train.checkpoints_iterator.
Args:
seconds_to_sleep: time between polling calls.
timeout: how long to wait for a new checkpoint.
"""
def _poll(last_checkpoint: Optional[str] = None):
stop_time = time.time() + timeout
while True:
_checkpoint_path = get_checkpoint(save_dir, missing_ok=True)
if not _checkpoint_path or _checkpoint_path == last_checkpoint:
if time.time() + seconds_to_sleep > stop_time:
logging.info(
f"Timed out waiting for next available checkpoint from {save_dir} for {timeout}s."
)
return None
logging.info(f"Waiting for next available checkpoint from {save_dir}.")
time.sleep(seconds_to_sleep)
else:
logging.info(f"Found latest checkpoint {_checkpoint_path}.")
return _checkpoint_path
checkpoint_path = None
while True:
new_checkpoint = _poll(checkpoint_path)
if not new_checkpoint:
return
checkpoint_path = new_checkpoint
yield checkpoint_path
def get_checkpoint(
save_dir: str,
global_step: Optional[int] = None,
missing_ok: bool = False,
) -> str:
"""Gets latest checkpoint or checkpoint at specified global_step.
Args:
global_step: Finds this checkpoint if specified.
missing_ok: if True and checkpoints do not exist, returns without restoration.
"""
checkpoints = get_checkpoints(save_dir)
if not checkpoints:
if not missing_ok:
raise Exception(f"No checkpoints found at {save_dir}")
else:
logging.info(f"No checkpoints found for restoration at {save_dir}.")
return ""
if global_step is None:
return checkpoints[-1]
logging.info(f"Found checkpoints: {checkpoints}")
for checkpoint in checkpoints:
step = step_from_checkpoint(checkpoint)
if global_step == step:
chosen_checkpoint = checkpoint
break
else:
raise Exception(f"Desired checkpoint at {global_step} not found in {save_dir}")
return chosen_checkpoint
def get_checkpoints(save_dir: str) -> List[str]:
"""Gets all checkpoints that have been fully written."""
checkpoints = []
fs = infer_fs(save_dir)
if fs.exists(save_dir):
prefix = GCS_PREFIX if is_gcs_fs(fs) else ""
checkpoints = list(f"{prefix}{elem}" for elem in fs.ls(save_dir, detail=False))
# Only take checkpoints that were fully written.
checkpoints = list(
filter(
lambda path: fs.exists(f"{path}/{torchsnapshot.snapshot.SNAPSHOT_METADATA_FNAME}"),
checkpoints,
)
)
checkpoints = sorted(checkpoints, key=lambda path: int(os.path.basename(path)))
return checkpoints
def wait_for_evaluators(
save_dir: str,
partition_names: List[str],
global_step: int,
timeout: int,
) -> None:
logging.info("Waiting for all evaluators to finish.")
start_time = time.time()
for checkpoint in checkpoints_iterator(save_dir):
step = step_from_checkpoint(checkpoint)
logging.info(f"Considering checkpoint {checkpoint} for global step {global_step}.")
if step == global_step:
while partition_names:
if is_done_eval(checkpoint, partition_names[-1]):
logging.info(
f"Checkpoint {checkpoint} marked as finished eval for partition {partition_names[-1]} at step {step}, still waiting for {partition_names}."
)
partition_names.pop()
if time.time() - start_time >= timeout:
logging.warning(
f"Not all evaluators finished after waiting for {time.time() - start_time}"
)
return
time.sleep(10)
logging.info("All evaluators finished.")
return
if time.time() - start_time >= timeout:
logging.warning(f"Not all evaluators finished after waiting for {time.time() - start_time}")
return

30
common/device.py Normal file
View file

@ -0,0 +1,30 @@
import os
import torch
import torch.distributed as dist
def maybe_setup_tensorflow():
try:
import tensorflow as tf
except ImportError:
pass
else:
tf.config.set_visible_devices([], "GPU") # disable tf gpu
def setup_and_get_device(tf_ok: bool = True) -> torch.device:
if tf_ok:
maybe_setup_tensorflow()
device = torch.device("cpu")
backend = "gloo"
if torch.cuda.is_available():
rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
if not torch.distributed.is_initialized():
dist.init_process_group(backend)
return device

View file

@ -0,0 +1 @@
from tml.common.filesystem.util import infer_fs, is_gcs_fs, is_local_fs

View file

@ -0,0 +1,16 @@
"""Minimal test for infer_fs.
Mostly a test that it returns an object
"""
from tml.common.filesystem import infer_fs
def test_infer_fs():
local_path = "/tmp/local_path"
gcs_path = "gs://somebucket/somepath"
local_fs = infer_fs(local_path)
gcs_fs = infer_fs(gcs_path)
# This should return two different objects
assert local_fs != gcs_fs

25
common/filesystem/util.py Normal file
View file

@ -0,0 +1,25 @@
"""Utilities for interacting with the file systems."""
from fsspec.implementations.local import LocalFileSystem
import gcsfs
GCS_FS = gcsfs.GCSFileSystem(cache_timeout=-1)
LOCAL_FS = LocalFileSystem()
def infer_fs(path: str):
if path.startswith("gs://"):
return GCS_FS
elif path.startswith("hdfs://"):
# We can probably use pyarrow HDFS to support this.
raise NotImplementedError("HDFS not yet supported")
else:
return LOCAL_FS
def is_local_fs(fs):
return fs == LOCAL_FS
def is_gcs_fs(fs):
return fs == GCS_FS

90
common/log_weights.py Normal file
View file

@ -0,0 +1,90 @@
"""For logging model weights."""
import itertools
from typing import Callable, Dict, List, Optional, Union
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
def weights_to_log(
model: torch.nn.Module,
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
"""Creates dict of reduced weights to log to give sense of training.
Args:
model: model to traverse.
how_to_log: if a function, then applies this to every parameter, if a dict
then only applies and logs specified parameters.
"""
if not how_to_log:
return
to_log = dict()
named_parameters = model.named_parameters()
logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
if isinstance(model, DistributedModelParallel):
named_parameters = itertools.chain(
named_parameters, model._dmp_wrapped_module.named_parameters()
)
logging.info(
f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}"
)
for param_name, params in named_parameters:
if callable(how_to_log):
how = how_to_log
else:
how = how_to_log.get(param_name) # type: ignore[assignment]
if not how:
continue # type: ignore
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
return to_log
def log_ebc_norms(
model_state_dict,
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
"""Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict: model.state_dict()
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
"""
norm_logs = dict()
for emb_key in ebc_keys:
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
if emb_key in model_state_dict:
emb_weight = model_state_dict[emb_key]
try:
emb_weight_tensor = emb_weight.local_tensor()
except AttributeError as e:
logging.info(e)
emb_weight_tensor = emb_weight
logging.info("Running Tensor.detach()")
emb_weight_tensor = emb_weight_tensor.detach()
sample_mask = torch.randperm(emb_weight_tensor.shape[0])[
: min(sample_size, emb_weight_tensor.shape[0])
]
# WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
# Change sample_size if the you observe frequent OOM errors or remove weight logging.
norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))
all_norms = [
torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())
]
dist.all_gather(all_norms, norms)
for idx, norm in enumerate(all_norms):
if norm != -1.0:
norm_logs[f"{emb_key}-norm-{idx}"] = norm
logging.info(f"Norm Logs are {norm_logs}")
return norm_logs

View file

@ -0,0 +1,59 @@
from typing import List
from enum import Enum
import tml.core.config as base_config
from tml.optimizers.config import OptimizerConfig
import pydantic
class DataType(str, Enum):
FP32 = "fp32"
FP16 = "fp16"
class EmbeddingSnapshot(base_config.BaseConfig):
"""Configuration for Embedding snapshot"""
emb_name: str = pydantic.Field(
..., description="Name of the embedding table from the loaded snapshot"
)
embedding_snapshot_uri: str = pydantic.Field(
..., description="Path to torchsnapshot of the embedding"
)
class EmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
vocab: str = pydantic.Field(
None, description="Directory to parquet files of mapping from entity ID to table index."
)
# make sure to use an optimizer that matches:
# https://github.com/pytorch/FBGEMM/blob/4c58137529d221390575e47e88d3c05ce65b66fd/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py#L15
optimizer: OptimizerConfig
data_type: DataType
class LargeEmbeddingsConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBagCollection.
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
"""
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
tables_to_log: List[str] = pydantic.Field(
None, description="list of embedding table names that we want to log during training"
)
class Mode(str, Enum):
"""Job modes."""
TRAIN = "train"
EVALUATE = "evaluate"
INFERENCE = "inference"

View file

@ -0,0 +1,58 @@
from tml.common.modules.embedding.config import LargeEmbeddingsConfig, DataType
from tml.ml_logging.torch_logging import logging
import torch
from torch import nn
import torchrec
from torchrec.modules import embedding_configs
from torchrec import EmbeddingBagConfig, EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
import numpy as np
class LargeEmbeddings(nn.Module):
def __init__(
self,
large_embeddings_config: LargeEmbeddingsConfig,
):
super().__init__()
tables = []
for table in large_embeddings_config.tables:
data_type = (
embedding_configs.DataType.FP32
if (table.data_type == DataType.FP32)
else embedding_configs.DataType.FP16
)
tables.append(
EmbeddingBagConfig(
embedding_dim=table.embedding_dim,
feature_names=[table.name], # restricted to 1 feature per table for now
name=table.name,
num_embeddings=table.num_embeddings,
pooling=torchrec.PoolingType.SUM,
data_type=data_type,
)
)
self.ebc = EmbeddingBagCollection(
device="meta",
tables=tables,
)
logging.info("********************** EBC named params are **********")
logging.info(list(self.ebc.named_parameters()))
# This hook is used to perform post-processing surgery
# on large_embedding models to prep them for serving
self.surgery_cut_point = torch.nn.Identity()
def forward(
self,
sparse_features: KeyedJaggedTensor,
) -> KeyedTensor:
pooled_embs = self.ebc(sparse_features)
# a KeyedTensor
return self.surgery_cut_point(pooled_embs)

109
common/run_training.py Normal file
View file

@ -0,0 +1,109 @@
import os
import subprocess
import sys
from typing import Optional
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
from twitter.ml.tensorflow.experimental.distributed import utils
import torch
import torch.distributed.run
def is_distributed_worker():
world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None
def maybe_run_training(
train_fn,
module_name,
nproc_per_node: Optional[int] = None,
num_nodes: Optional[int] = None,
set_python_path_in_subprocess: bool = False,
is_chief: Optional[bool] = False,
**training_kwargs,
):
"""Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables
(WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn: The function that is responsible for training
module_name: The name of the module that this function was called from;
used to indicate torchrun entrypoint.
nproc_per_node: Number of workers per node; supported values.
num_nodes: Number of nodes, otherwise inferred from environment.
is_chief: If process is running on chief.
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
"""
machines = utils.machine_from_env()
if num_nodes is None:
num_nodes = 1
if machines.num_workers:
num_nodes += machines.num_workers
if is_distributed_worker():
# world_size, rank, etc are set; assuming any other env vars are set (checks to come)
# start the actual training!
train_fn(**training_kwargs)
else:
if nproc_per_node is None:
if torch.cuda.is_available():
nproc_per_node = torch.cuda.device_count()
else:
nproc_per_node = machines.chief.num_accelerators
# Rejoin all arguments to send back through torchrec
# this is a temporary measure, will replace the os.system call
# with torchrun API calls
args = list(f"--{key}={val}" for key, val in training_kwargs.items())
cmd = [
"--nnodes",
str(num_nodes),
]
if nproc_per_node:
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
if num_nodes > 1:
cluster_resolver = utils.cluster_resolver()
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
cmd.extend(
[
"--rdzv_backend",
"c10d",
"--rdzv_id",
backend_address,
]
)
# Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388
if is_chief:
cmd.extend(["--rdzv_endpoint", "localhost:2222"])
else:
cmd.extend(["--rdzv_endpoint", backend_address])
else:
cmd.append("--standalone")
cmd.extend(
[
str(module_name),
*args,
]
)
logging.info(f"""Distributed running with cmd: '{" ".join(cmd)}'""")
# Call torchrun on this module; will spawn new processes and re-run this
# function, eventually calling "train_fn". The following line sets the PYTHONPATH to accommodate
# bazel stubbing for the main binary.
if set_python_path_in_subprocess:
subprocess.run(["torchrun"] + cmd, env={**os.environ, "PYTHONPATH": ":".join(sys.path)})
else:
torch.distributed.run.main(cmd)

13
common/test_device.py Normal file
View file

@ -0,0 +1,13 @@
"""Minimal test for device.
Mostly a test that this can be imported properly even tho moved.
"""
from unittest.mock import patch
import tml.common.device as device_utils
def test_device():
with patch("tml.common.device.dist.init_process_group"):
device = device_utils.setup_and_get_device(tf_ok=False)
assert device.type == "cpu"

33
common/testing_utils.py Normal file
View file

@ -0,0 +1,33 @@
from contextlib import contextmanager
import datetime
import os
from unittest.mock import patch
import torch.distributed as dist
from tml.ml_logging.torch_logging import logging
MOCK_ENV = {
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"LOCAL_WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29501",
"RANK": "0",
}
@contextmanager
def mock_pg():
with patch.dict(os.environ, MOCK_ENV):
try:
dist.init_process_group(
backend="gloo",
timeout=datetime.timedelta(1),
)
yield
except:
dist.destroy_process_group()
raise
finally:
dist.destroy_process_group()

44
common/utils.py Normal file
View file

@ -0,0 +1,44 @@
import yaml
import getpass
import os
import string
from typing import Tuple, Type, TypeVar
from tml.core.config import base_config
import fsspec
C = TypeVar("C", bound=base_config.BaseConfig)
def _read_file(f):
with fsspec.open(f) as f:
return f.read()
def setup_configuration(
config_type: Type[C],
yaml_path: str,
substitute_env_variable: bool = False,
) -> Tuple[C, str]:
"""Resolves a config at a yaml path.
Args:
config_type: Pydantic config class to load.
yaml_path: yaml path of the config file.
substitute_env_variable: If True substitute string in the format $VAR or ${VAR} by their
environment variable value whenever possible. If an environment variable doesn't exist,
the string is left unchanged.
Returns:
The pydantic config object.
"""
def _substitute(s):
if substitute_env_variable:
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
return s
assert config_type is not None, "can't use all_config without config_type"
content = _substitute(yaml.safe_load(_read_file(yaml_path)))
return config_type.parse_obj(content)

20
common/wandb.py Normal file
View file

@ -0,0 +1,20 @@
from typing import Any, Dict, List
import tml.core.config as base_config
import pydantic
class WandbConfig(base_config.BaseConfig):
host: str = pydantic.Field(
"https://https--wandb--prod--wandb.service.qus1.twitter.biz/",
description="Host of Weights and Biases instance, passed to login.",
)
key_path: str = pydantic.Field(description="Path to key file.")
name: str = pydantic.Field(None, description="Name of the experiment, passed to init.")
entity: str = pydantic.Field(None, description="Name of user/service account, passed to init.")
project: str = pydantic.Field(None, description="Name of wandb project, passed to init.")
tags: List[str] = pydantic.Field([], description="List of tags, passed to init.")
notes: str = pydantic.Field(None, description="Notes, passed to init.")
metadata: Dict[str, Any] = pydantic.Field(None, description="Additional metadata to log.")

0
core/__init__.py Normal file
View file

5
core/config/__init__.py Normal file
View file

@ -0,0 +1,5 @@
from tml.core.config.base_config import BaseConfig
from tml.core.config.config_load import load_config_from_yaml
# Make mypy happy by explicitly rexporting the symbols intended for end user use.
__all__ = ["BaseConfig", "load_config_from_yaml"]

View file

@ -0,0 +1,66 @@
"""Base class for all config (forbids extra fields)."""
import collections
import functools
import yaml
import pydantic
class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes.
This class provides some convenient functionality:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
```
class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok
ExampleConfig(y=1) # ok
ExampleConfig(x=1, y=1) # throws error
```
"""
class Config:
"""Forbids extras."""
extra = pydantic.Extra.forbid # noqa
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with provided the field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields are appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
def pretty_print(self) -> str:
"""Return a human legible (yaml) representation of the config useful for logging."""
return yaml.dump(self.dict())

View file

@ -0,0 +1,38 @@
from unittest import TestCase
from tml.core.config import BaseConfig
import pydantic
class BaseConfigTest(TestCase):
def test_extra_forbidden(self):
class Config(BaseConfig):
x: int
Config(x=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=2)
def test_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, one_of="f")
y: int = pydantic.Field(None, one_of="f")
with self.assertRaises(pydantic.ValidationError):
Config()
Config(x=1)
Config(y=1)
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y=3)
def test_at_most_one_of(self):
class Config(BaseConfig):
x: int = pydantic.Field(None, at_most_one_of="f")
y: str = pydantic.Field(None, at_most_one_of="f")
Config()
Config(x=1)
Config(y="a")
with self.assertRaises(pydantic.ValidationError):
Config(x=1, y="a")

View file

@ -0,0 +1,24 @@
import yaml
import string
import getpass
import os
from typing import Type
from tml.core.config.base_config import BaseConfig
def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str):
"""Recommend method to load a config file (a yaml file) and parse it.
Because we have a shared filesystem the recommended route to running jobs it put modified config
files with the desired parameters somewhere on the filesytem and run jobs pointing to them.
"""
def _substitute(s):
return string.Template(s).safe_substitute(os.environ, USER=getpass.getuser())
with open(yaml_path, "r") as f:
raw_contents = f.read()
obj = yaml.safe_load(_substitute(raw_contents))
return config_type.parse_obj(obj)

View file

@ -0,0 +1,23 @@
from unittest import TestCase
from tml.core.config import BaseConfig, load_config_from_yaml
import pydantic
import getpass
import pydantic
class _PointlessConfig(BaseConfig):
a: int
user: str
def test_load_config_from_yaml(tmp_path):
yaml_path = tmp_path.joinpath("test.yaml").as_posix()
with open(yaml_path, "w") as yaml_file:
yaml_file.write("""a: 3\nuser: ${USER}\n""")
pointless_config = load_config_from_yaml(_PointlessConfig, yaml_path)
assert pointless_config.a == 3
assert pointless_config.user == getpass.getuser()

40
core/config/training.py Normal file
View file

@ -0,0 +1,40 @@
from typing import Any, Dict, List, Optional
from tml.common.wandb import WandbConfig
from tml.core.config import base_config
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.projects.twhin.models.config import TwhinModelConfig
import pydantic
class RuntimeConfig(base_config.BaseConfig):
wandb: WandbConfig = pydantic.Field(None)
enable_tensorfloat32: bool = pydantic.Field(
False, description="Use tensorfloat32 if on Ampere devices."
)
enable_amp: bool = pydantic.Field(False, description="Enable automatic mixed precision.")
class TrainingConfig(base_config.BaseConfig):
save_dir: str = pydantic.Field("/tmp/model", description="Directory to save checkpoints.")
num_train_steps: pydantic.PositiveInt = 10000
initial_checkpoint_dir: str = pydantic.Field(
None, description="Directory of initial checkpoints", at_most_one_of="initialization"
)
checkpoint_every_n: pydantic.PositiveInt = 1000
checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
)
train_log_every_n: pydantic.PositiveInt = 1000
num_eval_steps: int = pydantic.Field(
16384, description="Number of evaluation steps. If < 0 the entire dataset will be used."
)
eval_log_every_n: pydantic.PositiveInt = 5000
eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60
gradient_accumulation: int = pydantic.Field(
None, description="Number of replica steps to accumulate gradients."
)
num_epochs: pydantic.PositiveInt = 1

View file

@ -0,0 +1,316 @@
"""Torch and torchrec specific training and evaluation loops.
Features (go/100_enablements):
- CUDA data-fetch, compute, gradient-push overlap
- Large learnable embeddings through torchrec
- On/off-chief evaluation
- Warmstart/checkpoint management
- go/dataset-service 0-copy integration
"""
import datetime
import os
from typing import Callable, Dict, Iterable, List, Mapping, Optional
from tml.common import log_weights
import tml.common.checkpointing.snapshot as snapshot_lib
from tml.core.losses import get_global_loss_detached
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
from tml.core.train_pipeline import TrainPipelineSparseDist
import tree
import torch
import torch.distributed as dist
from torch.optim.lr_scheduler import _LRScheduler
import torchmetrics as tm
def get_new_iterator(iterable: Iterable):
"""
This obtain a new iterator from the iterable. If the iterable uses tf.data.Dataset internally,
getting a new iterator each N steps will avoid memory leak. To avoid the memory leak
calling iter(iterable) should return a "fresh" iterator using a fresh
(new instance of) tf.data.Iterator.
In particular, iterable can be a torch.utils.data.IterableDataset or a
torch.utils.data.DataLoader.
When using DDS, performing this reset does not change the order in which elements are received
(excluding elements already prefetched) provided that iter(iterable) internally uses
a new instance of tf.data.Dataset created by calling from_dataset_id.
This requirement is satisfied by RecapDataset.
:param iterable:
:return:
"""
return iter(iterable)
def _get_step_fn(pipeline, data_iterator, training: bool):
def step_fn():
# It turns out that model.train() and model.eval() simply switch a single field inside the model
# class,so it's somewhat safer to wrap in here.
if training:
pipeline._model.train()
else:
pipeline._model.eval()
outputs = pipeline.progress(data_iterator)
return tree.map_structure(lambda elem: elem.detach(), outputs)
return step_fn
@torch.no_grad()
def _run_evaluation(
pipeline,
dataset,
eval_steps: int,
metrics: tm.MetricCollection,
eval_batch_size: int,
logger=None,
):
"""Runs the evaluation loop over all evaluation iterators."""
dataset = get_new_iterator(dataset)
step_fn = _get_step_fn(pipeline, dataset, training=False)
last_time = datetime.datetime.now()
logging.info(f"Starting {eval_steps} steps of evaluation.")
for _ in range(eval_steps):
outputs = step_fn()
metrics.update(outputs)
eval_ex_per_s = (
eval_batch_size * eval_steps / (datetime.datetime.now() - last_time).total_seconds()
)
logging.info(f"eval examples_per_s : {eval_ex_per_s}")
metrics_result = metrics.compute()
# Resetting at end to release metrics memory not in use.
# Reset metrics to prevent accumulation between multiple evaluation splits and not report a
# running average.
metrics.reset()
return metrics_result
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: str,
save_dir: str,
logging_interval: int,
train_steps: int,
checkpoint_frequency: int,
dataset: Iterable,
worker_batch_size: int,
num_workers: Optional[int] = 0,
enable_amp: bool = False,
initial_checkpoint_dir: Optional[str] = None,
gradient_accumulation: Optional[int] = None,
logger_initializer: Optional[Callable] = None,
scheduler: _LRScheduler = None,
metrics: Optional[tm.MetricCollection] = None,
parameters_to_log: Optional[Dict[str, Callable]] = None,
tables_to_log: Optional[List[str]] = None,
) -> None:
"""Runs training and eval on the given TrainPipeline
Args:
dataset: data iterator for the training set
evaluation_iterators: data iterators for the different evaluation sets
scheduler: optional learning rate scheduler
output_transform_for_metrics: optional transformation functions to transorm the model
output and labels into a format the metrics can understand
"""
train_pipeline = TrainPipelineSparseDist(
model=model,
optimizer=optimizer,
device=device,
enable_amp=enable_amp,
grad_accum=gradient_accumulation,
) # type: ignore[var-annotated]
# We explicitly initialize optimizer state here so that checkpoint will work properly.
if hasattr(train_pipeline._optimizer, "init_state"):
train_pipeline._optimizer.init_state()
save_state = {
"model": train_pipeline._model,
"optimizer": train_pipeline._optimizer,
"scaler": train_pipeline._grad_scaler,
}
chosen_checkpoint = None
checkpoint_handler = snapshot_lib.Snapshot(
save_dir=save_dir,
state=save_state,
)
if save_dir:
chosen_checkpoint = snapshot_lib.get_checkpoint(save_dir=save_dir, missing_ok=True)
start_step = 0
start_walltime = 0.0
if chosen_checkpoint:
# Skip restoration and exit if we should be finished.
chosen_checkpoint_global_step = snapshot_lib.step_from_checkpoint(chosen_checkpoint)
if not chosen_checkpoint_global_step < dist.get_world_size() * train_steps:
logging.info(
"Not restoring and finishing training as latest checkpoint "
f"{chosen_checkpoint} found "
f"at global_step ({chosen_checkpoint_global_step}) >= "
f"train_steps ({dist.get_world_size() * train_steps})"
)
return
logging.info(f"Restoring latest checkpoint from global_step {chosen_checkpoint_global_step}")
checkpoint_handler.restore(chosen_checkpoint)
start_step = checkpoint_handler.step
start_walltime = checkpoint_handler.walltime
elif initial_checkpoint_dir:
base, ckpt_step = os.path.split(initial_checkpoint_dir)
warmstart_handler = snapshot_lib.Snapshot(
save_dir=base,
state=save_state,
)
ckpt = snapshot_lib.get_checkpoint(save_dir=base, missing_ok=False, global_step=int(ckpt_step))
logging.info(
f"Restoring from initial_checkpoint_dir: {initial_checkpoint_dir}, but keeping starting step as 0."
)
warmstart_handler.restore(ckpt)
train_logger = logger_initializer(mode="train") if logger_initializer else None
train_step_fn = _get_step_fn(train_pipeline, get_new_iterator(dataset), training=True)
# Counting number of parameters in the model directly when creating it.
nb_param = 0
for p in model.parameters():
nb_param += p.numel()
logging.info(f"Model has {nb_param} parameters")
last_time = datetime.datetime.now()
start_time = last_time
last_pending_snapshot = None
for step in range(start_step, train_steps + 1):
checkpoint_handler.step = step
outputs = train_step_fn()
step_done_time = datetime.datetime.now()
checkpoint_handler.walltime = (step_done_time - start_time).total_seconds() + start_walltime
if scheduler:
scheduler.step()
if step % logging_interval == 0:
interval_time = (step_done_time - last_time).total_seconds()
steps_per_s = logging_interval / interval_time
worker_example_per_s = steps_per_s * worker_batch_size
global_example_per_s = worker_example_per_s * (1 + (num_workers or 0))
global_step = step
log_values = {
"global_step": global_step,
"loss": get_global_loss_detached(outputs["loss"]),
"steps_per_s": steps_per_s,
"global_example_per_s": global_example_per_s,
"worker_examples_per_s": worker_example_per_s,
"active_training_walltime": checkpoint_handler.walltime,
}
if parameters_to_log:
log_values.update(
log_weights.weights_to_log(
model=model,
how_to_log=parameters_to_log,
)
)
log_values = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), log_values)
if tables_to_log:
log_values.update(
log_weights.log_ebc_norms(
model_state_dict=train_pipeline._model.state_dict(),
ebc_keys=tables_to_log,
)
)
if train_logger:
train_logger.log(log_values, step=global_step)
log_line = ", ".join(f"{name}: {value}" for name, value in log_values.items())
logging.info(f"Step: {step}, training. {log_line}")
last_time = step_done_time
# If we just restored, do not save again.
if checkpoint_frequency and step > start_step and step % checkpoint_frequency == 0:
if last_pending_snapshot and not last_pending_snapshot.done():
logging.warning(
"Begin a new snapshot and the last one hasn't finished. That probably indicates "
"either you're snapshotting really often or something is wrong. Will now block and "
"wait for snapshot to finish before beginning the next one."
)
last_pending_snapshot.wait()
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
# Save if we did not just save.
if checkpoint_frequency and step % checkpoint_frequency != 0:
# For the final save, wait for the checkpoint to write to make sure the process doesn't finish
# before its completed.
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
logging.info(f"Finished training steps: {step}, global_steps: {step * dist.get_world_size()}")
if last_pending_snapshot:
logging.info(f"Waiting for any checkpoints to finish.")
last_pending_snapshot.wait()
def log_eval_results(
results,
eval_logger,
partition_name: str,
step: int,
):
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
logging.info(f"Step: {step}, evaluation ({partition_name}).")
for metric_name, metric_value in results.items():
logging.info(f"\t{metric_name}: {metric_value:1.4e}")
if eval_logger:
eval_logger.log(results, step=step, commit=True)
def only_evaluate(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: str,
save_dir: str,
num_train_steps: int,
dataset: Iterable,
eval_batch_size: int,
num_eval_steps: int,
eval_timeout_in_s: int,
eval_logger: Callable,
partition_name: str,
metrics: Optional[tm.MetricCollection] = None,
):
logging.info(f"Evaluating on partition {partition_name}.")
logging.info("Computing metrics:")
logging.info(metrics)
eval_pipeline = TrainPipelineSparseDist(model, optimizer, device) # type: ignore[var-annotated]
save_state = {
"model": eval_pipeline._model,
"optimizer": eval_pipeline._optimizer,
}
checkpoint_handler = snapshot_lib.Snapshot(
save_dir=save_dir,
state=save_state,
)
for checkpoint_path in snapshot_lib.checkpoints_iterator(save_dir, timeout=eval_timeout_in_s):
checkpoint_handler.restore(checkpoint_path)
step = checkpoint_handler.step
dataset = get_new_iterator(dataset)
results = _run_evaluation(
pipeline=eval_pipeline,
dataset=dataset,
eval_steps=num_eval_steps,
eval_batch_size=eval_batch_size,
metrics=metrics,
)
log_eval_results(results, eval_logger, partition_name, step=step)
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
snapshot_lib.mark_done_eval(checkpoint_path, partition_name)
if step >= num_train_steps:
return

View file

@ -0,0 +1,45 @@
"""This is a very limited feature training loop useful for interactive debugging.
It is not intended for actual model tranining (it is not fast, doesn't compile the model).
It does not support checkpointing.
suggested use:
from tml.core import debug_training_loop
debug_training_loop.train(...)
"""
from typing import Iterable, Optional, Dict, Callable, List
import torch
from torch.optim.lr_scheduler import _LRScheduler
import torchmetrics as tm
from tml.ml_logging.torch_logging import logging
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_steps: int,
dataset: Iterable,
scheduler: _LRScheduler = None,
# Accept any arguments (to be compatible with the real training loop)
# but just ignore them.
*args,
**kwargs,
) -> None:
logging.warning("Running debug training loop, don't use for model training.")
data_iter = iter(dataset)
for step in range(0, train_steps + 1):
x = next(data_iter)
optimizer.zero_grad()
loss, outputs = model.forward(x)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
logging.info(f"Step {step} completed. Loss = {loss}")

7
core/loss_type.py Normal file
View file

@ -0,0 +1,7 @@
"""Loss type enums."""
from enum import Enum
class LossType(str, Enum):
CROSS_ENTROPY = "cross_entropy"
BCE_WITH_LOGITS = "bce_with_logits"

109
core/losses.py Normal file
View file

@ -0,0 +1,109 @@
"""Loss functions -- including multi task ones."""
import typing
from tml.core.loss_type import LossType
from tml.ml_logging.torch_logging import logging
import torch
def _maybe_warn(reduction: str):
"""
Warning for reduction different than mean.
"""
if reduction != "mean":
logging.warn(
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
f"to the gradient without DDP only for mean reduction. If you need this property for"
f"the provided reduction {reduction}, it needs to be implemented."
)
def build_loss(
loss_type: LossType,
reduction="mean",
):
_maybe_warn(reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
def loss_fn(logits, labels):
return f(logits, labels.type_as(logits), reduction=reduction)
return loss_fn
def get_global_loss_detached(local_loss, reduction="mean"):
"""
Perform all_reduce to obtain the global loss function using the provided reduction.
:param local_loss: The local loss of the current rank.
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
:return: The reduced & detached global loss.
"""
if reduction != "mean":
logging.warn(
f"The reduction used in this function should be the same as the one used by "
f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately"
f"modified for reduction {reduction}."
)
if reduction not in ["mean", "sum"]:
raise ValueError(f"Reduction {reduction} is currently unsupported.")
global_loss = local_loss.detach()
if reduction == "mean":
global_loss.div_(torch.distributed.get_world_size())
torch.distributed.all_reduce(global_loss)
return global_loss
def build_multi_task_loss(
loss_type: LossType,
tasks: typing.List[str],
task_loss_reduction="mean",
global_reduction="mean",
pos_weights=None,
):
_maybe_warn(global_reduction)
_maybe_warn(task_loss_reduction)
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
loss_reduction_fns = {
"mean": torch.mean,
"sum": torch.sum,
"min": torch.min,
"max": torch.max,
"median": torch.median,
}
def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
if pos_weights is None:
torch_weights = torch.ones([len(tasks)])
else:
torch_weights = torch.tensor(pos_weights)
losses = {}
for task_idx, task in enumerate(tasks):
task_logits = logits[:, task_idx]
label = labels[:, task_idx].type_as(task_logits)
loss = f(
task_logits,
label,
reduction=task_loss_reduction,
pos_weight=torch_weights[task_idx],
weight=weights[:, task_idx],
)
losses[f"loss/{task}"] = loss
losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
return losses
return loss_fn
_LOSS_TYPE_TO_FUNCTION = {
LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits
}

104
core/metric_mixin.py Normal file
View file

@ -0,0 +1,104 @@
"""
Mixin that requires a transform to munge output dictionary of tensors a
model produces to a form that the torchmetrics.Metric.update expects.
By unifying on our signature for `update`, we can also now use
torchmetrics.MetricCollection which requires all metrics have
the same call signature.
To use, override this with a transform that munges `outputs`
into a kwargs dict that the inherited metric.update accepts.
Here are two examples of how to extend torchmetrics.SumMetric so that it accepts
an output dictionary of tensors and munges it to what SumMetric expects (single `value`)
for its update method.
1. Using as a mixin to inherit from or define a new metric class.
class Count(MetricMixin, SumMetric):
def transform(self, outputs):
return {'value': 1}
2. Redefine an existing metric class.
SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})
"""
from abc import abstractmethod
from typing import Callable, Dict, List
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torchmetrics
class MetricMixin:
@abstractmethod
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
...
def update(self, outputs: Dict[str, torch.Tensor]):
results = self.transform(outputs)
# Do not try to update if any tensor is empty as a result of stratification.
for value in results.values():
if torch.is_tensor(value) and not value.nelement():
return
super().update(**results)
class TaskMixin:
def __init__(self, task_idx: int = -1, **kwargs):
super().__init__(**kwargs)
self._task_idx = task_idx
class StratifyMixin:
def __init__(
self,
stratifier=None,
**kwargs,
):
super().__init__(**kwargs)
self._stratifier = stratifier
def maybe_apply_stratification(
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
) -> Dict[str, torch.Tensor]:
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
outputs = outputs.copy()
if not self._stratifier:
return outputs
stratifiers = outputs.get("stratifiers")
if not stratifiers:
return outputs
if stratifiers.get(self._stratifier.name) is None:
return outputs
mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value)
target_slice = torch.squeeze(mask.nonzero(), -1)
for value_name in value_names:
target = outputs[value_name]
outputs[value_name] = torch.index_select(target, 0, target_slice)
return outputs
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
"""Returns new class using MetricMixin and given base_metric.
Functionally the same using inheritance, just saves some lines of code
if no need for class attributes.
"""
def transform_method(_self, *args, **kwargs):
return transform(*args, **kwargs)
return type(
base_metric.__name__,
(
MetricMixin,
base_metric,
),
{"transform": transform_method},
)

148
core/metrics.py Normal file
View file

@ -0,0 +1,148 @@
"""Common metrics that also support multi task.
We assume multi task models will output [task_idx, ...] predictions
"""
from typing import Any, Dict
from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin
import torch
import torchmetrics as tm
def probs_and_labels(
outputs: Dict[str, torch.Tensor],
task_idx: int,
) -> Dict[str, torch.Tensor]:
preds = outputs["probabilities"]
target = outputs["labels"]
if task_idx >= 0:
preds = preds[:, task_idx]
target = target[:, task_idx]
return {
"preds": preds,
"target": target.int(),
}
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
value = outputs["probabilities"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
"""
def __init__(self, num_samples, **kwargs):
super().__init__(**kwargs)
self.num_samples = num_samples
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
pos_scores = scores[labels == 1]
neg_scores = scores[labels == 0]
result = {
"value": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))]
> neg_scores[torch.randint(len(neg_scores), (self.num_samples,))]
}
return result
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The ranks of all positives
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": pos_ranks}
return result
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The reciprocal of the ranks of all
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": torch.div(torch.ones_like(pos_ranks), pos_ranks)}
return result
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The fraction of positives that rank in the top K among their negatives
Note that this is basically precision@k
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
"""
def __init__(self, k: int, **kwargs):
super().__init__(**kwargs)
self.k = k
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": (pos_ranks <= self.k).float()}
return result

187
core/test_metrics.py Normal file
View file

@ -0,0 +1,187 @@
from dataclasses import dataclass
from tml.core import metrics as core_metrics
from tml.core.metric_mixin import MetricMixin, prepend_transform
import torch
from torchmetrics import MaxMetric, MetricCollection, SumMetric
@dataclass
class MockStratifierConfig:
name: str
index: int
value: int
class Count(MetricMixin, SumMetric):
def transform(self, outputs):
return {"value": 1}
Max = prepend_transform(MaxMetric, lambda outputs: {"value": outputs["value"]})
def test_count_metric():
num_examples = 123
examples = [
{"stuff": 0},
] * num_examples
metric = Count()
for outputs in examples:
metric.update(outputs)
assert metric.compute().item() == num_examples
def test_collections():
max_metric = Max()
count_metric = Count()
metric = MetricCollection([max_metric, count_metric])
examples = [{"value": idx} for idx in range(123)]
for outputs in examples:
metric.update(outputs)
assert metric.compute() == {
max_metric.__class__.__name__: len(examples) - 1,
count_metric.__class__.__name__: len(examples),
}
def test_task_dependent_ctr():
num_examples = 144
batch_size = 1024
outputs = [
{
"stuff": 0,
"labels": torch.arange(0, 6).repeat(batch_size, 1),
}
for idx in range(num_examples)
]
for task_idx in range(5):
metric = core_metrics.Ctr(task_idx=task_idx)
for output in outputs:
metric.update(output)
assert metric.compute().item() == task_idx
def test_stratified_ctr():
outputs = [
{
"stuff": 0,
# [bsz, tasks]
"labels": torch.tensor(
[
[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 0],
]
),
"stratifiers": {
# [bsz]
"level": torch.tensor(
[9, 0, 9],
),
},
}
]
stratifier = MockStratifierConfig(name="level", index=2, value=9)
for task_idx in range(5):
metric = core_metrics.Ctr(task_idx=1, stratifier=stratifier)
for output in outputs:
metric.update(output)
# From the dataset of:
# [
# [0, 1, 2, 3],
# [1, 2, 3, 4],
# [2, 3, 4, 0],
# ]
# we pick out
# [
# [0, 1, 2, 3],
# [2, 3, 4, 0],
# ]
# and with Ctr task_idx, we pick out
# [
# [1,],
# [3,],
# ]
assert metric.compute().item() == (1 + 3) / 2
def test_auc():
num_samples = 10000
metric = core_metrics.Auc(num_samples)
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 1.0, 1.0, 1.0])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.0, -1.0, -1.0])
outputs_bad = {"logits": preds_bad, "labels": target}
metric.update(outputs_correct)
assert metric.compute().item() == 1.0
metric.reset()
metric.update(outputs_bad)
assert metric.compute().item() == 0.0
def test_pos_rank():
metric = core_metrics.PosRanks()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
outputs_bad = {"logits": preds_bad, "labels": target}
metric.update(outputs_correct)
assert metric.compute().item() == 2.0
metric.reset()
metric.update(outputs_bad)
assert metric.compute().item() == 4.0
def test_reciprocal_rank():
metric = core_metrics.ReciprocalRank()
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, -1.0, 0.5, 1.0, 1.5])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
outputs_bad = {"logits": preds_bad, "labels": target}
metric.update(outputs_correct)
assert abs(metric.compute().item() - 0.6111) < 0.001
metric.reset()
metric.update(outputs_bad)
assert abs(metric.compute().item() == 0.2611) < 0.001
def test_hit_k():
hit1_metric = core_metrics.HitAtK(1)
target = torch.tensor([0, 0, 1, 1, 1])
preds_correct = torch.tensor([-1.0, 1.0, 0.5, -0.1, 1.5])
outputs_correct = {"logits": preds_correct, "labels": target}
preds_bad = torch.tensor([1.0, 1.0, -1.5, -1.0, -0.5])
outputs_bad = {"logits": preds_bad, "labels": target}
hit1_metric.update(outputs_correct)
assert abs(hit1_metric.compute().item() - 0.3333) < 0.0001
hit1_metric.reset()
hit1_metric.update(outputs_bad)
assert hit1_metric.compute().item() == 0
hit3_metric = core_metrics.HitAtK(3)
hit3_metric.update(outputs_correct)
assert (hit3_metric.compute().item() - 0.66666) < 0.0001
hit3_metric.reset()
hit3_metric.update(outputs_bad)
assert abs(hit3_metric.compute().item() - 0.3333) < 0.0001

View file

@ -0,0 +1,91 @@
from dataclasses import dataclass
from typing import Tuple
from tml.common.batch import DataclassBatch
from tml.common.testing_utils import mock_pg
from tml.core import train_pipeline
import torch
from torchrec.distributed import DistributedModelParallel
@dataclass
class MockDataclassBatch(DataclassBatch):
continuous_features: torch.Tensor
labels: torch.Tensor
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(10, 1)
self.loss_fn = torch.nn.BCEWithLogitsLoss()
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
pred = self.model(batch.continuous_features)
loss = self.loss_fn(pred, batch.labels)
return (loss, pred)
def create_batch(bsz: int):
return MockDataclassBatch(
continuous_features=torch.rand(bsz, 10).float(),
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
)
def test_sparse_pipeline():
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
grad_accum=2,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach().numpy() for elem in results]
# Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
for first, second in zip(results[::2], results[1::2]):
assert first == second, results
# Check we do update gradients, i.e. results do change for every 1th and 2nd.
for first, second in zip(results[1::2], results[2::2]):
assert first != second, results
def test_amp():
device = torch.device("cpu")
model = MockModule().to(device)
steps = 8
example = create_batch(1)
dataloader = iter(example for _ in range(steps + 2))
results = []
with mock_pg():
d_model = DistributedModelParallel(model)
pipeline = train_pipeline.TrainPipelineSparseDist(
model=d_model,
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
device=device,
enable_amp=True,
# Not supported on CPU.
enable_grad_scaling=False,
)
for _ in range(steps):
results.append(pipeline.progress(dataloader))
results = [elem.detach() for elem in results]
for value in results:
assert value.dtype == torch.bfloat16

626
core/train_pipeline.py Normal file
View file

@ -0,0 +1,626 @@
"""
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
with TrainPipelineSparseDist.progress modified to support gradient accumulation.
"""
import abc
from dataclasses import dataclass, field
import logging
from typing import (
Any,
cast,
Dict,
Generic,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
import torch
from torch.autograd.profiler import record_function
from torch.fx.node import Node
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
ShardedModule,
)
from torchrec.distributed.types import Awaitable
from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor
from torchrec.streamable import Multistreamable, Pipelineable
logger: logging.Logger = logging.getLogger(__name__)
In = TypeVar("In", bound=Pipelineable)
Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]):
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
assert isinstance(
batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface"
return cast(In, batch.to(device=device, non_blocking=non_blocking))
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
# PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
# to tell the allocator about all these streams. Otherwise, the allocator might free the
# underlying memory of the tensor once it is no longer used by the creator stream. This is
# a notable programming trick when we write programs using multi CUDA streams.
cur_stream = torch.cuda.current_stream()
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(cur_stream)
class TrainPipelineBase(TrainPipeline[In, Out]):
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = (
torch.cuda.Stream() if device.type == "cuda" else None
)
self._cur_batch: Optional[In] = None
self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None:
cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._connected:
self._connect(dataloader_iter)
# Fetch next batch
with record_function("## next_batch ##"):
next_batch = next(dataloader_iter)
cur_batch = self._cur_batch
assert cur_batch is not None
if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## wait_for_batch ##"):
_wait_for_batch(cur_batch, self._memcpy_stream)
with record_function("## forward ##"):
(losses, output) = self._model(cur_batch)
if self._model.training:
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()
# Copy the next batch to GPU
self._cur_batch = cur_batch = next_batch
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
# Update
if self._model.training:
with record_function("## optimizer ##"):
self._optimizer.step()
return output
class Tracer(torch.fx.Tracer):
# Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do
# that, we can likely remove this line
proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True
return super().is_leaf_module(m, module_qualified_name)
@dataclass
class TrainPipelineContext:
# pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
# pyre-ignore [4]
feature_processor_forwards: List[Any] = field(default_factory=list)
@dataclass
class ArgInfo:
# attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"]
input_attrs: List[str]
# batch[attr1].attr2 will produce [True, False]
is_getitems: List[bool]
# name for kwarg of pipelined forward() call or None
# for a positional arg
name: Optional[str]
class PipelinedForward:
def __init__(
self,
name: str,
args: List[ArgInfo],
module: ShardedModule,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> None:
self._name = name
self._args = args
self._module = module
self._context = context
self._dist_stream = dist_stream
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with torch.cuda.stream(self._dist_stream):
data = request.wait()
# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
if self._dist_stream is not None:
torch.cuda.current_stream().wait_stream(self._dist_stream)
cur_stream = torch.cuda.current_stream()
assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
data.record_stream(cur_stream)
ctx = self._context.module_contexts[self._name]
ctx.record_stream(cur_stream)
if len(self._context.feature_processor_forwards) > 0:
with record_function("## feature_processor ##"):
for sparse_feature in data:
if sparse_feature.id_score_list_features is not None:
for fp_forward in self._context.feature_processor_forwards:
sparse_feature.id_score_list_features = fp_forward(
sparse_feature.id_score_list_features
)
return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)
@property
def name(self) -> str:
return self._name
@property
def args(self) -> List[ArgInfo]:
return self._args
def _start_data_dist(
pipelined_modules: List[ShardedModule],
batch: In,
context: TrainPipelineContext,
) -> None:
context.input_dist_requests.clear()
context.module_contexts.clear()
for module in pipelined_modules:
forward = module.forward
assert isinstance(forward, PipelinedForward)
# Retrieve argument for the input_dist of EBC
# is_getitem True means this argument could be retrieved by a list
# False means this argument is getting while getattr
# and this info was done in the _rewrite_model by tracing the
# entire model to get the arg_info_list
args = []
kwargs = {}
for arg_info in forward.args:
if arg_info.input_attrs:
arg = batch
for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):
if is_getitem:
arg = arg[attr]
else:
arg = getattr(arg, attr)
if arg_info.name:
kwargs[arg_info.name] = arg
else:
args.append(arg)
else:
args.append(None)
# Start input distribution.
module_ctx = module.create_context()
context.module_contexts[forward.name] = module_ctx
context.input_dist_requests[forward.name] = module.input_dist(module_ctx, *args, **kwargs)
# Call wait on the first awaitable in the input dist for the tensor splits
for key, awaitable in context.input_dist_requests.items():
context.input_dist_requests[key] = awaitable.wait()
def _get_node_args_helper(
# pyre-ignore
arguments,
num_found: int,
feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
if arg is None:
num_found += 1
continue
while True:
if not isinstance(arg, torch.fx.Node):
break
child_node = arg
if child_node.op == "placeholder":
num_found += 1
break
# skip this fp node
elif feature_processor_arguments is not None and child_node in feature_processor_arguments:
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "builtins"
# pyre-ignore[16]
and child_node.target.__name__ == "getattr"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, False)
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "_operator"
# pyre-ignore[16]
and child_node.target.__name__ == "getitem"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, True)
arg = child_node.args[0]
else:
break
return arg_info_list, num_found
def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes
)
kwargs_arg_info_list, num_found = _get_node_args_helper(node.kwargs.values(), num_found)
# Replace with proper names for kwargs
for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list):
arg_info_list.name = name
arg_info_list = pos_arg_info_list + kwargs_arg_info_list
return arg_info_list, num_found
def _get_unsharded_module_names_helper(
model: torch.nn.Module,
path: str,
unsharded_module_names: Set[str],
) -> bool:
sharded_children = set()
for name, child in model.named_children():
curr_path = path + name
if isinstance(child, ShardedModule):
sharded_children.add(name)
else:
child_sharded = _get_unsharded_module_names_helper(
child,
curr_path + ".",
unsharded_module_names,
)
if child_sharded:
sharded_children.add(name)
if len(sharded_children) > 0:
for name, _ in model.named_children():
if name not in sharded_children:
unsharded_module_names.add(path + name)
return len(sharded_children) > 0
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
"""
Returns a list of top level modules do not contain any sharded sub modules.
"""
unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper(
model,
"",
unsharded_module_names,
)
return list(unsharded_module_names)
def _rewrite_model( # noqa C901
model: torch.nn.Module,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
model = model.module
# Collect a list of sharded modules.
sharded_modules = {}
fp_modules = {}
for name, m in model.named_modules():
if isinstance(m, ShardedModule):
sharded_modules[name] = m
if isinstance(m, BaseGroupedFeatureProcessor):
fp_modules[name] = m
# Trace a model.
tracer = Tracer(leaf_modules=_get_unsharded_module_names(model))
graph = tracer.trace(model)
feature_processor_nodes = []
# find the fp node
for node in graph.nodes:
if node.op == "call_module" and node.target in fp_modules:
feature_processor_nodes.append(node)
# Select sharded modules, which are top-level in the forward call graph,
# i.e. which don't have input transformations, i.e.
# rely only on 'builtins.getattr'.
ret = []
for node in graph.nodes:
if node.op == "call_module" and node.target in sharded_modules:
total_num_args = len(node.args) + len(node.kwargs)
if total_num_args == 0:
continue
arg_info_list, num_found = _get_node_args(node, feature_processor_nodes)
if num_found == total_num_args:
logger.info(f"Module '{node.target}'' will be pipelined")
child = sharded_modules[node.target]
child.forward = PipelinedForward(
node.target,
arg_info_list,
child,
context,
dist_stream,
)
ret.append(child)
return ret
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
"""
synced_pipeline_id: Dict[int, int] = {}
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
enable_amp: bool = False,
enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._enable_amp = enable_amp
# NOTE: Pending upstream feedback, but two flags because we can run AMP without CUDA but cannot scale gradients without CUDA.
# Background on gradient/loss scaling
# https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling
# https://pytorch.org/docs/stable/amp.html#gradient-scaling
self._enable_grad_scaling = enable_grad_scaling
self._grad_scaler = torch.cuda.amp.GradScaler(
enabled=self._enable_amp and self._enable_grad_scaling
)
logging.info(f"Amp is enabled: {self._enable_amp}")
# use two data streams to support two concurrent batches
if device.type == "cuda":
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
else:
if self._enable_amp:
logging.warning("Amp is enabled, but no CUDA available")
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None
self._batch_i: Optional[In] = None
self._batch_ip1: Optional[In] = None
self._batch_ip2: Optional[In] = None
self._connected = False
self._context = TrainPipelineContext()
self._pipelined_modules: List[ShardedModule] = []
self._progress_calls = 0
if grad_accum is not None:
assert isinstance(grad_accum, int) and grad_accum > 0
self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None:
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
# Try to pipeline input data dist.
self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_i, self._memcpy_stream)
_start_data_dist(self._pipelined_modules, batch_i, self._context)
# batch 2
with torch.cuda.stream(self._memcpy_stream):
batch_ip1 = next(dataloader_iter)
self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
self._connected = True
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
"""
should_step_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 1) % self._grad_accum == 0
) or self._grad_accum is None
should_reset_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 2) % self._grad_accum == 0
) or self._grad_accum is None
if not self._connected:
self._connect(dataloader_iter)
elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):
self._sync_pipeline()
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
if self._model.training and should_reset_optimizer:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
batch_ip2 = next(dataloader_iter)
self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)
batch_i = cast(In, self._batch_i)
batch_ip1 = cast(In, self._batch_ip1)
with record_function("## wait_for_batch ##"):
_wait_for_batch(batch_i, self._data_dist_stream)
# Forward
with record_function("## forward ##"):
# if using multiple streams (ie. CUDA), create an event in default stream
# before starting forward pass
if self._data_dist_stream:
event = torch.cuda.current_stream().record_event()
if self._enable_amp:
# conditionally apply the model to the batch in the autocast context
# it appears that `enabled=self._enable_amp` should handle this,
# but it does not.
with torch.autocast(
device_type=self._device.type,
dtype=torch.bfloat16,
enabled=self._enable_amp,
):
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
else:
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
# Data Distribution
with record_function("## sparse_data_dist ##"):
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_ip1, self._memcpy_stream)
# Ensure event in default stream has been called before
# starting data dist
if self._data_dist_stream:
# pyre-ignore [61]: Local variable `event` is undefined, or not always defined
self._data_dist_stream.wait_event(event)
_start_data_dist(self._pipelined_modules, batch_ip1, self._context)
if self._model.training:
# Backward
with record_function("## backward ##"):
# Loss is normalize by number of accumulation steps.
# The reported loss in `output['loss']` remains the unnormalized value.
if self._grad_accum is not None:
losses = losses / self._grad_accum
self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()
if should_step_optimizer:
# Update
with record_function("## optimizer ##"):
self._grad_scaler.step(self._optimizer)
self._grad_scaler.update()
self._batch_i = batch_ip1
self._batch_ip1 = batch_ip2
if self._model.training:
self._progress_calls += 1
return output
def _sync_pipeline(self) -> None:
"""
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
"""
for module in self._pipelined_modules:
module.forward._context = self._context
module.forward._dist_stream = self._data_dist_stream

25
images/init_venv.sh Executable file
View file

@ -0,0 +1,25 @@
#! /bin/sh
if [[ "$(uname)" == "Darwin" ]]; then
echo "Only supported on Linux."
exit 1
fi
# You may need to point this to a version of python 3.10
PYTHONBIN="/opt/ee/python/3.10/bin/python3.10"
echo Using "PYTHONBIN=$PYTHONBIN"
# Put venv in tmp, these things are not made to last, just rebuild.
VENV_PATH="$HOME/tml_venv"
rm -rf "$VENV_PATH"
"$PYTHONBIN" -m venv "$VENV_PATH"
# shellcheck source=/dev/null
. "$VENV_PATH/bin/activate"
pip --require-virtual install -U pip
pip --require-virtualenv install --no-deps -r images/requirements.txt
ln -s "$(pwd)" "$VENV_PATH/lib/python3.10/site-packages/tml"
echo "Now run source ${VENV_PATH}/bin/activate" to get going.

152
images/requirements.txt Normal file
View file

@ -0,0 +1,152 @@
absl-py==1.4.0
aiofiles==22.1.0
aiohttp==3.8.3
aiosignal==1.3.1
appdirs==1.4.4
arrow==1.2.3
asttokens==2.2.1
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.1.0
backcall==0.2.0
black==22.6.0
cachetools==5.3.0
cblack==22.6.0
certifi==2022.12.7
cfgv==3.3.1
charset-normalizer==2.1.1
click==8.1.3
cmake==3.25.0
Cython==0.29.32
decorator==5.1.1
distlib==0.3.6
distro==1.8.0
dm-tree==0.1.6
docker==6.0.1
docker-pycreds==0.4.0
docstring-parser==0.8.1
exceptiongroup==1.1.0
executing==1.2.0
fbgemm-gpu-cpu==0.3.2
filelock==3.8.2
fire==0.5.0
flatbuffers==1.12
frozenlist==1.3.3
fsspec==2022.11.0
gast==0.4.0
gcsfs==2022.11.0
gitdb==4.0.10
GitPython==3.1.31
google-api-core==2.8.2
google-auth==2.16.0
google-auth-oauthlib==0.4.6
google-cloud-core==2.3.2
google-cloud-storage==2.7.0
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.4.1
googleapis-common-protos==1.56.4
grpcio==1.51.1
h5py==3.8.0
hypothesis==6.61.0
identify==2.5.17
idna==3.4
importlib-metadata==6.0.0
iniconfig==2.0.0
iopath==0.1.10
ipdb==0.13.11
ipython==8.10.0
jedi==0.18.2
Jinja2==3.1.2
keras==2.9.0
Keras-Preprocessing==1.1.2
libclang==15.0.6.1
libcst==0.4.9
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib-inline==0.1.6
moreorless==0.4.0
multidict==6.0.4
mypy==1.0.1
mypy-extensions==0.4.3
nest-asyncio==1.5.6
ninja==1.11.1
nodeenv==1.7.0
numpy==1.22.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.3
parso==0.8.3
pathspec==0.11.0
pathtools==0.1.2
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.0.0
pluggy==1.0.0
portalocker==2.6.0
portpicker==1.5.2
pre-commit==3.0.4
prompt-toolkit==3.0.36
protobuf==3.20.2
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==10.0.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pydantic==1.9.0
pyDeprecate==0.3.2
Pygments==2.14.0
pyparsing==3.0.9
pyre-extensions==0.0.27
pytest==7.2.1
pytest-mypy==0.10.3
python-dateutil==2.8.2
pytz==2022.6
PyYAML==6.0.0
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.9
scikit-build==0.16.3
sentry-sdk==1.16.0
setproctitle==1.3.2
six==1.16.0
smmap==5.0.0
sortedcontainers==2.4.0
stack-data==0.6.2
stdlibs==2022.10.9
tabulate==0.9.0
tensorboard==2.9.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.9.3
tensorflow-estimator==2.9.0
tensorflow-io-gcs-filesystem==0.30.0
termcolor==2.2.0
toml==0.10.2
tomli==2.0.1
torch==1.13.1
torchmetrics==0.11.0
torchrec==0.3.2
torchsnapshot==0.1.0
torchx==0.3.0
tqdm==4.64.1
trailrunner==1.2.1
traitlets==5.9.0
typing-inspect==0.8.0
typing_extensions==4.4.0
urllib3==1.26.13
usort==1.0.5
virtualenv==20.19.0
wandb==0.13.11
wcwidth==0.2.6
websocket-client==1.4.2
Werkzeug==2.2.3
wrapt==1.14.1
yarl==1.8.2
zipp==3.12.1

108
machines/environment.py Normal file
View file

@ -0,0 +1,108 @@
import json
import os
from typing import List
KF_DDS_PORT: int = 5050
SLURM_DDS_PORT: int = 5051
FLIGHT_SERVER_PORT: int = 2222
def on_kf():
return "SPEC_TYPE" in os.environ
def has_readers():
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return machines_config_env["dataset_worker"] is not None
return os.environ.get("HAS_READERS", "False") == "True"
def get_task_type():
if on_kf():
return os.environ["SPEC_TYPE"]
return os.environ["TASK_TYPE"]
def is_chief() -> bool:
return get_task_type() == "chief"
def is_reader() -> bool:
return get_task_type() == "datasetworker"
def is_dispatcher() -> bool:
return get_task_type() == "datasetdispatcher"
def get_task_index():
if on_kf():
pod_name = os.environ["MY_POD_NAME"]
return int(pod_name.split("-")[-1])
else:
raise NotImplementedError
def get_reader_port():
if on_kf():
return KF_DDS_PORT
return SLURM_DDS_PORT
def get_dds():
if not has_readers():
return None
dispatcher_address = get_dds_dispatcher_address()
if dispatcher_address:
return f"grpc://{dispatcher_address}"
else:
raise ValueError("Job does not have DDS.")
def get_dds_dispatcher_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
dds_host = f"{job_name}-datasetdispatcher-0"
else:
dds_host = os.environ["SLURM_JOB_NODELIST_HET_GROUP_0"]
return f"{dds_host}:{get_reader_port()}"
def get_dds_worker_address():
if not has_readers():
return None
if on_kf():
job_name = os.environ["JOB_NAME"]
task_index = get_task_index()
return f"{job_name}-datasetworker-{task_index}:{get_reader_port()}"
else:
node = os.environ["SLURMD_NODENAME"]
return f"{node}:{get_reader_port()}"
def get_num_readers():
if not has_readers():
return 0
if on_kf():
machines_config_env = json.loads(os.environ["MACHINES_CONFIG"])
return int(machines_config_env["num_dataset_workers"] or 0)
return len(os.environ["SLURM_JOB_NODELIST_HET_GROUP_1"].split(","))
def get_flight_server_addresses():
if on_kf():
job_name = os.environ["JOB_NAME"]
return [
f"grpc://{job_name}-datasetworker-{task_index}:{FLIGHT_SERVER_PORT}"
for task_index in range(get_num_readers())
]
else:
raise NotImplementedError
def get_dds_journaling_dir():
return os.environ.get("DATASET_JOURNALING_DIR", None)

38
machines/get_env.py Normal file
View file

@ -0,0 +1,38 @@
import tml.machines.environment as env
from absl import app, flags
FLAGS = flags.FLAGS
flags.DEFINE_string("property", None, "Which property of the current environment to fetch.")
def main(argv):
if FLAGS.property == "using_dds":
print(f"{env.has_readers()}", flush=True)
if FLAGS.property == "has_readers":
print(f"{env.has_readers()}", flush=True)
elif FLAGS.property == "get_task_type":
print(f"{env.get_task_type()}", flush=True)
elif FLAGS.property == "is_datasetworker":
print(f"{env.is_reader()}", flush=True)
elif FLAGS.property == "is_dds_dispatcher":
print(f"{env.is_dispatcher()}", flush=True)
elif FLAGS.property == "get_task_index":
print(f"{env.get_task_index()}", flush=True)
elif FLAGS.property == "get_dataset_service":
print(f"{env.get_dds()}", flush=True)
elif FLAGS.property == "get_dds_dispatcher_address":
print(f"{env.get_dds_dispatcher_address()}", flush=True)
elif FLAGS.property == "get_dds_worker_address":
print(f"{env.get_dds_worker_address()}", flush=True)
elif FLAGS.property == "get_dds_port":
print(f"{env.get_reader_port()}", flush=True)
elif FLAGS.property == "get_dds_journaling_dir":
print(f"{env.get_dds_journaling_dir()}", flush=True)
elif FLAGS.property == "should_start_dds":
print(env.is_reader() or env.is_dispatcher(), flush=True)
if __name__ == "__main__":
app.run(main)

26
machines/is_venv.py Normal file
View file

@ -0,0 +1,26 @@
"""This is intended to be run as a module.
e.g. python -m tml.machines.is_venv
Exits with 0 ii running in venv, otherwise 1.
"""
import sys
import logging
def is_venv():
# See https://stackoverflow.com/questions/1871549/determine-if-python-is-running-inside-virtualenv
return sys.base_prefix != sys.prefix
def _main():
if is_venv():
logging.info("In venv %s", sys.prefix)
sys.exit(0)
else:
logging.error("Not in venv")
sys.exit(1)
if __name__ == "__main__":
_main()

43
machines/list_ops.py Normal file
View file

@ -0,0 +1,43 @@
"""
Simple str.split() parsing of input string
usage example:
python list_ops.py --input_list=$INPUT [--sep=","] [--op=<len|select>] [--elem=$INDEX]
Args:
- input_list: input string
- sep (default ","): separator string
- elem (default 0): integer index
- op (default "select"): either `len` or `select`
- len: prints len(input_list.split(sep))
- select: prints input_list.split(sep)[elem]
Typical usage would be in a bash script, e.g.:
LIST_LEN=$(python list_ops.py --input_list=$INPUT --op=len)
"""
import tml.machines.environment as env
from absl import app, flags
FLAGS = flags.FLAGS
flags.DEFINE_string("input_list", None, "string to parse as list")
flags.DEFINE_integer("elem", 0, "which element to take")
flags.DEFINE_string("sep", ",", "separator")
flags.DEFINE_string("op", "select", "operation to do")
def main(argv):
split_list = FLAGS.input_list.split(FLAGS.sep)
if FLAGS.op == "select":
print(split_list[FLAGS.elem], flush=True)
elif FLAGS.op == "len":
print(len(split_list), flush=True)
else:
raise ValueError(f"operation {FLAGS.op} not recognized.")
if __name__ == "__main__":
app.run(main)

3
metrics/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from .aggregation import StableMean # noqa
from .auroc import AUROCWithMWU # noqa
from .rce import NRCE, RCE # noqa

97
metrics/aggregation.py Normal file
View file

@ -0,0 +1,97 @@
"""
Contains aggregation metrics.
"""
from typing import Tuple, Union
import torch
import torchmetrics
def update_mean(
current_mean: torch.Tensor,
current_weight_sum: torch.Tensor,
value: torch.Tensor,
weight: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the mean according to Welford formula:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
Args:
current_mean: The value of the current accumulated mean.
current_weight_sum: The current weighted sum.
value: The new value that needs to be added to get a new mean.
weight: The weights for the new value.
Returns: The updated mean and updated weighted sum.
"""
weight = torch.broadcast_to(weight, value.shape)
# Avoiding (on purpose) in-place operation when using += in case
# current_mean and current_weight_sum share the same storage
current_weight_sum = current_weight_sum + torch.sum(weight)
current_mean = current_mean + torch.sum((weight / current_weight_sum) * (value - current_mean))
return current_mean, current_weight_sum
def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor:
"""
Merge the state from multiple workers.
Args:
state: A tensor with the first dimension indicating workers.
Returns: The accumulated mean from all workers.
"""
mean, weight_sum = update_mean(
current_mean=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
current_weight_sum=torch.as_tensor(0.0, dtype=state.dtype, device=state.device),
value=state[:, 0],
weight=state[:, 1],
)
return torch.stack([mean, weight_sum])
class StableMean(torchmetrics.Metric):
"""
This implements a numerical stable mean metrics computation using Welford algorithm according to
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version.
For example when using float32, the algorithm will give a valid output even if the "sum" is larger
than the maximum float32 as far as the mean is within the limit of float32.
See also https://nullbuffer.com/articles/welford_algorithm.html for more information.
"""
def __init__(self, **kwargs):
"""
Args:
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
super().__init__(**kwargs)
self.add_state(
"mean_and_weight_sum",
default=torch.zeros(2),
dist_reduce_fx=stable_mean_dist_reduce_fn,
)
def update(self, value: torch.Tensor, weight: Union[float, torch.Tensor] = 1.0) -> None:
"""
Update the current mean.
Args:
value: Value to update the mean with.
weight: weight to use. Shape should be broadcastable to that of value.
"""
mean, weight_sum = self.mean_and_weight_sum[0], self.mean_and_weight_sum[1]
if not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=value.dtype, device=value.device)
self.mean_and_weight_sum[0], self.mean_and_weight_sum[1] = update_mean(
mean, weight_sum, value, torch.as_tensor(weight)
)
def compute(self) -> torch.Tensor:
"""
Compute and return the accumulated mean.
"""
return self.mean_and_weight_sum[0]

162
metrics/auroc.py Normal file
View file

@ -0,0 +1,162 @@
"""
AUROC metrics.
"""
from typing import Union
from tml.ml_logging.torch_logging import logging
import torch
import torchmetrics
from torchmetrics.utilities.data import dim_zero_cat
def _compute_helper(
predictions: torch.Tensor,
target: torch.Tensor,
weights: torch.Tensor,
max_positive_negative_weighted_sum: torch.Tensor,
min_positive_negative_weighted_sum: torch.Tensor,
equal_predictions_as_incorrect: bool,
) -> torch.Tensor:
"""
Compute AUROC.
Args:
predictions: The predictions probabilities.
target: The target.
weights: The sample weights to assign to each sample in the batch.
max_positive_negative_weighted_sum: The sum of the weights for the positive labels.
min_positive_negative_weighted_sum:
equal_predictions_as_incorrect: For positive & negative labels having identical scores,
we assume that they are correct prediction (i.e weight = 1) when ths is False. Otherwise,
we assume that they are correct prediction (i.e weight = 0).
"""
dim = 0
# Sort predictions based on key (score, true_label). The order is ascending for score.
# For true_label, order is ascending if equal_predictions_as_incorrect is True;
# otherwise it is descending.
target_order = torch.argsort(target, dim=dim, descending=equal_predictions_as_incorrect)
score_order = torch.sort(torch.gather(predictions, dim, target_order), stable=True, dim=dim)[1]
score_order = torch.gather(target_order, dim, score_order)
sorted_target = torch.gather(target, dim, score_order)
sorted_weights = torch.gather(weights, dim, score_order)
negatives_from_left = torch.cumsum((1.0 - sorted_target) * sorted_weights, 0)
numerator = torch.sum(
sorted_weights * (sorted_target * negatives_from_left / max_positive_negative_weighted_sum)
)
return numerator / min_positive_negative_weighted_sum
class AUROCWithMWU(torchmetrics.Metric):
"""
AUROC using Mann-Whitney U-test.
See https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve.
This AUROC implementation is well suited to (non-zero) low-CTR. In particular it will return
the correct AUROC even if the predicted probabilities are all close to 0.
Currently only support binary classification.
"""
def __init__(self, label_threshold: float = 0.5, raise_missing_class: bool = False, **kwargs):
"""
Args:
label_threshold: Labels strictly above this threshold are considered positive labels,
otherwise, they are considered negative.
raise_missing_class: If True, an error will be raise if negative or positive class is missing.
Otherwise, we will simply log a warning.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
super().__init__(**kwargs)
self.add_state("predictions", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
self.add_state("weights", default=[], dist_reduce_fx="cat")
self.label_threshold = label_threshold
self.raise_missing_class = raise_missing_class
def update(
self,
predictions: torch.Tensor,
target: torch.Tensor,
weight: Union[float, torch.Tensor] = 1.0,
) -> None:
"""
Update the current auroc.
Args:
predictions: Predicted values, 1D Tensor or 2D Tensor of shape batch_size x 1.
target: Ground truth. Must have same shape as predictions.
weight: The weight to use for the predicted values. Shape should be
broadcastable to that of predictions.
"""
self.predictions.append(predictions)
self.target.append(target)
if not isinstance(weight, torch.Tensor):
weight = torch.as_tensor(weight, dtype=predictions.dtype, device=target.device)
self.weights.append(torch.broadcast_to(weight, predictions.size()))
def compute(self) -> torch.Tensor:
"""
Compute and return the accumulated AUROC.
"""
weights = dim_zero_cat(self.weights)
predictions = dim_zero_cat(self.predictions)
target = dim_zero_cat(self.target).type_as(predictions)
negative_mask = target <= self.label_threshold
positive_mask = torch.logical_not(negative_mask)
if not negative_mask.any():
msg = "Negative class missing. AUROC returned will be meaningless."
if self.raise_missing_class:
raise ValueError(msg)
else:
logging.warn(msg)
if not positive_mask.any():
msg = "Positive class missing. AUROC returned will be meaningless."
if self.raise_missing_class:
raise ValueError(msg)
else:
logging.warn(msg)
weighted_actual_negative_sum = torch.sum(
torch.where(negative_mask, weights, torch.zeros_like(weights))
)
weighted_actual_positive_sum = torch.sum(
torch.where(positive_mask, weights, torch.zeros_like(weights))
)
max_positive_negative_weighted_sum = torch.max(
weighted_actual_negative_sum, weighted_actual_positive_sum
)
min_positive_negative_weighted_sum = torch.min(
weighted_actual_negative_sum, weighted_actual_positive_sum
)
# Compute auroc with the weight set to 1 when positive & negative have identical scores.
auroc_le = _compute_helper(
target=target,
weights=weights,
predictions=predictions,
min_positive_negative_weighted_sum=min_positive_negative_weighted_sum,
max_positive_negative_weighted_sum=max_positive_negative_weighted_sum,
equal_predictions_as_incorrect=False,
)
# Compute auroc with the weight set to 0 when positive & negative have identical scores.
auroc_lt = _compute_helper(
target=target,
weights=weights,
predictions=predictions,
min_positive_negative_weighted_sum=min_positive_negative_weighted_sum,
max_positive_negative_weighted_sum=max_positive_negative_weighted_sum,
equal_predictions_as_incorrect=True,
)
# Compute auroc with the weight set to 1/2 when positive & negative have identical scores.
return auroc_le - (auroc_le - auroc_lt) / 2.0

280
metrics/rce.py Normal file
View file

@ -0,0 +1,280 @@
"""
Contains RCE metrics.
"""
import copy
from functools import partial
from typing import Union
from tml.metrics import aggregation
import torch
import torchmetrics
def _smooth(
value: torch.Tensor, label_smoothing: Union[float, torch.Tensor]
) -> Union[float, torch.Tensor]:
"""
Smooth given values.
Args:
value: Value to smooth.
label_smoothing: smoothing constant.
Returns: Smoothed values.
"""
return value * (1.0 - label_smoothing) + 0.5 * label_smoothing
def _binary_cross_entropy_with_clipping(
predictions: torch.Tensor,
target: torch.Tensor,
epsilon: Union[float, torch.Tensor],
reduction: str = "none",
) -> torch.Tensor:
"""
Clip Predictions and apply binary cross entropy.
This is done to match the implementation in keras at
https://github.com/keras-team/keras/blob/r2.9/keras/backend.py#L5294-L5300
Args:
predictions: Predicted probabilities.
target: Ground truth.
epsilon: Epsilon fuzz factor used to clip the predictions.
reduction: The reduction method to use.
Returns: Binary cross entropy on the clipped predictions.
"""
predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon)
bce = -target * torch.log(predictions + epsilon)
bce -= (1.0 - target) * torch.log(1.0 - predictions + epsilon)
if reduction == "mean":
return torch.mean(bce)
return bce
class RCE(torchmetrics.Metric):
"""
Compute the relative cross entropy (`RCE <http://go/rce>`_).
RCE is metric used for models predicting probability of success (p), i.e. pCTR.
RCE represents the binary `cross entropy <https://en.wikipedia.org/wiki/Cross_entropy>` of
the model compared to a reference straw man model.
Binary cross entropy is defined as:
y = label; p = prediction;
binary cross entropy(example) = - y * log(p) - (1-y) * log(1-p)
Where y in {0, 1}
Cross entropy of a model is defined as:
CE(model) = average(binary cross entropy(example))
Over all the examples we aggregate on.
The straw man model is quite simple, it is a constant predictor, always predicting the average
over the labels.
RCE of a model is defined as:
RCE(model) = 100 * (CE(reference model) - CE(model)) / CE(reference model)
.. note:: Maximizing the likelihood is the same as minimizing the cross entropy or maximizing
the RCE. Since cross entropy is the average minus likelihood for the binary case.
.. note:: Binary cross entropy of an example is non negative, and equal to the
`KL divergence <(https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
#Properties>`
since p is constant, and its entropy is equal to zero.
.. note:: 0% RCE means as good as the straw man model.
100% means always predicts exactly the label. Namely, cross entropy of the model is
always zero. In practice 100% is impossible to achieve due to clipping.
Negative RCE means that the model is doing worse than the straw man.
This usually means an un-calibrated model, namely, the average prediction
is "far" from the average label. Examining NRCE might help identifying if that is
the case.
.. note:: RCE is not a "ratio" in the statistical
`level of measurement sense <https://en.wikipedia.org/wiki/Level_of_measurement>`.
The higher the model's RCE is the harder it is to improve it by an extra point.
For example:
Let CE(model) = 0.5 CE(reference model), then the RCE(model) = 50.
Now take a "twice as good" model:
Let CE(better model) = 0.5 CE(model) = 0.25 CE(reference model),
then the RCE(better model) = 75 and not 100.
.. note:: In order to keep the log function stable, typically p is limited to
lie in [CLAMP_EPSILON, 1-CLAMP_EPSILON],
where CLAMP_EPSILON is some small constant like: 1e-7.
Old implementation used 1e-5 clipping by default, current uses
tf.keras.backend.epsilon()
whose default is 1e-7.
.. note:: Since the reference model prediction is constant (probability),
CE(reference model) = H(average(label))
Where H is the standard
`entropy <https://en.wikipedia.org/wiki/Entropy_(information_theory)>` function.
.. note:: Must have at least 1 positive and 1 negative sample accumulated,
or RCE will come out as NaN.
"""
def __init__(
self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
):
"""
Args:
from_logits: whether or not predictions are logits or probabilities.
label_smoothing: label smoothing constant.
epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
super().__init__(**kwargs)
self.from_logits = from_logits
self.label_smoothing = label_smoothing
self.epsilon = epsilon
self.kwargs = kwargs
self.mean_label = aggregation.StableMean(**kwargs)
self.binary_cross_entropy = aggregation.StableMean(**kwargs)
if self.from_logits:
self.bce_loss_fn = torch.nn.functional.binary_cross_entropy_with_logits
else:
self.bce_loss_fn = partial(_binary_cross_entropy_with_clipping, epsilon=self.epsilon)
# Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.
self.batch_metric = copy.deepcopy(self)
def update(
self, predictions: torch.Tensor, target: torch.Tensor, weight: float = 1.0
) -> torch.Tensor:
"""
Update the current rce.
Args:
predictions: Predicted values.
target: Ground truth. Should have same shape as predictions.
weight: The weight to use for the predicted values. Shape should be broadcastable to that of
predictions.
"""
target = _smooth(target, self.label_smoothing)
self.mean_label.update(target, weight)
self.binary_cross_entropy.update(
self.bce_loss_fn(predictions, target, reduction="none"), weight
)
def compute(self) -> torch.Tensor:
"""
Compute and return the accumulated rce.
"""
baseline_mean = self.mean_label.compute()
baseline_ce = _binary_cross_entropy_with_clipping(
baseline_mean, baseline_mean, reduction="mean", epsilon=self.epsilon
)
pred_ce = self.binary_cross_entropy.compute()
return (1.0 - (pred_ce / baseline_ce)) * 100
def reset(self):
"""
Reset the metric to its initial state.
"""
super().reset()
self.mean_label.reset()
self.binary_cross_entropy.reset()
def forward(self, *args, **kwargs):
"""
Serves the dual purpose of both computing the metric on the current batch of inputs but also
add the batch statistics to the overall accumulating metric state.
Input arguments are the exact same as corresponding ``update`` method.
The returned output is the exact same as the output of ``compute``.
"""
self.update(*args, **kwargs)
self.batch_metric.update(*args, **kwargs)
batch_result = self.batch_metric.compute()
self.batch_metric.reset()
return batch_result
class NRCE(RCE):
"""
Calculate the RCE of the normalizes model.
Where the normalized model prediction average is normalized to the average label seen so far.
Namely, the the normalized model prediction:
normalized model prediction(example) = (model prediction(example) * average(label)) /
average(model prediction)
Where the average is over all previously seen examples.
.. note:: average(normalized model prediction) = average(label)
.. note:: NRCE can be misleading since it is oblivious to mis-calibrations.
The common interpretation of NRCE is to measure how good your model could potentially
perform if it was well calibrated.
.. note:: A big gap between NRCE and RCE might indicate a badly calibrated model,
"""
def __init__(
self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
):
"""
Args:
from_logits: whether or not predictions are logits or probabilities.
label_smoothing: label smoothing constant.
epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.
It only used when computing the cross entropy but not when normalizing.
**kwargs: Additional parameters supported by all torchmetrics.Metric.
"""
super().__init__(from_logits=False, label_smoothing=0, epsilon=epsilon, **kwargs)
self.nrce_from_logits = from_logits
self.nrce_label_smoothing = label_smoothing
self.mean_prediction = aggregation.StableMean()
# Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.
self.batch_metric = copy.deepcopy(self)
def update(
self,
predictions: torch.Tensor,
target: torch.Tensor,
weight: Union[float, torch.Tensor] = 1.0,
):
"""
Update the current nrce.
Args:
predictions: Predicted values.
target: Ground truth. Should have same shape as predictions.
weight: The weight to use for the predicted values. Shape should be broadcastable to that of
predictions.
"""
predictions = torch.sigmoid(predictions) if self.nrce_from_logits else predictions
target = _smooth(target, self.nrce_label_smoothing)
self.mean_label.update(target, weight)
self.mean_prediction.update(predictions, weight)
normalizer = self.mean_label.compute() / self.mean_prediction.compute()
predictions = predictions * normalizer
self.binary_cross_entropy.update(
self.bce_loss_fn(predictions, target, reduction="none"), weight
)
def reset(self):
"""
Reset the metric to its initial state.
"""
super().reset()
self.mean_prediction.reset()

0
ml_logging/__init__.py Normal file
View file

View file

@ -0,0 +1,26 @@
"""Sets up logging through absl for training usage.
- Redirects logging to sys.stdout so that severity levels in GCP Stackdriver are accurate.
Usage:
>>> from twitter.ml.logging.absl_logging import logging
>>> logging.info(f"Properly logged as INFO level in GCP Stackdriver.")
"""
import logging as py_logging
import sys
from absl import logging as logging
def setup_absl_logging():
"""Make sure that absl logging pushes to stdout rather than stderr."""
logging.get_absl_handler().python_handler.stream = sys.stdout
formatter = py_logging.Formatter(
fmt="[%(module)s.%(funcName)s:%(lineno)s - %(levelname)s] %(message)s"
)
logging.get_absl_handler().setFormatter(formatter)
logging.set_verbosity(logging.INFO)
setup_absl_logging()

View file

@ -0,0 +1,21 @@
import unittest
from tml.ml_logging.torch_logging import logging
class Testtlogging(unittest.TestCase):
def test_warn_once(self):
with self.assertLogs(level="INFO") as captured_logs:
logging.info("first info")
logging.warning("first warning")
logging.warning("first warning")
logging.info("second info")
self.assertEqual(
captured_logs.output,
[
"INFO:absl:first info",
"WARNING:absl:first warning",
"INFO:absl:second info",
],
)

View file

@ -0,0 +1,61 @@
"""Overrides absl logger to be rank-aware for distributed pytorch usage.
>>> # in-bazel import
>>> from twitter.ml.logging.torch_logging import logging
>>> # out-bazel import
>>> from ml.logging.torch_logging import logging
>>> logging.info(f"This only prints on rank 0 if distributed, otherwise prints normally.")
>>> logging.info(f"This prints on all ranks if distributed, otherwise prints normally.", rank=-1)
"""
import functools
from typing import Optional
from tml.ml_logging.absl_logging import logging as logging
from absl import logging as absl_logging
import torch.distributed as dist
def rank_specific(logger):
"""Ensures that we only override a given logger once."""
if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"):
return logger
def _if_rank(logger_method, limit: Optional[int] = None):
if limit:
# If we are limiting redundant logs, wrap logging call with a cache
# to not execute if already cached.
def _wrap(_call):
@functools.lru_cache(limit)
def _logger_method(*args, **kwargs):
_call(*args, **kwargs)
return _logger_method
logger_method = _wrap(logger_method)
def _inner(msg, *args, rank: int = 0, **kwargs):
if not dist.is_initialized():
logger_method(msg, *args, **kwargs)
elif dist.get_rank() == rank:
logger_method(msg, *args, **kwargs)
elif rank < 0:
logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs)
# Register this stack frame with absl logging so that it doesn't trample logging lines.
absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__)
return _inner
logger.fatal = _if_rank(logger.fatal)
logger.error = _if_rank(logger.error)
logger.warning = _if_rank(logger.warning, limit=1)
logger.info = _if_rank(logger.info)
logger.debug = _if_rank(logger.debug)
logger.exception = _if_rank(logger.exception)
logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True
rank_specific(logging)

89
model.py Normal file
View file

@ -0,0 +1,89 @@
"""Wraps servable model in loss and RecapBatch passing to be trainable."""
# flake8: noqa
from typing import Callable
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
class ModelAndLoss(torch.nn.Module):
# Reconsider our approach at a later date: https://ppwwyyxx.com/blog/2022/Loss-Function-Separation/
def __init__(
self,
model,
loss_fn: Callable,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(batch)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
outputs.update(
{
"loss": losses,
"labels": batch.labels,
"weights": batch.weights,
}
)
# Allow multiple losses.
return losses, outputs
def maybe_shard_model(
model,
device: torch.device,
):
"""Set up and apply DistributedModelParallel to a model if running in a distributed environment.
If in a distributed environment, constructs Topology, sharders, and ShardingPlan, then applies
DistributedModelParallel.
If not in a distributed environment, returns model directly.
"""
if dist.is_initialized():
logging.info("***** Wrapping in DistributedModelParallel *****")
logging.info(f"Model before wrapping: {model}")
model = DistributedModelParallel(
module=model,
device=device,
)
logging.info(f"Model after wrapping: {model}")
return model
def log_sharded_tensor_content(weight_name: str, table_name: str, weight_tensor) -> None:
"""Handy function to log the content of EBC embedding layer.
Only works for single GPU machines.
Args:
weight_name: name of tensor, as defined in model
table_name: name of the EBC table the weight is taken from
weight_tensor: embedding weight tensor
"""
logging.info(f"{weight_name}, {table_name}", rank=-1)
logging.info(f"{weight_tensor.metadata()}", rank=-1)
output_tensor = torch.zeros(*weight_tensor.size(), device=torch.device("cuda:0"))
weight_tensor.gather(out=output_tensor)
logging.info(f"{output_tensor}", rank=-1)

1
optimizers/__init__.py Normal file
View file

@ -0,0 +1 @@
from tml.optimizers.optimizer import compute_lr

82
optimizers/config.py Normal file
View file

@ -0,0 +1,82 @@
"""Optimization configurations for models."""
import typing
import tml.core.config as base_config
import pydantic
class PiecewiseConstant(base_config.BaseConfig):
learning_rate_boundaries: typing.List[int] = pydantic.Field(None)
learning_rate_values: typing.List[float] = pydantic.Field(None)
class LinearRampToConstant(base_config.BaseConfig):
learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
description="Number of steps to ramp this up from zero."
)
class LinearRampToCosine(base_config.BaseConfig):
learning_rate: float
final_learning_rate: float
num_ramp_steps: pydantic.PositiveInt = pydantic.Field(
description="Number of steps to ramp this up from zero."
)
final_num_steps: pydantic.PositiveInt = pydantic.Field(
description="Final number of steps where decay stops."
)
class LearningRate(base_config.BaseConfig):
constant: float = pydantic.Field(None, one_of="lr")
linear_ramp_to_cosine: LinearRampToCosine = pydantic.Field(None, one_of="lr")
linear_ramp_to_constant: LinearRampToConstant = pydantic.Field(None, one_of="lr")
piecewise_constant: PiecewiseConstant = pydantic.Field(None, one_of="lr")
class OptimizerAlgorithmConfig(base_config.BaseConfig):
"""Base class for optimizer configurations."""
lr: float
...
class AdamConfig(OptimizerAlgorithmConfig):
# see https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
lr: float
betas: typing.Tuple[float, float] = [0.9, 0.999]
eps: float = 1e-7 # Numerical stability in denominator.
class SgdConfig(OptimizerAlgorithmConfig):
lr: float
momentum: float = 0.0
class AdagradConfig(OptimizerAlgorithmConfig):
lr: float
eps: float = 0
class OptimizerConfig(base_config.BaseConfig):
learning_rate: LearningRate = pydantic.Field(
None,
description="Constant learning rates",
)
adam: AdamConfig = pydantic.Field(None, one_of="optimizer")
sgd: SgdConfig = pydantic.Field(None, one_of="optimizer")
adagrad: AdagradConfig = pydantic.Field(None, one_of="optimizer")
def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig):
if optimizer_config.adam is not None:
return optimizer_config.adam
elif optimizer_config.sgd is not None:
return optimizer_config.sgd
elif optimizer_config.adagrad is not None:
return optimizer_config.adagrad
else:
raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}")

105
optimizers/optimizer.py Normal file
View file

@ -0,0 +1,105 @@
from typing import Dict, Tuple
import math
import bisect
from tml.optimizers.config import (
LearningRate,
OptimizerConfig,
)
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from tml.ml_logging.torch_logging import logging
def compute_lr(lr_config, step):
"""Compute a learning rate."""
if lr_config.constant is not None:
return lr_config.constant
elif lr_config.piecewise_constant is not None:
return lr_config.piecewise_constant.learning_rate_values[
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
]
elif lr_config.linear_ramp_to_constant is not None:
slope = (
lr_config.linear_ramp_to_constant.learning_rate
/ lr_config.linear_ramp_to_constant.num_ramp_steps
)
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
elif lr_config.linear_ramp_to_cosine is not None:
cfg = lr_config.linear_ramp_to_cosine
if step < cfg.num_ramp_steps:
slope = cfg.learning_rate / cfg.num_ramp_steps
return slope * step
elif step <= cfg.final_num_steps:
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
1.0
+ math.cos(
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
)
)
else:
return cfg.final_learning_rate
else:
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
class LRShim(_LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
def __init__(
self,
optimizer,
lr_dict: Dict[str, LearningRate],
last_epoch=-1,
verbose=False,
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
if num_param_groups != len(lr_dict):
raise ValueError(
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
)
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
logging.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]
def get_optimizer_class(optimizer_config: OptimizerConfig):
if optimizer_config.adam is not None:
return torch.optim.Adam
elif optimizer_config.sgd is not None:
return torch.optim.SGD
elif optimizer_config.adagrad is not None:
return torch.optim.Adagrad
def build_optimizer(
model: torch.nn.Module, optimizer_config: OptimizerConfig
) -> Tuple[Optimizer, _LRScheduler]:
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
"""
optimizer_class = get_optimizer_class(optimizer_config)
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
# We're passing everything in as one group here
scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate})
return optimizer, scheduler

0
projects/__init__.py Normal file
View file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,48 @@
This project is the "heavy ranker" used on the "For You" timeline. This is used to generate the ranking of Tweet after candidate retrieval and light ranker (note the final ordering of the Tweet is not directly the highest -> lowest scoring, because after scoring other heuristics are used).
This model captures the ranking model used for the majority of users of Twitter "For You" timeline in early March 2023. Due to the need to make sure this runs independently from other parts of Twitter codebase, there may be small differences from the production model.
The model receives various features, describing the Tweet and the user whose timeline is being constructed as input (see FEATURES.md for more details). The model outputs multiple binary predictions about how the user will respond if shown the Tweet.
Those are:
"recap.engagement.is_favorited": The probability the user will favorite the Tweet.
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": The probability the user will click into the conversation of this Tweet and reply or Like a Tweet.
"recap.engagement.is_good_clicked_convo_desc_v2": The probability the user will click into the conversation of this Tweet and stay there for at least 2 minutes.
"recap.engagement.is_negative_feedback_v2": The probability the user will react negatively (requesting "show less often" on the Tweet or author, block or mute the Tweet author)
"recap.engagement.is_profile_clicked_and_profile_engaged": The probability the user opens the Tweet author profile and Likes or replies to a Tweet.
"recap.engagement.is_replied": The probability the user replies to the Tweet.
"recap.engagement.is_replied_reply_engaged_by_author": The probability the user replies to the Tweet and this reply is engaged by the Tweet author.
"recap.engagement.is_report_tweet_clicked": The probability the user will click Report Tweet.
"recap.engagement.is_retweeted": The probability the user will ReTweet the Tweet.
"recap.engagement.is_video_playback_50": The probability (for a video Tweet) that the user will watch at least half of the video
For ranking the candidates these predictions are combined into a score by weighting them:
"recap.engagement.is_favorited": 0.5
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": 11* (the maximum prediction from these two "good click" features is used and weighted by 11, the other prediction is ignored).
"recap.engagement.is_good_clicked_convo_desc_v2": 11*
"recap.engagement.is_negative_feedback_v2": -74
"recap.engagement.is_profile_clicked_and_profile_engaged": 12
"recap.engagement.is_replied": 27
"recap.engagement.is_replied_reply_engaged_by_author": 75
"recap.engagement.is_report_tweet_clicked": -369
"recap.engagement.is_retweeted": 1
"recap.engagement.is_video_playback_50": 0.005
We cannot release the real training data due to privacy restrictions. However, we have included a script to generate random data to ensure you can run the model training code.
To try training the model (assuming you have already followed the repo setup instructions and are inside a virtualenv).
Run
$ ./projects/home/recap/scripts/create_random_data.sh
This will create some random data (in $HOME/tmp/recap_local_random_data).
$ ./projects/home/recap/scripts/run_local.sh
This will train the model (for a small number of iterations). Checkpoints and logs will be written to $HOME/tmp/runs/recap_local_debug.
The model training is configured through a yaml file (./projects/home/recap/config/local_prod.yaml).
The model architecture is a parallel masknet (https://arxiv.org/abs/2102.07619).

View file

View file

@ -0,0 +1,54 @@
from tml.core import config as config_mod
import tml.projects.home.recap.data.config as data_config
import tml.projects.home.recap.model.config as model_config
import tml.projects.home.recap.optimizer.config as optimizer_config
from enum import Enum
from typing import Dict, Optional
import pydantic
class TrainingConfig(config_mod.BaseConfig):
save_dir: str = "/tmp/model"
num_train_steps: pydantic.PositiveInt = 1000000
initial_checkpoint_dir: str = pydantic.Field(
None, description="Directory of initial checkpoints", at_most_one_of="initialization"
)
checkpoint_every_n: pydantic.PositiveInt = 1000
checkpoint_max_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Maximum number of checkpoints to keep. Defaults to keeping all."
)
train_log_every_n: pydantic.PositiveInt = 1000
num_eval_steps: int = pydantic.Field(
16384, description="Number of evaluation steps. If < 0 the entire dataset " "will be used."
)
eval_log_every_n: pydantic.PositiveInt = 5000
eval_timeout_in_s: pydantic.PositiveFloat = 60 * 60
gradient_accumulation: int = pydantic.Field(
None, description="Number of replica steps to accumulate gradients."
)
class RecapConfig(config_mod.BaseConfig):
training: TrainingConfig = pydantic.Field(TrainingConfig())
model: model_config.ModelConfig
train_data: data_config.RecapDataConfig
validation_data: Dict[str, data_config.RecapDataConfig]
optimizer: optimizer_config.RecapOptimizerConfig
which_metrics: Optional[str] = pydantic.Field(None, description="which metrics to pick.")
# DANGER DANGER! You might expect validators here to ensure that multi task learning setups are
# the same as the data. Unfortunately, this throws opaque errors when the model configuration is
# invalid. In our judgement, that is a more frequency and worse occurrence than tasks not matching
# the data.
class JobMode(str, Enum):
"""Job modes."""
TRAIN = "train"
EVALUATE = "evaluate"
INFERENCE = "inference"

View file

@ -0,0 +1,114 @@
{
"schema": [
{
"dtype": "int64_list",
"feature_name": "home_recap_2022_discrete__segdense_vals",
"length": 320
},
{
"dtype": "float_list",
"feature_name": "home_recap_2022_cont__segdense_vals",
"length": 6000
},
{
"dtype": "int64_list",
"feature_name": "home_recap_2022_binary__segdense_vals",
"length": 512
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_tweet_detail_dwelled_15_sec",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_profile_clicked_and_profile_engaged",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_replied_reply_engaged_by_author",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_video_playback_50",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_report_tweet_clicked",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_replied",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "meta.author_id",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_negative_feedback_v2",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_retweeted",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_favorited",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "meta.tweet_id",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_good_clicked_convo_desc_v2",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "meta.user_id",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_bookmarked",
"length": 1
},
{
"dtype": "int64_list",
"feature_name": "recap.engagement.is_shared",
"length": 1
},
{
"dtype": "float_list",
"feature_name": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings",
"length": 200
},
{
"dtype": "float_list",
"feature_name": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings",
"length": 200
},
{
"dtype": "float_list",
"feature_name": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings",
"length": 200
}
]
}

View file

@ -0,0 +1,477 @@
training:
num_train_steps: 10
num_eval_steps: 5
checkpoint_every_n: 5
train_log_every_n: 1
eval_log_every_n: 1
save_dir: ${HOME}/tmp/runs/recap_local_debug
eval_timeout_in_s: 7200
model:
backbone:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config:
mask_blocks:
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
- aggregation_size: 1024
input_layer_norm: false
output_size: 1024
reduction_factor: null
mlp:
batch_norm: null
dropout: null
final_layer_activation: true
layer_sizes:
- 2048
use_parallel: true
mlp_config: null
pos_weight: 1.0
featurization_config:
clip_log1p_abs_config: null
double_norm_log_config:
batch_norm_config:
affine: true
momentum: 0.01
clip_magnitude: 5.0
layer_norm_config:
axis: -1
center: true
epsilon: 0.0
scale: true
feature_names_to_concat:
- binary
log1p_abs_config: null
z_score_log_config: null
large_embeddings: null
multi_task_type: share_all
position_debias_config: null
small_embeddings: null
stratifiers: null
tasks:
recap.engagement.is_favorited:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_good_clicked_convo_desc_v2:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_negative_feedback_v2:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout:
rate: 0.1
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_profile_clicked_and_profile_engaged:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_replied:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_replied_reply_engaged_by_author:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_report_tweet_clicked:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout:
rate: 0.2
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_retweeted:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
recap.engagement.is_video_playback_50:
affine_map: null
dcn_config: null
dlrm_config: null
mask_net_config: null
mlp_config:
batch_norm:
affine: false
momentum: 0.1
dropout: null
final_layer_activation: false
layer_sizes:
- 256
- 128
- 1
pos_weight: 1.0
train_data:
global_batch_size: 128
dataset_service_compression: AUTO
inputs: &data_root "${HOME}/tmp/recap_local_random_data/*.gz"
seg_dense_schema: &seg_dense_schema
schema_path: "${TML_BASE}/projects/home/recap/config/home_recap_2022/segdense.json"
renamed_features:
"continuous": "home_recap_2022_cont__segdense_vals"
"binary": "home_recap_2022_binary__segdense_vals"
"discrete": "home_recap_2022_discrete__segdense_vals"
"author_embedding": "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings"
"user_embedding": "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings"
"user_eng_embedding": "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings"
"meta__author_id": "meta.author_id"
"meta__user_id": "meta.user_id"
"meta__tweet_id": "meta.tweet_id"
tasks: &data_tasks
"recap.engagement.is_bookmarked": {}
"recap.engagement.is_favorited": {}
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied": {}
"recap.engagement.is_good_clicked_convo_desc_v2": {}
"recap.engagement.is_negative_feedback_v2": {}
"recap.engagement.is_profile_clicked_and_profile_engaged": {}
"recap.engagement.is_replied": {}
"recap.engagement.is_replied_reply_engaged_by_author": {}
"recap.engagement.is_report_tweet_clicked": {}
"recap.engagement.is_retweeted": {}
"recap.engagement.is_shared": {}
"recap.engagement.is_tweet_detail_dwelled_15_sec": {}
"recap.engagement.is_video_playback_50": {}
preprocess: &preprocess
truncate_and_slice:
continuous_feature_truncation: 2117
binary_feature_truncation: 59
validation_data:
validation: &validation
global_batch_size: &eval_batch_size 128
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks: *data_tasks
preprocess: *preprocess
train:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks: *data_tasks
preprocess: *preprocess
recap.engagement.is_favorited:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_favorited":
pos_downsampling_rate: 0.8387
neg_downsampling_rate: 0.01
evaluation_tasks:
- "recap.engagement.is_favorited"
preprocess: *preprocess
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_good_clicked_convo_desc_favorited_or_replied":
pos_downsampling_rate: 0.9164
neg_downsampling_rate: 0.00195
evaluation_tasks:
- "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied"
preprocess: *preprocess
recap.engagement.is_good_clicked_convo_desc_v2:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_good_clicked_convo_desc_v2":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.00174
evaluation_tasks:
- "recap.engagement.is_good_clicked_convo_desc_v2"
preprocess: *preprocess
recap.engagement.is_negative_feedback_v2:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_negative_feedback_v2":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.00280
evaluation_tasks:
- "recap.engagement.is_negative_feedback_v2"
preprocess: *preprocess
recap.engagement.is_profile_clicked_and_profile_engaged:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_profile_clicked_and_profile_engaged":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.0015
evaluation_tasks:
- "recap.engagement.is_profile_clicked_and_profile_engaged"
preprocess: *preprocess
recap.engagement.is_replied:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_replied":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.005
evaluation_tasks:
- "recap.engagement.is_replied"
preprocess: *preprocess
recap.engagement.is_replied_reply_engaged_by_author:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_replied_reply_engaged_by_author":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.001
evaluation_tasks:
- "recap.engagement.is_replied_reply_engaged_by_author"
preprocess: *preprocess
recap.engagement.is_report_tweet_clicked:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_report_tweet_clicked":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.000014
evaluation_tasks:
- "recap.engagement.is_report_tweet_clicked"
preprocess: *preprocess
recap.engagement.is_retweeted:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_retweeted":
pos_downsampling_rate: 0.9561
neg_downsampling_rate: 0.004
evaluation_tasks:
- "recap.engagement.is_retweeted"
preprocess: *preprocess
recap.engagement.is_video_playback_50:
global_batch_size: *eval_batch_size
inputs: *data_root
seg_dense_schema: *seg_dense_schema
tasks:
<<: *data_tasks
"recap.engagement.is_video_playback_50":
pos_downsampling_rate: 1.0
neg_downsampling_rate: 0.00427
evaluation_tasks:
- "recap.engagement.is_video_playback_50"
preprocess: *preprocess
optimizer:
adam:
beta_1: 0.95
beta_2: 0.999
epsilon: 1.0e-07
multi_task_learning_rates:
backbone_learning_rate:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
tower_learning_rates:
recap.engagement.is_favorited:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0008
num_ramp_steps: 5000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_good_clicked_convo_desc_favorited_or_replied:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 2000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_good_clicked_convo_desc_v2:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0002
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_negative_feedback_v2:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0005
num_ramp_steps: 5000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_profile_clicked_and_profile_engaged:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0003
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_replied:
constant: null
linear_ramp_to_constant:
learning_rate: 0.001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_replied_reply_engaged_by_author:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_report_tweet_clicked:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 3000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_retweeted:
constant: null
linear_ramp_to_constant:
learning_rate: 0.0001
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
recap.engagement.is_video_playback_50:
constant: null
linear_ramp_to_constant:
learning_rate: 0.003
num_ramp_steps: 1000
linear_ramp_to_cosine: null
piecewise_constant: null
single_task_learning_rate: null

View file

View file

@ -0,0 +1,246 @@
import typing
from enum import Enum
from tml.core import config as base_config
import pydantic
class ExplicitDateInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_date and days of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_date: str = pydantic.Field(..., description="Data end date, inclusive.")
days: int = pydantic.Field(..., description="Number of days of data for dataset.")
num_missing_days_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_days_tol days of missing data."
)
class ExplicitDatetimeInputs(base_config.BaseConfig):
"""Arguments to select train/validation data using end_datetime and hours of data."""
data_root: str = pydantic.Field(..., description="Data path prefix.")
end_datetime: str = pydantic.Field(..., description="Data end datetime, inclusive.")
hours: int = pydantic.Field(..., description="Number of hours of data for dataset.")
num_missing_hours_tol: int = pydantic.Field(
0, description="We tolerate <= num_missing_hours_tol hours of missing data."
)
class DdsCompressionOption(str, Enum):
"""The only valid compression option is 'AUTO'"""
AUTO = "AUTO"
class DatasetConfig(base_config.BaseConfig):
inputs: str = pydantic.Field(
None, description="A glob for selecting data.", one_of="date_inputs_format"
)
explicit_datetime_inputs: ExplicitDatetimeInputs = pydantic.Field(
None, one_of="date_inputs_format"
)
explicit_date_inputs: ExplicitDateInputs = pydantic.Field(None, one_of="date_inputs_format")
global_batch_size: pydantic.PositiveInt
num_files_to_keep: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to keep."
)
repeat_files: bool = pydantic.Field(
True, description="DEPRICATED. Files are repeated no matter what this is set to."
)
file_batch_size: pydantic.PositiveInt = pydantic.Field(16, description="File batch size")
cache: bool = pydantic.Field(
False,
description="Cache dataset in memory. Careful to only use this when you"
" have enough memory to fit entire dataset.",
)
data_service_dispatcher: str = pydantic.Field(None)
ignore_data_errors: bool = pydantic.Field(
False, description="Whether to ignore tf.data errors. DANGER DANGER, may wedge jobs."
)
dataset_service_compression: DdsCompressionOption = pydantic.Field(
None,
description="Compress the dataset for DDS worker -> training host. Disabled by default and the only valid option is 'AUTO'",
)
# tf.data.Dataset options
examples_shuffle_buffer_size: int = pydantic.Field(1024, description="Size of shuffle buffers.")
map_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of parallel calls."
)
interleave_num_parallel_calls: pydantic.PositiveInt = pydantic.Field(
None, description="Number of shards to interleave."
)
class TruncateAndSlice(base_config.BaseConfig):
# Apply truncation and then slice.
continuous_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates continuous features to this amount for efficiency."
)
binary_feature_truncation: pydantic.PositiveInt = pydantic.Field(
None, description="Experimental. Truncates binary features to this amount for efficiency."
)
continuous_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input continuous features."
)
binary_feature_mask_path: str = pydantic.Field(
None, description="Path of mask used to slice input binary features."
)
class DataType(str, Enum):
BFLOAT16 = "bfloat16"
BOOL = "bool"
FLOAT32 = "float32"
FLOAT16 = "float16"
UINT8 = "uint8"
class DownCast(base_config.BaseConfig):
# Apply down casting to selected features.
features: typing.Dict[str, DataType] = pydantic.Field(
None, description="Map features to down cast data types."
)
class TaskData(base_config.BaseConfig):
pos_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of positives used to generate dataset.",
)
neg_downsampling_rate: float = pydantic.Field(
1.0,
description="Downsampling rate of negatives used to generate dataset.",
)
class SegDenseSchema(base_config.BaseConfig):
schema_path: str = pydantic.Field(..., description="Path to feature config json.")
features: typing.List[str] = pydantic.Field(
[],
description="List of features (in addition to the renamed features) to read from schema path above.",
)
renamed_features: typing.Dict[str, str] = pydantic.Field(
{}, description="Dictionary of renamed features."
)
mask_mantissa_features: typing.Dict[str, int] = pydantic.Field(
{},
description="(experimental) Number of mantissa bits to mask to simulate lower precision data.",
)
class RectifyLabels(base_config.BaseConfig):
label_rectification_window_in_hours: float = pydantic.Field(
3.0, description="overlap time in hours for which to flip labels"
)
served_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to served time"
)
impressed_timestamp_field: str = pydantic.Field(
..., description="input field corresponding to impressed time"
)
label_to_engaged_timestamp_field: typing.Dict[str, str] = pydantic.Field(
..., description="label to the input field corresponding to engagement time"
)
class ExtractFeaturesRow(base_config.BaseConfig):
name: str = pydantic.Field(
...,
description="name of the new field name to be created",
)
source_tensor: str = pydantic.Field(
...,
description="name of the dense tensor to look for the feature",
)
index: int = pydantic.Field(
...,
description="index of the feature in the dense tensor",
)
class ExtractFeatures(base_config.BaseConfig):
extract_feature_table: typing.List[ExtractFeaturesRow] = pydantic.Field(
[],
description="list of features to be extracted with their name, source tensor and index",
)
class DownsampleNegatives(base_config.BaseConfig):
batch_multiplier: int = pydantic.Field(
None,
description="batch multiplier",
)
engagements_list: typing.List[str] = pydantic.Field(
[],
description="engagements with kept positives",
)
num_engagements: int = pydantic.Field(
...,
description="number engagements used in the model, including ones excluded in engagements_list",
)
class Preprocess(base_config.BaseConfig):
truncate_and_slice: TruncateAndSlice = pydantic.Field(None, description="Truncation and slicing.")
downcast: DownCast = pydantic.Field(None, description="Down cast to features.")
rectify_labels: RectifyLabels = pydantic.Field(
None, description="Rectify labels for a given overlap window"
)
extract_features: ExtractFeatures = pydantic.Field(
None, description="Extract features from dense tensors."
)
downsample_negatives: DownsampleNegatives = pydantic.Field(
None, description="Downsample negatives."
)
class Sampler(base_config.BaseConfig):
"""Assumes function is defined in data/samplers.py.
Only use this for quick experimentation.
If samplers are useful, we should sample from upstream data generation.
DEPRICATED, DO NOT USE.
"""
name: str
kwargs: typing.Dict
class RecapDataConfig(DatasetConfig):
seg_dense_schema: SegDenseSchema
tasks: typing.Dict[str, TaskData] = pydantic.Field(
description="Description of individual tasks in this dataset."
)
evaluation_tasks: typing.List[str] = pydantic.Field(
[], description="If specified, lists the tasks we're generating metrics for."
)
preprocess: Preprocess = pydantic.Field(
None, description="Function run in tf.data.Dataset at train/eval, in-graph at inference."
)
sampler: Sampler = pydantic.Field(
None,
description="""DEPRICATED, DO NOT USE. Sampling function for offline experiments.""",
)
@pydantic.root_validator()
def _validate_evaluation_tasks(cls, values):
if values.get("evaluation_tasks") is not None:
for task in values["evaluation_tasks"]:
if task not in values["tasks"]:
raise KeyError(f"Evaluation task {task} must be in tasks. Received {values['tasks']}")
return values

View file

@ -0,0 +1,477 @@
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Dict
import functools
import torch
import tensorflow as tf
from tml.common.batch import DataclassBatch
from tml.projects.home.recap.data.config import RecapDataConfig, TaskData
from tml.projects.home.recap.data import preprocessors
from tml.projects.home.recap.config import JobMode
from tml.projects.home.recap.data.tfe_parsing import get_seg_dense_parse_fn
from tml.projects.home.recap.data.util import (
keyed_jagged_tensor_from_tensors_dict,
sparse_or_dense_tf_to_torch,
)
from absl import logging
import torch.distributed as dist
@dataclass
class RecapBatch(DataclassBatch):
"""Holds features and labels from the Recap dataset."""
continuous_features: torch.Tensor
binary_features: torch.Tensor
discrete_features: torch.Tensor
sparse_features: "KeyedJaggedTensor" # type: ignore[name-defined] # noqa: F821
labels: torch.Tensor
user_embedding: torch.Tensor = None
user_eng_embedding: torch.Tensor = None
author_embedding: torch.Tensor = None
weights: torch.Tensor = None
def __post_init__(self):
if self.weights is None:
self.weights = torch.ones_like(self.labels)
for feature_name, feature_value in self.as_dict().items():
if ("embedding" in feature_name) and (feature_value is None):
setattr(self, feature_name, torch.empty([0, 0]))
def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> RecapBatch:
"""Converts a torch data loader output into `RecapBatch`."""
x = tf.nest.map_structure(functools.partial(sparse_or_dense_tf_to_torch, pin_memory=False), x)
try:
features_in, labels = x
except ValueError:
# For Mode.INFERENCE, we do not expect to recieve labels as part of the input tuple
features_in, labels = x, None
sparse_features = keyed_jagged_tensor_from_tensors_dict({})
if sparse_feature_names:
sparse_features = keyed_jagged_tensor_from_tensors_dict(
{embedding_name: features_in[embedding_name] for embedding_name in sparse_feature_names}
)
user_embedding, user_eng_embedding, author_embedding = None, None, None
if "user_embedding" in features_in:
if sparse_feature_names and "meta__user_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
user_embedding = features_in["user_embedding"]
if "user_eng_embedding" in features_in:
if sparse_feature_names and "meta__user_eng_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
user_eng_embedding = features_in["user_eng_embedding"]
if "author_embedding" in features_in:
if sparse_feature_names and "meta__author_id" in sparse_feature_names:
raise ValueError("Only one source of embedding for user is supported")
else:
author_embedding = features_in["author_embedding"]
return RecapBatch(
continuous_features=features_in["continuous"],
binary_features=features_in["binary"],
discrete_features=features_in["discrete"],
sparse_features=sparse_features,
user_embedding=user_embedding,
user_eng_embedding=user_eng_embedding,
author_embedding=author_embedding,
labels=labels,
weights=features_in.get("weights", None), # Defaults to torch.ones_like(labels)
)
def _chain(param, f1, f2):
"""
Reduce multiple functions into one chained function
_chain(x, f1, f2) -> f2(f1(x))
"""
output = param
fns = [f1, f2]
for f in fns:
output = f(output)
return output
def _add_weights(inputs, tasks: Dict[str, TaskData]):
"""Adds weights based on label sampling for positive and negatives.
This is useful for numeric calibration etc. This mutates inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Returns:
A tuple of features and labels; weights are added to features.
"""
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
return inputs
def get_datetimes(explicit_datetime_inputs):
"""Compute list datetime strings for train/validation data."""
datetime_format = "%Y/%m/%d/%H"
end = datetime.strptime(explicit_datetime_inputs.end_datetime, datetime_format)
dates = sorted(
[
(end - timedelta(hours=i + 1)).strftime(datetime_format)
for i in range(int(explicit_datetime_inputs.hours))
]
)
return dates
def get_explicit_datetime_inputs_files(explicit_datetime_inputs):
"""
Compile list of files for training/validation.
Used with DataConfigs that use the `explicit_datetime_inputs` format to specify data.
For each hour of data, if the directory is missing or empty, we increment a counter to keep
track of the number of missing data hours.
Returns only files with a `.gz` extension.
Args:
explicit_datetime_inputs: An `ExplicitDatetimeInputs` object within a `datasets.DataConfig` object
Returns:
data_files: Sorted list of files to read corresponding to data at the desired datetimes
num_hours_missing: Number of hours that we are missing data
"""
datetimes = get_datetimes(explicit_datetime_inputs)
folders = [os.path.join(explicit_datetime_inputs.data_root, datetime) for datetime in datetimes]
data_files = []
num_hours_missing = 0
for folder in folders:
try:
files = tf.io.gfile.listdir(folder)
if not files:
logging.warning(f"{folder} contained no data files")
num_hours_missing += 1
data_files.extend(
[
os.path.join(folder, filename)
for filename in files
if filename.rsplit(".", 1)[-1].lower() == "gz"
]
)
except tf.errors.NotFoundError as e:
num_hours_missing += 1
logging.warning(f"Cannot find directory {folder}. Missing one hour of data. Error: \n {e}")
return sorted(data_files), num_hours_missing
def _map_output_for_inference(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
if preprocessor:
raise ValueError("No preprocessor should be used at inference time.")
if add_weights:
raise NotImplementedError()
# Add zero weights.
inputs["weights"] = tf.zeros_like(tf.expand_dims(inputs["continuous"][:, 0], -1))
for label in tasks:
del inputs[label]
return inputs
def _map_output_for_train_eval(
inputs, tasks: Dict[str, TaskData], preprocessor: tf.keras.Model = None, add_weights: bool = False
):
if add_weights:
inputs = _add_weights_based_on_sampling_rates(inputs, tasks)
# Warning this has to happen first as it changes the input
if preprocessor:
inputs = preprocessor(inputs)
label_values = tf.squeeze(tf.stack([inputs[label] for label in tasks], axis=1), axis=[-1])
for label in tasks:
del inputs[label]
return inputs, label_values
def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskData]):
"""Adds weights based on label sampling for positive and negatives.
This is useful for numeric calibration etc. This mutates inputs.
Args:
inputs: A dictionary of strings to tensor-like structures.
tasks: A dict of string (label) to `TaskData` specifying inputs.
Returns:
A tuple of features and labels; weights are added to features.
"""
weights = []
for key, task in tasks.items():
label = inputs[key]
float_label = tf.cast(label, tf.float32)
weights.append(
float_label / task.pos_downsampling_rate + (1.0 - float_label) / task.neg_downsampling_rate
)
# Ensure we are batch-major (assumes we batch before this call).
inputs["weights"] = tf.squeeze(tf.transpose(tf.convert_to_tensor(weights)), axis=0)
return inputs
class RecapDataset(torch.utils.data.IterableDataset):
def __init__(
self,
data_config: RecapDataConfig,
dataset_service: Optional[str] = None,
mode: JobMode = JobMode.TRAIN,
compression: Optional[str] = "AUTO",
repeat: bool = False,
vocab_mapper: tf.keras.Model = None,
):
logging.info("***** Labels *****")
logging.info(list(data_config.tasks.keys()))
self._data_config = data_config
self._parse_fn = get_seg_dense_parse_fn(data_config)
self._mode = mode
self._repeat = repeat
self._num_concurrent_iterators = 1
self._vocab_mapper = vocab_mapper
self.dataset_service = dataset_service
preprocessor = None
self._batch_size_multiplier = 1
if data_config.preprocess:
preprocessor = preprocessors.build_preprocess(data_config.preprocess, mode=mode)
if data_config.preprocess.downsample_negatives:
self._batch_size_multiplier = data_config.preprocess.downsample_negatives.batch_multiplier
self._preprocessor = preprocessor
if mode == JobMode.INFERENCE:
if preprocessor is not None:
raise ValueError("Expect no preprocessor at inference time.")
should_add_weights = False
output_map_fn = _map_output_for_inference # (features,)
else:
# Only add weights if there is a reason to! If all weights will
# be equal to 1.0, save bandwidth between DDS and Chief by simply
# relying on the fact that weights default to 1.0 in `RecapBatch`
# WARNING: Weights may still be added as a side effect of a preprocessor
# such as `DownsampleNegatives`.
should_add_weights = any(
[
task_cfg.pos_downsampling_rate != 1.0 or task_cfg.neg_downsampling_rate != 1.0
for task_cfg in data_config.tasks.values()
]
)
output_map_fn = _map_output_for_train_eval # (features, labels)
self._output_map_fn = functools.partial(
output_map_fn,
tasks=data_config.tasks,
preprocessor=preprocessor,
add_weights=should_add_weights,
)
sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
self._tf_dataset = self._create_tf_dataset()
self._init_tensor_spec()
def _init_tensor_spec(self):
def _tensor_spec_to_torch_shape(spec):
if spec.shape is None:
return None
shape = [x if x is not None else -1 for x in spec.shape]
return torch.Size(shape)
self.torch_element_spec = tf.nest.map_structure(
_tensor_spec_to_torch_shape, self._tf_dataset.element_spec
)
def _create_tf_dataset(self):
if hasattr(self, "_tf_dataset"):
raise ValueError("Do not call `_create_tf_dataset` more than once.")
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_replica_bsz = (
self._batch_size_multiplier * self._data_config.global_batch_size // world_size
)
dataset: tf.data.Dataset = self._create_base_tf_dataset(
batch_size=per_replica_bsz,
)
if self._repeat:
logging.info("Repeating dataset")
dataset = dataset.repeat()
if self.dataset_service:
if self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self.num_concurrent_iterators,
on_chief=False,
)
self.dataset_id, self.job_name = register_dataset(
dataset=dataset, dataset_service=self.dataset_service, compression=self.compression
)
dataset = distribute_from_dataset_id(
dataset_id=self.dataset_id, # type: ignore[arg-type]
job_name=self.job_name,
dataset_service=self.dataset_service,
compression=self.compression,
)
elif self._num_concurrent_iterators > 1:
if not self.machines_config:
raise ValueError(
"Must supply a machine_config for autotuning in order to use >1 concurrent iterators"
)
dataset = dataset_lib.with_auto_tune_budget(
dataset,
machine_config=self.machines_config.chief,
num_concurrent_iterators=self._num_concurrent_iterators,
on_chief=True,
)
# Vocabulary mapping happens on the training node, not in dds because of size.
if self._vocab_mapper:
dataset = dataset.map(self._vocab_mapper)
return dataset.prefetch(world_size * 2)
def _create_base_tf_dataset(self, batch_size: int):
if self._data_config.inputs:
glob = self._data_config.inputs
filenames = sorted(tf.io.gfile.glob(glob))
elif self._data_config.explicit_datetime_inputs:
num_missing_hours_tol = self._data_config.explicit_datetime_inputs.num_missing_hours_tol
filenames, num_hours_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_datetime_inputs,
increment="hourly",
)
if num_hours_missing > num_missing_hours_tol:
raise ValueError(
f"We are missing {num_hours_missing} hours of data"
f"more than tolerance {num_missing_hours_tol}."
)
elif self._data_config.explicit_date_inputs:
num_missing_days_tol = self._data_config.explicit_date_inputs.num_missing_days_tol
filenames, num_days_missing = get_explicit_datetime_inputs_files(
self._data_config.explicit_date_inputs,
increment="daily",
)
if num_days_missing > num_missing_days_tol:
raise ValueError(
f"We are missing {num_days_missing} days of data"
f"more than tolerance {num_missing_days_tol}."
)
else:
raise ValueError(
"Must specifiy either `inputs`, `explicit_datetime_inputs`, or `explicit_date_inputs` in data_config"
)
num_files = len(filenames)
logging.info(f"Found {num_files} data files")
if num_files < 1:
raise ValueError("No data files found")
if self._data_config.num_files_to_keep is not None:
filenames = filenames[: self._data_config.num_files_to_keep]
logging.info(f"Retaining only {len(filenames)} files.")
filenames_ds = (
tf.data.Dataset.from_tensor_slices(filenames).shuffle(len(filenames))
# Because of drop_remainder, if our dataset does not fill
# up a batch, it will emit nothing without this repeat.
.repeat(-1)
)
if self._data_config.file_batch_size:
filenames_ds = filenames_ds.batch(self._data_config.file_batch_size)
def per_shard_dataset(filename):
ds = tf.data.TFRecordDataset([filename], compression_type="GZIP")
return ds.prefetch(4)
ds = filenames_ds.interleave(
per_shard_dataset,
block_length=4,
deterministic=False,
num_parallel_calls=self._data_config.interleave_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
# Combine functions into one map call to reduce overhead.
map_fn = functools.partial(
_chain,
f1=self._parse_fn,
f2=self._output_map_fn,
)
# Shuffle -> Batch -> Parse is the correct ordering
# Shuffling needs to be performed before batching otherwise there is not much point
# Batching happens before parsing because tf.Example parsing is actually vectorized
# and works much faster overall on batches of data.
ds = (
# DANGER DANGER: there is a default shuffle size here.
ds.shuffle(self._data_config.examples_shuffle_buffer_size)
.batch(batch_size=batch_size, drop_remainder=True)
.map(
map_fn,
num_parallel_calls=self._data_config.map_num_parallel_calls
or tf.data.experimental.AUTOTUNE,
)
)
if self._data_config.cache:
ds = ds.cache()
if self._data_config.ignore_data_errors:
ds = ds.apply(tf.data.experimental.ignore_errors())
options = tf.data.Options()
options.experimental_deterministic = False
ds = ds.with_options(options)
return ds
def _gen(self):
for x in self._tf_dataset:
yield to_batch(x)
def to_dataloader(self) -> Dict[str, torch.Tensor]:
return torch.utils.data.DataLoader(self, batch_size=None)
def __iter__(self):
return iter(self._gen())

View file

@ -0,0 +1,81 @@
import os
import json
from absl import app, flags, logging
import tensorflow as tf
from typing import Dict
from tml.projects.home.recap.data import tfe_parsing
from tml.core import config as tml_config_mod
import tml.projects.home.recap.config as recap_config_mod
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.")
FLAGS = flags.FLAGS
def _generate_random_example(
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
) -> Dict[str, tf.Tensor]:
example = {}
for feature_name, feature_spec in tf_example_schema.items():
dtype = feature_spec.dtype
if (dtype == tf.int64) or (dtype == tf.int32):
x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype)
elif (dtype == tf.float32) or (dtype == tf.float64):
x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype)
else:
raise NotImplementedError(f"Unknown type {dtype}")
example[feature_name] = x
return example
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
feature = {}
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
for feature_name, tensor in x.items():
feature[feature_name] = serializers[tensor.dtype](tensor)
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = tfe_parsing.create_tf_example_schema(
config.train_data,
seg_dense_schema,
)
record_filename = os.path.join(data_path, "random.tfrecord.gz")
with tf.io.TFRecordWriter(record_filename, "GZIP") as writer:
random_example = _generate_random_example(tf_example_schema)
serialized_example = _serialize_example(random_example)
writer.write(serialized_example)
def _generate_data_main(unused_argv):
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
# Find the path where to put the data
data_path = os.path.dirname(config.train_data.inputs)
logging.info("Putting random data in %s", data_path)
generate_data(data_path, config)
if __name__ == "__main__":
app.run(_generate_data_main)

View file

@ -0,0 +1,206 @@
"""
Preprocessors applied on DDS workers in order to modify the dataset on the fly.
Some of these preprocessors are also applied to the model at serving time.
"""
from tml.projects.home.recap import config as config_mod
from absl import logging
import tensorflow as tf
import numpy as np
class TruncateAndSlice(tf.keras.Model):
"""Class for truncating and slicing."""
def __init__(self, truncate_and_slice_config):
super().__init__()
self._truncate_and_slice_config = truncate_and_slice_config
if self._truncate_and_slice_config.continuous_feature_mask_path:
with tf.io.gfile.GFile(
self._truncate_and_slice_config.continuous_feature_mask_path, "rb"
) as f:
self._continuous_mask = np.load(f).nonzero()[0]
logging.info(f"Slicing {np.sum(self._continuous_mask)} continuous features.")
else:
self._continuous_mask = None
if self._truncate_and_slice_config.binary_feature_mask_path:
with tf.io.gfile.GFile(self._truncate_and_slice_config.binary_feature_mask_path, "rb") as f:
self._binary_mask = np.load(f).nonzero()[0]
logging.info(f"Slicing {np.sum(self._binary_mask)} binary features.")
else:
self._binary_mask = None
def call(self, inputs, training=None, mask=None):
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
if self._truncate_and_slice_config.continuous_feature_truncation:
logging.info("Truncating continuous")
outputs["continuous"] = outputs["continuous"][
:, : self._truncate_and_slice_config.continuous_feature_truncation
]
if self._truncate_and_slice_config.binary_feature_truncation:
logging.info("Truncating binary")
outputs["binary"] = outputs["binary"][
:, : self._truncate_and_slice_config.binary_feature_truncation
]
if self._continuous_mask is not None:
outputs["continuous"] = tf.gather(outputs["continuous"], self._continuous_mask, axis=1)
if self._binary_mask is not None:
outputs["binary"] = tf.gather(outputs["binary"], self._binary_mask, axis=1)
return outputs
class DownCast(tf.keras.Model):
"""Class for Down casting dataset before serialization and transferring to training host.
Depends on the data type and the actual data range, the down casting can be lossless or not.
It is strongly recommended to compare the metrics before and after down casting.
"""
def __init__(self, downcast_config):
super().__init__()
self.config = downcast_config
self._type_map = {
"bfloat16": tf.bfloat16,
"bool": tf.bool,
}
def call(self, inputs, training=None, mask=None):
outputs = tf.nest.pack_sequence_as(inputs, tf.nest.flatten(inputs))
for feature, type_str in self.config.features.items():
assert type_str in self._type_map
if type_str == "bfloat16":
logging.warning(
"Although bfloat16 and float32 have the same number of exponent bits, this down casting is not 100% lossless. Please double check metrics."
)
down_cast_data_type = self._type_map[type_str]
outputs[feature] = tf.cast(outputs[feature], dtype=down_cast_data_type)
return outputs
class RectifyLabels(tf.keras.Model):
"""Class for rectifying labels"""
def __init__(self, rectify_label_config):
super().__init__()
self._config = rectify_label_config
self._window = int(self._config.label_rectification_window_in_hours * 60 * 60 * 1000)
def call(self, inputs, training=None, mask=None):
served_ts_field = self._config.served_timestamp_field
impressed_ts_field = self._config.impressed_timestamp_field
for label, engaged_ts_field in self._config.label_to_engaged_timestamp_field.items():
impressed = inputs[impressed_ts_field]
served = inputs[served_ts_field]
engaged = inputs[engaged_ts_field]
keep = tf.math.logical_and(inputs[label] > 0, impressed - served < self._window)
keep = tf.math.logical_and(keep, engaged - served < self._window)
inputs[label] = tf.where(keep, inputs[label], tf.zeros_like(inputs[label]))
return inputs
class ExtractFeatures(tf.keras.Model):
"""Class for extracting individual features from dense tensors by their index."""
def __init__(self, extract_features_config):
super().__init__()
self._config = extract_features_config
def call(self, inputs, training=None, mask=None):
for row in self._config.extract_feature_table:
inputs[row.name] = inputs[row.source_tensor][:, row.index]
return inputs
class DownsampleNegatives(tf.keras.Model):
"""Class for down-sampling/dropping negatives and updating the weights.
If inputs['fav'] = [1, 0, 0, 0] and inputs['weights'] = [1.0, 1.0, 1.0, 1.0]
inputs are transformed to inputs['fav'] = [1, 0] and inputs['weights'] = [1.0, 3.0]
when batch_multiplier=2 and engagements_list=['fav']
It supports multiple engagements (union/logical_or is used to aggregate engagements), so we don't
drop positives for any engagement.
"""
def __init__(self, downsample_negatives_config):
super().__init__()
self.config = downsample_negatives_config
def call(self, inputs, training=None, mask=None):
labels = self.config.engagements_list
# union of engagements
mask = tf.squeeze(tf.reduce_any(tf.stack([inputs[label] == 1 for label in labels], 1), 1))
n_positives = tf.reduce_sum(tf.cast(mask, tf.int32))
batch_size = tf.cast(tf.shape(inputs[labels[0]])[0] / self.config.batch_multiplier, tf.int32)
negative_weights = tf.math.divide_no_nan(
tf.cast(self.config.batch_multiplier * batch_size - n_positives, tf.float32),
tf.cast(batch_size - n_positives, tf.float32),
)
new_weights = tf.cast(mask, tf.float32) + (1 - tf.cast(mask, tf.float32)) * negative_weights
def _split_by_label_concatenate_and_truncate(input_tensor):
# takes positive examples and concatenate with negative examples and truncate
# DANGER: if n_positives > batch_size down-sampling is incorrect (do not use pb_50)
return tf.concat(
[
input_tensor[mask],
input_tensor[tf.math.logical_not(mask)],
],
0,
)[:batch_size]
if "weights" not in inputs:
# add placeholder so logic below applies even if weights aren't present in inputs
inputs["weights"] = tf.ones([tf.shape(inputs[labels[0]])[0], self.config.num_engagements])
for tensor in inputs:
if tensor == "weights":
inputs[tensor] = inputs[tensor] * tf.reshape(new_weights, [-1, 1])
inputs[tensor] = _split_by_label_concatenate_and_truncate(inputs[tensor])
return inputs
def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN):
"""Builds a preprocess model to apply all preprocessing stages."""
if mode == config_mod.JobMode.INFERENCE:
logging.info("Not building preprocessors for dataloading since we are in Inference mode.")
return None
preprocess_models = []
if preprocess_config.downsample_negatives:
preprocess_models.append(DownsampleNegatives(preprocess_config.downsample_negatives))
if preprocess_config.truncate_and_slice:
preprocess_models.append(TruncateAndSlice(preprocess_config.truncate_and_slice))
if preprocess_config.downcast:
preprocess_models.append(DownCast(preprocess_config.downcast))
if preprocess_config.rectify_labels:
preprocess_models.append(RectifyLabels(preprocess_config.rectify_labels))
if preprocess_config.extract_features:
preprocess_models.append(ExtractFeatures(preprocess_config.extract_features))
if len(preprocess_models) == 0:
raise ValueError("No known preprocessor.")
class PreprocessModel(tf.keras.Model):
def __init__(self, preprocess_models):
super().__init__()
self.preprocess_models = preprocess_models
def call(self, inputs, training=None, mask=None):
outputs = inputs
for model in self.preprocess_models:
outputs = model(outputs, training, mask)
return outputs
if len(preprocess_models) > 1:
logging.warning(
"With multiple preprocessing models, we apply these models in a predefined order. Future works may introduce customized models and orders."
)
return PreprocessModel(preprocess_models)

View file

@ -0,0 +1,129 @@
import functools
import json
from tml.projects.home.recap.data import config as recap_data_config
from absl import logging
import tensorflow as tf
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
def create_tf_example_schema(
data_config: recap_data_config.SegDenseSchema,
segdense_schema,
):
"""Generate schema for deseralizing tf.Example.
Args:
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
labels: List of strings denoting labels.
Returns:
A dictionary schema suitable for deserializing tf.Example.
"""
segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys())
used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
)
logging.info(used_features)
tfe_schema = {}
for entry in segdense_schema:
feature_name = entry["feature_name"]
if feature_name in used_features:
length = entry["length"]
dtype = entry["dtype"]
if feature_name in labels:
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
)
elif length == -1:
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
else:
tfe_schema[feature_name] = tf.io.FixedLenFeature(
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
)
for feature_name in used_features:
if feature_name not in tfe_schema:
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
return tfe_schema
@functools.lru_cache(1)
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
"""For experimentating with emulating bfloat16 or less precise types."""
mask: tf.Tensor = make_mantissa_mask(mask_length)
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
def parse_tf_example(
serialized_example,
tfe_schema,
seg_dense_schema_config,
):
"""Parse serialized tf.Example into dict of tensors.
Args:
serialized_example: Serialized tf.Example to be parsed.
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
Returns:
Dictionary of tensors to be used as model input.
"""
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
inputs[new_feature_name] = inputs.pop(old_feature_name)
# This should not actually be used except for experimentation with low precision floats.
if "mask_mantissa_features" in seg_dense_schema_config:
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
# at TF level.
# We should not return empty tensors if we dont use embeddings.
# Otherwise, it breaks numpy->pt conversion
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
for renamed_key in renamed_keys:
if "embedding" in renamed_key and (renamed_key not in inputs):
inputs[renamed_key] = tf.zeros([], tf.float32)
logging.info(f"parsed example and inputs are {inputs}")
return inputs
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
"""Placeholder for seg dense.
In the future, when we use more seg dense variations, we can change this.
"""
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = create_tf_example_schema(
data_config,
seg_dense_schema,
)
logging.info("***** TF Example Schema *****")
logging.info(tf_example_schema)
parse = functools.partial(
parse_tf_example,
tfe_schema=tf_example_schema,
seg_dense_schema_config=data_config.seg_dense_schema,
)
return parse

View file

@ -0,0 +1,120 @@
from typing import Mapping, Tuple, Union
import torch
import torchrec
import numpy as np
import tensorflow as tf
def keyed_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedTensor":
"""
Convert a dictionary of torch tensor to torchrec keyed tensor
Args:
tensor_map:
Returns:
"""
keys = list(tensor_map.keys())
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1].
values = [
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
for key in keys
]
return torchrec.KeyedTensor.from_tensor_list(keys, values)
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if tensor.is_sparse:
x = tensor.coalesce() # Ensure that the indices are ordered.
lengths = torch.bincount(x.indices()[0])
values = x.values()
else:
values = tensor
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
return values, lengths
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
"""
Convert a torch tensor to torchrec jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
dense_shape of the sparse tensor can be arbitrary.
Args:
tensor: a torch (sparse) tensor.
Returns:
"""
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
return torchrec.JaggedTensor(values=values, lengths=lengths)
def keyed_jagged_tensor_from_tensors_dict(
tensor_map: Mapping[str, torch.Tensor]
) -> "torchrec.KeyedJaggedTensor":
"""
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
dense_shape of the sparse tensor can be arbitrary.
Args:
tensor_map:
Returns:
"""
if not tensor_map:
return torchrec.KeyedJaggedTensor(
keys=[],
values=torch.zeros(0, dtype=torch.int),
lengths=torch.zeros(0, dtype=torch.int),
)
values = []
lengths = []
for tensor in tensor_map.values():
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
values.append(torch.squeeze(tensor_val))
lengths.append(tensor_len)
values = torch.cat(values, axis=0)
lengths = torch.cat(lengths, axis=0)
return torchrec.KeyedJaggedTensor(
keys=list(tensor_map.keys()),
values=values,
lengths=lengths,
)
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
return tf_tensor._numpy() # noqa
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
tensor = _tf_to_numpy(tensor)
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
if tensor.dtype.name == "bfloat16":
tensor = tensor.astype(np.float32)
tensor = torch.from_numpy(tensor)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
def sparse_or_dense_tf_to_torch(
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
) -> torch.Tensor:
if isinstance(tensor, tf.SparseTensor):
tensor = torch.sparse_coo_tensor(
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
_dense_tf_to_torch(tensor.values, pin_memory),
torch.Size(_tf_to_numpy(tensor.dense_shape)),
)
else:
tensor = _dense_tf_to_torch(tensor, pin_memory)
return tensor

View file

@ -0,0 +1,94 @@
from typing import List, Optional
import tml.core.config as base_config
from tml.optimizers import config as optimizer_config
import pydantic
class EmbeddingSnapshot(base_config.BaseConfig):
"""Configuration for Embedding snapshot"""
emb_name: str = pydantic.Field(
..., description="Name of the embedding table from the loaded snapshot"
)
embedding_snapshot_uri: str = pydantic.Field(
..., description="Path to torchsnapshot of the embedding"
)
# https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_configs.EmbeddingBagConfig
class EmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties")
vocab: str = pydantic.Field(
None, description="Directory to parquet files of mapping from entity ID to table index."
)
class EmbeddingOptimizerConfig(base_config.BaseConfig):
learning_rate: optimizer_config.LearningRate = pydantic.Field(
None, description="learning rate scheduler for the EBC"
)
init_learning_rate: float = pydantic.Field(description="initial learning rate for the EBC")
# NB: Only sgd is supported right now and implicitly.
# FBGemm only supports simple exact_sgd which only takes LR as an argument.
class LargeEmbeddingsConfig(base_config.BaseConfig):
"""Configuration for EmbeddingBagCollection.
The tables listed in this config are gathered into a single torchrec EmbeddingBagCollection.
"""
tables: List[EmbeddingBagConfig] = pydantic.Field(..., description="list of embedding tables")
optimizer: EmbeddingOptimizerConfig
tables_to_log: List[str] = pydantic.Field(
None, description="list of embedding table names that we want to log during training"
)
class StratifierConfig(base_config.BaseConfig):
name: str
index: int
value: int
class SmallEmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for SmallEmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
class SmallEmbeddingBagConfig(base_config.BaseConfig):
"""Configuration for SmallEmbeddingBag."""
name: str = pydantic.Field(..., description="name of embedding bag")
num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary")
embedding_dim: int = pydantic.Field(..., description="size of each embedding vector")
index: int = pydantic.Field(..., description="index in the discrete tensor to look for")
class SmallEmbeddingsConfig(base_config.BaseConfig):
"""Configuration for SmallEmbeddingConfig.
Here we can use discrete features that already are present in our TFRecords generated using
segdense conversion as "home_recap_2022_discrete__segdense_vals" which are available in
the model as "discrete_features", and embed a user-defined set of them with configurable
dimensions and vocabulary sizes.
Compared with LargeEmbedding, this config is for small embedding tables that can fit inside
the model, whereas LargeEmbedding usually is meant to be hydrated outside the model at
serving time due to size (>>1 GB).
This small embeddings table uses the same optimizer as the rest of the model."""
tables: List[SmallEmbeddingBagConfig] = pydantic.Field(
..., description="list of embedding tables"
)

108
projects/home/recap/main.py Normal file
View file

@ -0,0 +1,108 @@
import datetime
import os
from typing import Callable, List, Optional, Tuple
import tensorflow as tf
import tml.common.checkpointing.snapshot as snapshot_lib
from tml.common.device import setup_and_get_device
from tml.core import config as tml_config_mod
import tml.core.custom_training_loop as ctl
from tml.core import debug_training_loop
from tml.core import losses
from tml.core.loss_type import LossType
from tml.model import maybe_shard_model
import tml.projects.home.recap.data.dataset as ds
import tml.projects.home.recap.config as recap_config_mod
import tml.projects.home.recap.optimizer as optimizer_mod
# from tml.projects.home.recap import feature
import tml.projects.home.recap.model as model_mod
import torchmetrics as tm
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from absl import app, flags, logging
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_bool("debug_loop", False, "Run with debug loop (slow)")
FLAGS = flags.FLAGS
def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):
print("#" * 100)
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
logging.info("Config: %s", config.pretty_print())
device = setup_and_get_device()
# Always enable tensorfloat on supported devices.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
loss_fn = losses.build_multi_task_loss(
loss_type=LossType.BCE_WITH_LOGITS,
tasks=list(config.model.tasks.keys()),
pos_weights=[task.pos_weight for task in config.model.tasks.values()],
)
# Since the prod model doesn't use large embeddings, for now we won't support them.
assert config.model.large_embeddings is None
train_dataset = ds.RecapDataset(
data_config=config.train_data,
dataset_service=data_service_dispatcher,
mode=recap_config_mod.JobMode.TRAIN,
compression=config.train_data.dataset_service_compression,
vocab_mapper=None,
repeat=True,
)
train_iterator = iter(train_dataset.to_dataloader())
torch_element_spec = train_dataset.torch_element_spec
model = model_mod.create_ranking_model(
data_spec=torch_element_spec[0],
config=config,
loss_fn=loss_fn,
device=device,
)
optimizer, scheduler = optimizer_mod.build_optimizer(model, config.optimizer, None)
model = maybe_shard_model(model, device)
datetime_str = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")
print(f"{datetime_str}\n", end="")
if FLAGS.debug_loop:
logging.warning("Running debug mode, slow!")
train_mod = debug_training_loop
else:
train_mod = ctl
train_mod.train(
model=model,
optimizer=optimizer,
device=device,
save_dir=config.training.save_dir,
logging_interval=config.training.train_log_every_n,
train_steps=config.training.num_train_steps,
checkpoint_frequency=config.training.checkpoint_every_n,
dataset=train_iterator,
worker_batch_size=config.train_data.global_batch_size,
enable_amp=False,
initial_checkpoint_dir=config.training.initial_checkpoint_dir,
gradient_accumulation=config.training.gradient_accumulation,
scheduler=scheduler,
)
if __name__ == "__main__":
app.run(run)

View file

@ -0,0 +1,7 @@
from tml.projects.home.recap.model.entrypoint import (
create_ranking_model,
sanitize,
unsanitize,
MultiTaskRankingModel,
)
from tml.projects.home.recap.model.model_and_loss import ModelAndLoss

View file

@ -0,0 +1,260 @@
"""Configuration for the main Recap model."""
import enum
from typing import List, Optional, Dict
import tml.core.config as base_config
from tml.projects.home.recap.embedding import config as embedding_config
import pydantic
class DropoutConfig(base_config.BaseConfig):
"""Configuration for the dropout layer."""
rate: pydantic.PositiveFloat = pydantic.Field(
0.1, description="Fraction of inputs to be dropped."
)
class LayerNormConfig(base_config.BaseConfig):
"""Configruation for the layer normalization."""
epsilon: float = pydantic.Field(
1e-3, description="Small float added to variance to avoid dividing by zero."
)
axis: int = pydantic.Field(-1, description="Axis or axes to normalize across.")
center: bool = pydantic.Field(True, description="Whether to add learnable center.")
scale: bool = pydantic.Field(True, description="Whether to add learnable scale.")
class BatchNormConfig(base_config.BaseConfig):
"""Configuration of the batch normalization layer."""
epsilon: pydantic.PositiveFloat = 1e-5
momentum: pydantic.PositiveFloat = 0.9
training_mode_at_inference_time: bool = False
use_renorm: bool = False
center: bool = pydantic.Field(True, description="Whether to add learnable center.")
scale: bool = pydantic.Field(True, description="Whether to add learnable scale.")
class DenseLayerConfig(base_config.BaseConfig):
layer_size: pydantic.PositiveInt
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout config for layer.")
class MlpConfig(base_config.BaseConfig):
"""Configuration for MLP model."""
layer_sizes: List[pydantic.PositiveInt] = pydantic.Field(None, one_of="mlp_layer_definition")
layers: List[DenseLayerConfig] = pydantic.Field(None, one_of="mlp_layer_definition")
class BatchNormConfig(base_config.BaseConfig):
"""Configuration for the batch norm layer."""
affine: bool = pydantic.Field(True, description="Use affine transformation.")
momentum: pydantic.PositiveFloat = pydantic.Field(
0.1, description="Forgetting parameter in moving average."
)
class DoubleNormLogConfig(base_config.BaseConfig):
batch_norm_config: Optional[BatchNormConfig] = pydantic.Field(None)
clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values."
)
layer_norm_config: Optional[LayerNormConfig] = pydantic.Field(None)
class Log1pAbsConfig(base_config.BaseConfig):
"""Simple configuration where only the log transform is performed."""
class ClipLog1pAbsConfig(base_config.BaseConfig):
clip_magnitude: pydantic.NonNegativeFloat = pydantic.Field(
3e38, description="Threshold to clip the input values."
)
class ZScoreLogConfig(base_config.BaseConfig):
analysis_path: str
schema_path: str = pydantic.Field(
None,
description="Schema path which feaure statistics are generated with. Can be different from scehma in data config.",
)
clip_magnitude: float = pydantic.Field(
5.0, description="Threshold to clip the normalized input values."
)
use_batch_norm: bool = pydantic.Field(
False, description="Option to use batch normalization on the inputs."
)
use_renorm: bool = pydantic.Field(
False, description="Option to use batch renormalization for trainig and serving consistency."
)
use_bq_stats: bool = pydantic.Field(
False, description="Option to load the partitioned json files from BQ as statistics."
)
class FeaturizationConfig(base_config.BaseConfig):
"""Configuration for featurization."""
log1p_abs_config: Log1pAbsConfig = pydantic.Field(None, one_of="featurization")
clip_log1p_abs_config: ClipLog1pAbsConfig = pydantic.Field(None, one_of="featurization")
z_score_log_config: ZScoreLogConfig = pydantic.Field(None, one_of="featurization")
double_norm_log_config: DoubleNormLogConfig = pydantic.Field(None, one_of="featurization")
feature_names_to_concat: List[str] = pydantic.Field(
["binary"], description="Feature names to concatenate as raw values with continuous features."
)
class DropoutConfig(base_config.BaseConfig):
"""Configuration for the dropout layer."""
rate: pydantic.PositiveFloat = pydantic.Field(
0.1, description="Fraction of inputs to be dropped."
)
class MlpConfig(base_config.BaseConfig):
"""Configuration for MLP model."""
layer_sizes: List[pydantic.PositiveInt]
batch_norm: BatchNormConfig = pydantic.Field(
None, description="Optional batch norm configuration."
)
dropout: DropoutConfig = pydantic.Field(None, description="Optional dropout configuration.")
final_layer_activation: bool = pydantic.Field(
False, description="Whether to include activation on final layer."
)
class DcnConfig(base_config.BaseConfig):
"""Config for DCN model."""
poly_degree: pydantic.PositiveInt
projection_dim: pydantic.PositiveInt = pydantic.Field(
None, description="Factorizes main DCN matmul with projection."
)
parallel_mlp: Optional[MlpConfig] = pydantic.Field(
None, description="Config for the mlp if used. If None, only the cross layers are used."
)
use_parallel: bool = pydantic.Field(True, description="Whether to use parallel DCN.")
output_mlp: Optional[MlpConfig] = pydantic.Field(None, description="Config for the output mlp.")
class MaskBlockConfig(base_config.BaseConfig):
output_size: int
reduction_factor: Optional[pydantic.PositiveFloat] = pydantic.Field(
None, one_of="aggregation_size"
)
aggregation_size: Optional[pydantic.PositiveInt] = pydantic.Field(
None, description="Specify the aggregation size directly.", one_of="aggregation_size"
)
input_layer_norm: bool
class MaskNetConfig(base_config.BaseConfig):
mask_blocks: List[MaskBlockConfig]
mlp: Optional[MlpConfig] = pydantic.Field(None, description="MLP Configuration for parallel")
use_parallel: bool = pydantic.Field(False, description="Whether to use parallel MaskNet.")
class PositionDebiasConfig(base_config.BaseConfig):
"""
Configuration for Position Debias.
"""
max_position: int = pydantic.Field(256, description="Bucket all later positions.")
num_dims: pydantic.PositiveInt = pydantic.Field(
64, description="Number of dimensions in embedding."
)
drop_probability: float = pydantic.Field(0.5, description="Probability of dropping position.")
# Currently it should be 51 based on dataset being tested at the time of writing this model
# However, no default provided here to make sure user of the model is aware of its importance.
position_feature_index: int = pydantic.Field(
description="The index of the position feature in the discrete features"
)
class AffineMap(base_config.BaseConfig):
"""An affine map that scales the logits into the appropriate range."""
scale: float = pydantic.Field(1.0)
bias: float = pydantic.Field(0.0)
class DLRMConfig(base_config.BaseConfig):
bottom_mlp: MlpConfig = pydantic.Field(
...,
description="Bottom mlp, the output to be combined with sparse features and feed to interaction",
)
top_mlp: MlpConfig = pydantic.Field(..., description="Top mlp, generate the final output")
class TaskModel(base_config.BaseConfig):
mlp_config: MlpConfig = pydantic.Field(None, one_of="architecture")
dcn_config: DcnConfig = pydantic.Field(None, one_of="architecture")
dlrm_config: DLRMConfig = pydantic.Field(None, one_of="architecture")
mask_net_config: MaskNetConfig = pydantic.Field(None, one_of="architecture")
affine_map: AffineMap = pydantic.Field(
None,
description="Affine map applied to logits so we can represent a broader range of probabilities.",
)
# DANGER DANGER: not implemented yet.
# loss_weight: float = pydantic.Field(1.0, description="Weight for task in loss.")
pos_weight: float = pydantic.Field(1.0, description="Weight of positive in loss.")
class MultiTaskType(str, enum.Enum):
SHARE_NONE = "share_none" # Tasks are separate.
SHARE_ALL = "share_all" # Tasks share same backbone.
SHARE_PARTIAL = "share_partial" # Tasks share some backbone, but have their own portions.
class ModelConfig(base_config.BaseConfig):
"""Specify model architecture."""
tasks: Dict[str, TaskModel] = pydantic.Field(
description="Specification of architecture per task."
)
large_embeddings: embedding_config.LargeEmbeddingsConfig = pydantic.Field(None)
small_embeddings: embedding_config.SmallEmbeddingsConfig = pydantic.Field(None)
# Not implemented yet.
# multi_task_loss_reduction_fn: str = "mean"
position_debias_config: PositionDebiasConfig = pydantic.Field(
default=None, description="position debias model configuration"
)
featurization_config: FeaturizationConfig = pydantic.Field(None)
multi_task_type: MultiTaskType = pydantic.Field(
MultiTaskType.SHARE_NONE, description="Multi task architecture"
)
backbone: TaskModel = pydantic.Field(None, description="Type of architecture for the backbone.")
stratifiers: List[embedding_config.StratifierConfig] = pydantic.Field(
default=None, description="Discrete features and values to stratify metrics by."
)
@pydantic.root_validator()
def _validate_mtl(cls, values):
if values.get("multi_task_type", None) is None:
return values
elif values["multi_task_type"] in [MultiTaskType.SHARE_ALL, MultiTaskType.SHARE_PARTIAL]:
if values.get("backbone", None) is None:
raise ValueError("Require `backbone` for SHARE_ALL and SHARE_PARTIAL.")
elif values["multi_task_type"] in [
MultiTaskType.SHARE_NONE,
]:
if values.get("backbone", None) is not None:
raise ValueError("Can not have backbone if the share type is SHARE_NONE")
return values

View file

@ -0,0 +1,304 @@
from __future__ import annotations
from absl import logging
import torch
from typing import Optional, Callable, Mapping, Dict, Sequence, TYPE_CHECKING
from tml.projects.home.recap.model import feature_transform
from tml.projects.home.recap.model import config as model_config_mod
from tml.projects.home.recap.model import mlp
from tml.projects.home.recap.model import mask_net
from tml.projects.home.recap.model import numeric_calibration
from tml.projects.home.recap.model.model_and_loss import ModelAndLoss
import tml.projects.home.recap.model.config as model_config_mod
if TYPE_CHECKING:
from tml.projects.home.recap import config as config_mod
from tml.projects.home.recap.data.config import RecapDataConfig
from tml.projects.home.recap.model.config import ModelConfig
def sanitize(task_name):
return task_name.replace(".", "__")
def unsanitize(sanitized_task_name):
return sanitized_task_name.replace("__", ".")
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task"""
if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config:
return dcn.Dcn(dcn_config=task.dcn_config, in_features=input_shape)
elif task.mask_net_config:
return mask_net.MaskNet(mask_net_config=task.mask_net_config, in_features=input_shape)
else:
raise ValueError("This should never be reached.")
class MultiTaskRankingModel(torch.nn.Module):
"""Multi-task ranking model."""
def __init__(
self,
input_shapes: Mapping[str, torch.Size],
config: ModelConfig,
data_config: RecapDataConfig,
return_backbone: bool = False,
):
"""Constructor for Multi task learning.
Assumptions made:
1. Tasks specified in data config match model architecture.
These are all validated in config.
"""
super().__init__()
self._config = config
self._data_config = data_config
self._preprocessor = feature_transform.build_features_preprocessor(
config.featurization_config, input_shapes
)
self.return_backbone = return_backbone
self.embeddings = None
self.small_embeddings = None
embedding_dims = 0
if config.large_embeddings:
from large_embeddings.models.learnable_embeddings import LargeEmbeddings
self.embeddings = LargeEmbeddings(large_embeddings_config=config.large_embeddings)
embedding_dims += sum([table.embedding_dim for table in config.large_embeddings.tables])
logging.info(f"Emb dim: {embedding_dims}")
if config.small_embeddings:
self.small_embeddings = SmallEmbedding(config.small_embeddings)
embedding_dims += sum([table.embedding_dim for table in config.small_embeddings.tables])
logging.info(f"Emb dim (with small embeddings): {embedding_dims}")
if "user_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["user_embedding"][-1]
self._user_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["user_embedding"][-1])
else:
self._user_embedding_layer_norm = None
if "user_eng_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["user_eng_embedding"][-1]
self._user_eng_embedding_layer_norm = torch.nn.LayerNorm(
input_shapes["user_eng_embedding"][-1]
)
else:
self._user_eng_embedding_layer_norm = None
if "author_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["author_embedding"][-1]
self._author_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["author_embedding"][-1])
else:
self._author_embedding_layer_norm = None
input_dims = input_shapes["continuous"][-1] + input_shapes["binary"][-1] + embedding_dims
if config.position_debias_config:
self.position_debias_model = PositionDebias(config.position_debias_config)
input_dims += self.position_debias_model.out_features
else:
self.position_debias_model = None
logging.info(f"input dim: {input_dims}")
if config.multi_task_type in [
model_config_mod.MultiTaskType.SHARE_ALL,
model_config_mod.MultiTaskType.SHARE_PARTIAL,
]:
self._backbone = _build_single_task_model(config.backbone, input_dims)
else:
self._backbone = None
_towers: Dict[str, torch.nn.Module] = {}
_calibrators: Dict[str, torch.nn.Module] = {}
_affine_maps: Dict[str, torch.nn.Module] = {}
for task_name, task_architecture in config.tasks.items():
safe_name = sanitize(task_name)
# Complex input dimension calculation.
if config.multi_task_type == model_config_mod.MultiTaskType.SHARE_NONE:
num_inputs = input_dims
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
num_inputs = self._backbone.out_features
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
num_inputs = input_dims + self._backbone.out_features
else:
raise ValueError("Unreachable branch of enum.")
# Annoyingly, ModuleDict doesn't allow . inside key names.
_towers[safe_name] = _build_single_task_model(task_architecture, num_inputs)
if task_architecture.affine_map:
affine_map = torch.nn.Linear(1, 1)
affine_map.weight.data = torch.tensor([[task_architecture.affine_map.scale]])
affine_map.bias.data = torch.tensor([task_architecture.affine_map.bias])
_affine_maps[safe_name] = affine_map
else:
_affine_maps[safe_name] = torch.nn.Identity()
_calibrators[safe_name] = numeric_calibration.NumericCalibration(
pos_downsampling_rate=data_config.tasks[task_name].pos_downsampling_rate,
neg_downsampling_rate=data_config.tasks[task_name].neg_downsampling_rate,
)
self._task_names = list(config.tasks.keys())
self._towers = torch.nn.ModuleDict(_towers)
self._affine_maps = torch.nn.ModuleDict(_affine_maps)
self._calibrators = torch.nn.ModuleDict(_calibrators)
self._counter = torch.autograd.Variable(torch.tensor(0), requires_grad=False)
def forward(
self,
continuous_features: torch.Tensor,
binary_features: torch.Tensor,
discrete_features: Optional[torch.Tensor] = None,
sparse_features=None, # Optional[KeyedJaggedTensor]
user_embedding: Optional[torch.Tensor] = None,
user_eng_embedding: Optional[torch.Tensor] = None,
author_embedding: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
weights: Optional[torch.Tensor] = None,
):
concat_dense_features = [
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
]
if self.embeddings:
concat_dense_features.append(self.embeddings(sparse_features))
# Twhin embedding layer norms
if self.small_embeddings:
if discrete_features is None:
raise ValueError(
"Forward arg discrete_features is None, but since small_embeddings are used, a Tensor is expected."
)
concat_dense_features.append(self.small_embeddings(discrete_features))
if self._user_embedding_layer_norm:
if user_embedding is None:
raise ValueError(
"Forward arg user_embedding is None, but since Twhin user_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._user_embedding_layer_norm(user_embedding))
if self._user_eng_embedding_layer_norm:
if user_eng_embedding is None:
raise ValueError(
"Forward arg user_eng_embedding is None, but since Twhin user_eng_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._user_eng_embedding_layer_norm(user_eng_embedding))
if self._author_embedding_layer_norm:
if author_embedding is None:
raise ValueError(
"Forward arg author_embedding is None, but since Twhin author_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._author_embedding_layer_norm(author_embedding))
if self.position_debias_model:
if discrete_features is None:
raise ValueError(
"Forward arg discrete_features is None, but since position_debias_model is used, a Tensor is expected."
)
concat_dense_features.append(self.position_debias_model(discrete_features))
if discrete_features is not None and not (self.position_debias_model or self.small_embeddings):
logging.warning("Forward arg discrete_features is passed, but never used.")
concat_dense_features = torch.cat(concat_dense_features, dim=1)
if self._backbone:
if self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
net = self._backbone(concat_dense_features)["output"]
elif self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
net = torch.cat(
[concat_dense_features, self._backbone(concat_dense_features)["output"]], dim=1
)
else:
net = concat_dense_features
backbone_result = net
all_logits = []
all_probabilities = []
all_calibrated_probabilities = []
for task_name in self._task_names:
safe_name = sanitize(task_name)
tower_outputs = self._towers[safe_name](net)
logits = tower_outputs["output"]
scaled_logits = self._affine_maps[safe_name](logits)
probabilities = torch.sigmoid(scaled_logits)
calibrated_probabilities = self._calibrators[safe_name](probabilities)
all_logits.append(scaled_logits)
all_probabilities.append(probabilities)
all_calibrated_probabilities.append(calibrated_probabilities)
results = {
"logits": torch.squeeze(torch.stack(all_logits, dim=1), dim=-1),
"probabilities": torch.squeeze(torch.stack(all_probabilities, dim=1), dim=-1),
"calibrated_probabilities": torch.squeeze(
torch.stack(all_calibrated_probabilities, dim=1), dim=-1
),
}
# Returning the backbone is intended for stitching post-tf conversion
# Leaving this on will ~200x the size of the output
# and could slow things down
if self.return_backbone:
results["backbone"] = backbone_result
return results
def create_ranking_model(
data_spec,
# Used for planner to be batch size aware.
config: config_mod.RecapConfig,
device: torch.device,
loss_fn: Optional[Callable] = None,
data_config=None,
return_backbone=False,
):
if list(config.model.tasks.values())[0].dlrm_config:
raise NotImplementedError()
model = EmbeddingRankingModel(
input_shapes=data_spec,
config=all_config.model,
data_config=all_config.train_data,
)
else:
model = MultiTaskRankingModel(
input_shapes=data_spec,
config=config.model,
data_config=data_config if data_config is not None else config.train_data,
return_backbone=return_backbone,
)
logging.info("***** Model Architecture *****")
logging.info(model)
logging.info("***** Named Parameters *****")
for elem in model.named_parameters():
logging.info(elem[0])
if loss_fn:
logging.info("***** Wrapping in loss *****")
model = ModelAndLoss(
model=model,
loss_fn=loss_fn,
stratifiers=config.model.stratifiers,
)
return model

View file

@ -0,0 +1,122 @@
from typing import Mapping, Sequence, Union
from tml.projects.home.recap.model.config import (
BatchNormConfig,
DoubleNormLogConfig,
FeaturizationConfig,
LayerNormConfig,
)
import torch
def log_transform(x: torch.Tensor) -> torch.Tensor:
"""Safe log transform that works across both negative, zero, and positive floats."""
return torch.sign(x) * torch.log1p(torch.abs(x))
class BatchNorm(torch.nn.Module):
def __init__(self, num_features: int, config: BatchNormConfig):
super().__init__()
self.layer = torch.nn.BatchNorm1d(num_features, affine=config.affine, momentum=config.momentum)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer(x)
class LayerNorm(torch.nn.Module):
def __init__(self, normalized_shape: Union[int, Sequence[int]], config: LayerNormConfig):
super().__init__()
if config.axis != -1:
raise NotImplementedError
if config.center != config.scale:
raise ValueError(
f"Center and scale must match in torch, received {config.center}, {config.scale}"
)
self.layer = torch.nn.LayerNorm(
normalized_shape, eps=config.epsilon, elementwise_affine=config.center
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer(x)
class Log1pAbs(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return log_transform(x)
class InputNonFinite(torch.nn.Module):
def __init__(self, fill_value: float = 0):
super().__init__()
self.register_buffer(
"fill_value", torch.as_tensor(fill_value, dtype=torch.float32), persistent=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.where(torch.isfinite(x), x, self.fill_value)
class Clamp(torch.nn.Module):
def __init__(self, min_value: float, max_value: float):
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"min_value", torch.as_tensor(min_value, dtype=torch.float32), persistent=True
)
self.register_buffer(
"max_value", torch.as_tensor(max_value, dtype=torch.float32), persistent=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.clamp(x, min=self.min_value, max=self.max_value)
class DoubleNormLog(torch.nn.Module):
"""Performs a batch norm and clamp on continuous features followed by a layer norm on binary and continuous features."""
def __init__(
self,
input_shapes: Mapping[str, Sequence[int]],
config: DoubleNormLogConfig,
):
super().__init__()
_before_concat_layers = [
InputNonFinite(),
Log1pAbs(),
]
if config.batch_norm_config:
_before_concat_layers.append(
BatchNorm(input_shapes["continuous"][-1], config.batch_norm_config)
)
_before_concat_layers.append(
Clamp(min_value=-config.clip_magnitude, max_value=config.clip_magnitude)
)
self._before_concat_layers = torch.nn.Sequential(*_before_concat_layers)
self.layer_norm = None
if config.layer_norm_config:
last_dim = input_shapes["continuous"][-1] + input_shapes["binary"][-1]
self.layer_norm = LayerNorm(last_dim, config.layer_norm_config)
def forward(
self, continuous_features: torch.Tensor, binary_features: torch.Tensor
) -> torch.Tensor:
x = self._before_concat_layers(continuous_features)
x = torch.cat([x, binary_features], dim=1)
if self.layer_norm:
return self.layer_norm(x)
return x
def build_features_preprocessor(
config: FeaturizationConfig, input_shapes: Mapping[str, Sequence[int]]
):
"""Trivial right now, but we will change in the future."""
return DoubleNormLog(input_shapes, config.double_norm_log_config)

View file

@ -0,0 +1,98 @@
"""MaskNet: Wang et al. (https://arxiv.org/abs/2102.07619)."""
from tml.projects.home.recap.model import config, mlp
import torch
def _init_weights(module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class MaskBlock(torch.nn.Module):
def __init__(
self, mask_block_config: config.MaskBlockConfig, input_dim: int, mask_input_dim: int
) -> None:
super(MaskBlock, self).__init__()
self.mask_block_config = mask_block_config
output_size = mask_block_config.output_size
if mask_block_config.input_layer_norm:
self._input_layer_norm = torch.nn.LayerNorm(input_dim)
else:
self._input_layer_norm = None
if mask_block_config.reduction_factor:
aggregation_size = int(mask_input_dim * mask_block_config.reduction_factor)
elif mask_block_config.aggregation_size is not None:
aggregation_size = mask_block_config.aggregation_size
else:
raise ValueError("Need one of reduction factor or aggregation size.")
self._mask_layer = torch.nn.Sequential(
torch.nn.Linear(mask_input_dim, aggregation_size),
torch.nn.ReLU(),
torch.nn.Linear(aggregation_size, input_dim),
)
self._mask_layer.apply(_init_weights)
self._hidden_layer = torch.nn.Linear(input_dim, output_size)
self._hidden_layer.apply(_init_weights)
self._layer_norm = torch.nn.LayerNorm(output_size)
def forward(self, net: torch.Tensor, mask_input: torch.Tensor):
if self._input_layer_norm:
net = self._input_layer_norm(net)
hidden_layer_output = self._hidden_layer(net * self._mask_layer(mask_input))
return self._layer_norm(hidden_layer_output)
class MaskNet(torch.nn.Module):
def __init__(self, mask_net_config: config.MaskNetConfig, in_features: int):
super().__init__()
self.mask_net_config = mask_net_config
mask_blocks = []
if mask_net_config.use_parallel:
total_output_mask_blocks = 0
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, in_features, in_features))
total_output_mask_blocks += mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
else:
input_size = in_features
for mask_block_config in mask_net_config.mask_blocks:
mask_blocks.append(MaskBlock(mask_block_config, input_size, in_features))
input_size = mask_block_config.output_size
self._mask_blocks = torch.nn.ModuleList(mask_blocks)
total_output_mask_blocks = mask_block_config.output_size
if mask_net_config.mlp:
self._dense_layers = mlp.Mlp(total_output_mask_blocks, mask_net_config.mlp)
self.out_features = mask_net_config.mlp.layer_sizes[-1]
else:
self.out_features = total_output_mask_blocks
self.shared_size = total_output_mask_blocks
def forward(self, inputs: torch.Tensor):
if self.mask_net_config.use_parallel:
mask_outputs = []
for mask_layer in self._mask_blocks:
mask_outputs.append(mask_layer(mask_input=inputs, net=inputs))
# Share the outputs of the MaskBlocks.
all_mask_outputs = torch.cat(mask_outputs, dim=1)
output = (
all_mask_outputs
if self.mask_net_config.mlp is None
else self._dense_layers(all_mask_outputs)["output"]
)
return {"output": output, "shared_layer": all_mask_outputs}
else:
net = inputs
for mask_layer in self._mask_blocks:
net = mask_layer(net=net, mask_input=inputs)
# Share the output of the stacked MaskBlocks.
output = net if self.mask_net_config.mlp is None else self._dense_layers[net]["output"]
return {"output": output, "shared_layer": net}

View file

@ -0,0 +1,58 @@
"""MLP feed forward stack in torch."""
from tml.projects.home.recap.model.config import MlpConfig
import torch
from absl import logging
def _init_weights(module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
class Mlp(torch.nn.Module):
def __init__(self, in_features: int, mlp_config: MlpConfig):
super().__init__()
self._mlp_config = mlp_config
input_size = in_features
layer_sizes = mlp_config.layer_sizes
modules = []
for layer_size in layer_sizes[:-1]:
modules.append(torch.nn.Linear(input_size, layer_size, bias=True))
if mlp_config.batch_norm:
modules.append(
torch.nn.BatchNorm1d(
layer_size, affine=mlp_config.batch_norm.affine, momentum=mlp_config.batch_norm.momentum
)
)
modules.append(torch.nn.ReLU())
if mlp_config.dropout:
modules.append(torch.nn.Dropout(mlp_config.dropout.rate))
input_size = layer_size
modules.append(torch.nn.Linear(input_size, layer_sizes[-1], bias=True))
if mlp_config.final_layer_activation:
modules.append(torch.nn.ReLU())
self.layers = torch.nn.ModuleList(modules)
self.layers.apply(_init_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
net = x
for i, layer in enumerate(self.layers):
net = layer(net)
if i == 1: # Share the first (widest?) set of activations for other applications.
shared_layer = net
return {"output": net, "shared_layer": shared_layer}
@property
def shared_size(self):
return self._mlp_config.layer_sizes[-1]
@property
def out_features(self):
return self._mlp_config.layer_sizes[-1]

View file

@ -0,0 +1,67 @@
from typing import Callable, Optional, List
from tml.projects.home.recap.embedding import config as embedding_config_mod
import torch
from absl import logging
class ModelAndLoss(torch.nn.Module):
def __init__(
self,
model,
loss_fn: Callable,
stratifiers: Optional[List[embedding_config_mod.StratifierConfig]] = None,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
straitifiers: mapping of stratifier name and index of discrete features to emit for metrics stratification.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.stratifiers = stratifiers
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(
continuous_features=batch.continuous_features,
binary_features=batch.binary_features,
discrete_features=batch.discrete_features,
sparse_features=batch.sparse_features,
user_embedding=batch.user_embedding,
user_eng_embedding=batch.user_eng_embedding,
author_embedding=batch.author_embedding,
labels=batch.labels,
weights=batch.weights,
)
losses = self.loss_fn(outputs["logits"], batch.labels.float(), batch.weights.float())
if self.stratifiers:
logging.info(f"***** Adding stratifiers *****\n {self.stratifiers}")
outputs["stratifiers"] = {}
for stratifier in self.stratifiers:
outputs["stratifiers"][stratifier.name] = batch.discrete_features[:, stratifier.index]
# In general, we can have a large number of losses returned by our loss function.
if isinstance(losses, dict):
return losses["loss"], {
**outputs,
**losses,
"labels": batch.labels,
"weights": batch.weights,
}
else: # Assume that this is a float.
return losses, {
**outputs,
"loss": losses,
"labels": batch.labels,
"weights": batch.weights,
}

View file

@ -0,0 +1,19 @@
import torch
class NumericCalibration(torch.nn.Module):
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
)
def forward(self, probs: torch.Tensor):
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))

View file

@ -0,0 +1 @@
from tml.projects.home.recap.optimizer.optimizer import build_optimizer

View file

@ -0,0 +1,36 @@
"""Optimization configurations for models."""
import typing
import tml.core.config as base_config
import tml.optimizers.config as optimizers_config_mod
import pydantic
class RecapAdamConfig(base_config.BaseConfig):
beta_1: float = 0.9 # Momentum term.
beta_2: float = 0.999 # Exponential weighted decay factor.
epsilon: float = 1e-7 # Numerical stability in denominator.
class MultiTaskLearningRates(base_config.BaseConfig):
tower_learning_rates: typing.Dict[str, optimizers_config_mod.LearningRate] = pydantic.Field(
description="Learning rates for different towers of the model."
)
backbone_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
None, description="Learning rate for backbone of the model."
)
class RecapOptimizerConfig(base_config.BaseConfig):
multi_task_learning_rates: MultiTaskLearningRates = pydantic.Field(
None, description="Multiple learning rates for different tasks.", one_of="lr"
)
single_task_learning_rate: optimizers_config_mod.LearningRate = pydantic.Field(
None, description="Single task learning rates", one_of="lr"
)
adam: RecapAdamConfig = pydantic.Field(one_of="optimizer")

View file

@ -0,0 +1,166 @@
"""Build optimizers and learning rate schedules."""
import bisect
from collections import defaultdict
import functools
import math
import typing
from typing import Optional
import warnings
# from large_embeddings.config import EmbeddingOptimizerConfig
from tml.projects.home.recap import model as model_mod
from tml.optimizers import config
from tml.optimizers import compute_lr
from absl import logging # type: ignore[attr-defined]
import torch
from torchrec.optim import keyed
_DEFAULT_LR = 24601.0 # NaN the model if we're not using the learning rate.
_BACKBONE = "backbone"
_DENSE_EMBEDDINGS = "dense_ebc"
class RecapLRShim(torch.optim.lr_scheduler._LRScheduler):
"""Shim to get learning rates into a LRScheduler.
This adheres to the torch.optim scheduler API and can be plugged anywhere that
e.g. exponential decay can be used.
"""
def __init__(
self,
optimizer,
lr_dict: typing.Dict[str, config.LearningRate],
emb_learning_rate,
last_epoch=-1,
verbose=False,
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
self.emb_learning_rate = emb_learning_rate
# We handle sparse LR scheduling separately, so only validate LR groups against dense param groups
num_dense_param_groups = sum(
1
for _, _optim in optimizer._optims
for _ in _optim.param_groups
if isinstance(_optim, keyed.KeyedOptimizerWrapper)
)
if num_dense_param_groups != len(lr_dict):
raise ValueError(
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
)
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
learning_rates = []
for lr_config in self.lr_dict.values():
learning_rates.append(compute_lr(lr_config, self.last_epoch))
# WARNING: The order of appending is important.
if self.emb_learning_rate:
learning_rates.append(compute_lr(self.emb_learning_rate, self.last_epoch))
return learning_rates
def build_optimizer(
model: torch.nn.Module,
optimizer_config: config.OptimizerConfig,
emb_optimizer_config: None = None, # Optional[EmbeddingOptimizerConfig] = None,
):
"""Builds an optimizer and scheduler.
Args:
model: A torch model, probably with DDP/DMP.
optimizer_config: An OptimizerConfig object that specifies learning rates per tower.
Returns:
A torch.optim instance, and a scheduler instance.
"""
optimizer_fn = functools.partial(
torch.optim.Adam,
lr=_DEFAULT_LR,
betas=(optimizer_config.adam.beta_1, optimizer_config.adam.beta_2),
eps=optimizer_config.adam.epsilon,
maximize=False,
)
if optimizer_config.multi_task_learning_rates:
logging.info("***** Parameter groups for optimization *****")
# Importantly, we preserve insertion order in dictionaries here.
parameter_groups: typing.Dict[str, typing.Dict] = defaultdict(dict)
added_parameters: typing.Set[str] = set()
for task in optimizer_config.multi_task_learning_rates.tower_learning_rates:
for name, parameter in model.named_parameters():
if f".{model_mod.sanitize(task)}." in name:
parameter_groups[task][name] = parameter
logging.info(f"{task}: {name}")
if name in added_parameters:
raise ValueError(f"Parameter {name} matched multiple tasks.")
added_parameters.add(name)
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" not in name:
parameter_groups[_BACKBONE][name] = parameter
added_parameters.add(name)
logging.info(f"{_BACKBONE}: {name}")
for name, parameter in model.named_parameters():
if name not in added_parameters and "embedding_bags" in name:
parameter_groups[_DENSE_EMBEDDINGS][name] = parameter
logging.info(f"{_DENSE_EMBEDDINGS}: {name}")
all_learning_rates = optimizer_config.multi_task_learning_rates.tower_learning_rates.copy()
if optimizer_config.multi_task_learning_rates.backbone_learning_rate is not None:
all_learning_rates[
_BACKBONE
] = optimizer_config.multi_task_learning_rates.backbone_learning_rate
if _DENSE_EMBEDDINGS in parameter_groups and emb_optimizer_config:
all_learning_rates[_DENSE_EMBEDDINGS] = emb_optimizer_config.learning_rate.copy()
else:
parameter_groups = dict(model.named_parameters())
all_learning_rates = {"single_task": optimizer_config.single_task_learning_rate}
optimizers = [
keyed.KeyedOptimizerWrapper(param_group, optimizer_fn)
for param_name, param_group in parameter_groups.items()
if param_name != _DENSE_EMBEDDINGS
]
# Making EBC optimizer to be SGD to match fused optimiser
if _DENSE_EMBEDDINGS in parameter_groups:
optimizers.append(
keyed.KeyedOptimizerWrapper(
parameter_groups[_DENSE_EMBEDDINGS],
functools.partial(torch.optim.SGD, lr=_DEFAULT_LR, maximize=False, momentum=False),
)
)
if not parameter_groups.keys() == all_learning_rates.keys():
raise ValueError("Learning rates do not match optimizers")
# If the optimiser is dense, model.fused_optimizer will be empty (but not None)
emb_learning_rate = None
if hasattr(model, "fused_optimizer") and model.fused_optimizer.optimizers:
logging.info(f"Model fused optimiser: {model.fused_optimizer}")
optimizers.append(model.fused_optimizer)
if emb_optimizer_config:
emb_learning_rate = emb_optimizer_config.learning_rate.copy()
else:
raise ValueError("Fused kernel exists, but LR is not set")
logging.info(f"***** Combining optimizers: {optimizers} *****")
optimizer = keyed.CombinedOptimizer(optimizers)
scheduler = RecapLRShim(optimizer, all_learning_rates, emb_learning_rate)
logging.info(f"***** Combined optimizer after init: {optimizer} *****")
return optimizer, scheduler

View file

@ -0,0 +1,499 @@
{
"schema" : [
{
"dtype" : "int64_list",
"feature_name" : "home_recap_2022_discrete__segdense_vals",
"length" : 320
},
{
"dtype" : "float_list",
"feature_name" : "home_recap_2022_cont__segdense_vals",
"length" : 6000
},
{
"dtype" : "int64_list",
"feature_name" : "home_recap_2022_binary__segdense_vals",
"length" : 512
},
{
"dtype" : "bytes_list",
"feature_name" : "home_recap_2022_string__segdense_vals",
"length" : 64
},
{
"dtype" : "bytes_list",
"feature_name" : "home_recap_2022_blob__segdense_vals",
"length" : 64
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_photo_expanded",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_tweet_detail_dwelled_15_sec",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "meta.timestamp",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_replied",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_profile_clicked_and_profile_engaged",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_retweeted",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_profile_clicked",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_shared",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_share_menu_clicked",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_replied_reply_engaged_by_author",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_profile_clicked_and_profile_engaged",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_profile_dwelled_30_sec",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_video_playback_50",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_negative_feedback_v2",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_strong_negative_feedback",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_photo_expanded",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_share_menu_clicked",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_report_tweet_clicked",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_dwell",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_replied",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_clicked",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "meta.author_id",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_weak_negative_feedback",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_favorited",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_weak_negative_feedback",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_negative_feedback_v2",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_profile_dwelled_20_sec",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_quoted",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_profile_dwelled_20_sec",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_retweeted",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_profile_dwelled_10_sec",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_report_tweet_clicked",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_strong_negative_feedback",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_dwelled",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_profile_dwelled_10_sec",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "timelines.meta.candidate_tweet_source_id",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_favorited",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_clicked",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_good_clicked_convo_desc_v2",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "timelines.engagement.dwell_time_ms",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_good_clicked_convo_desc_favorited_or_replied",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_tweet_detail_dwelled_15_sec",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "meta.tweet_id",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_good_clicked_convo_desc_v2",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "meta.user_id",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_good_clicked_convo_desc_favorited_or_replied",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_replied_reply_engaged_by_author",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_bookmarked",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "meta.record_weight/type=recap",
"length" : 1
},
{
"dtype" : "float_list",
"feature_name" : "recap.engagement_predicted.is_video_playback_50",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_shared",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "timelines.meta.prediction_request_id",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_quoted",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_profile_clicked",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "recap.engagement.is_profile_dwelled_30_sec",
"length" : 1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.good_profile_click",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.v3.tweet.text.emoji_tokens",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "entities.inferred_topic_ids",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "semantic_core.media_understanding.high_recall.sensitive.entity_ids",
"length" : -1
},
{
"dtype" : "float_list",
"feature_name" : "user.timelines.twhin_user_engagement_embeddings.twhin_user_engagement_embeddings",
"length" : 200
},
{
"dtype" : "bytes_list",
"feature_name" : "user.interests_discovery.user_topic_relationships.not_interested_topics",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.v2.tweet.media.providers",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.v3.tweet.text.pos_unigrams",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.share_menu_click",
"length" : -1
},
{
"dtype" : "float_list",
"feature_name" : "original_author.twhin.tw_hi_n.author_follow_as_float_tensor",
"length" : 200
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.tweet.media.sticker_ids",
"length" : -1
},
{
"dtype" : "float_list",
"feature_name" : "original_author.timelines.twhin_author_follow_embeddings.twhin_author_follow_embeddings",
"length" : 200
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.good_click",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.v2.tweet.media.classification_labels__key0",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.share",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "semantic_core.media_understanding.high_recall.non_sensitive.entity_ids",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "semantic_core.media_understanding.high_precision.non_sensitive.entity_ids",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.reported",
"length" : -1
},
{
"dtype" : "float_list",
"feature_name" : "tweet.cuad.hashspace.tweet_embedding_tensor",
"length" : 256
},
{
"dtype" : "bytes_list",
"feature_name" : "semantic_core.common.entity_ids",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.negative_feedback_v_2",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.retweet",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.tweet.media.tag_screen_names",
"length" : -1
},
{
"dtype" : "float_list",
"feature_name" : "user.twhin.tw_hi_n.user_follow_as_float_tensor",
"length" : 200
},
{
"dtype" : "float_list",
"feature_name" : "tweetsource.v2.tweet.media.classification_labels__vals",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "in_reply_to_tweet.tweetsource.v4.tweet.text.tokens",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.playback_50",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.reply",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "recap.tweetfeature.url_domains",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.reply_engaged_by_author",
"length" : -1
},
{
"dtype" : "float_list",
"feature_name" : "user.twhin.tw_hi_n.user_engagement_as_float_tensor",
"length" : 200
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.v3.tweet.text.pos_bigrams",
"length" : -1
},
{
"dtype" : "float_list",
"feature_name" : "user.timelines.twhin_user_follow_embeddings.twhin_user_follow_embeddings",
"length" : 200
},
{
"dtype" : "bytes_list",
"feature_name" : "in_reply_to_tweet.tweetsource.v3.tweet.text.emoji_tokens",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.good_click_v_2",
"length" : -1
},
{
"dtype" : "int64_list",
"feature_name" : "user.recommender_systems_research.recap_hashspace_pq_encoded.fav",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "user.interests_discovery.user_topic_relationships.followed_topics",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "semantic_core.media_understanding.high_precision.sensitive.entity_ids",
"length" : -1
},
{
"dtype" : "bytes_list",
"feature_name" : "tweetsource.v3.tweet.text.emoticon_tokens",
"length" : -1
}
]
}

View file

@ -0,0 +1,10 @@
#!/usr/bin/env bash
# Runs from inside venv
rm -rf $HOME/tmp/runs/recap_local_random_data
python -m tml.machines.is_venv || exit 1
export TML_BASE="$(git rev-parse --show-toplevel)"
mkdir -p $HOME/tmp/recap_local_random_data
python projects/home/recap/data/generate_random_data.py --config_path $(pwd)/projects/home/recap/config/local_prod.yaml

View file

@ -0,0 +1,15 @@
#!/usr/bin/env bash
# Runs from inside venv
rm -rf $HOME/tmp/runs/recap_local_debug
mkdir -p $HOME/tmp/runs/recap_local_debug
python -m tml.machines.is_venv || exit 1
export TML_BASE="$(git rev-parse --show-toplevel)"
torchrun \
--standalone \
--nnodes 1 \
--nproc_per_node 1 \
projects/home/recap/main.py \
--config_path $(pwd)/projects/home/recap/config/local_prod.yaml \
$@

28
projects/twhin/README.md Normal file
View file

@ -0,0 +1,28 @@
Twhin in torchrec
Workflow
========
- Build local development images `./scripts/build_images.sh`
- Run with `./scripts/docker_run.sh`
- Iterate in image with `./scripts/idocker.sh`
- Run tests with `./scripts/docker_test.sh`
Dump twhin user follows author prod data to gcs parquet in index-coordinates
===========
## Dump twhin user follows author prod data to gcs parquet in index-coordinates
Assumes the graph has a single partition
```
EXPORT DATA
OPTIONS ( uri = 'gs://follows_tml_01/tweet_eng/2023-01-23/small/edges/0/*',
format = 'PARQUET', compression="ZSTD",
overwrite = TRUE) AS ( (
SELECT
lhs, rhs, rel
FROM
`twttr-twhin-stg.tweet_eng_tml_01.tweet_eng_small_graph`
WHERE
snapshot_date = "2023-01-23" AND rel = 0)
)
```

14
projects/twhin/config.py Normal file
View file

@ -0,0 +1,14 @@
from tml.core.config import base_config
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.projects.twhin.models.config import TwhinModelConfig
from tml.core.config.training import RuntimeConfig, TrainingConfig
import pydantic
class TwhinConfig(base_config.BaseConfig):
runtime: RuntimeConfig = pydantic.Field(RuntimeConfig())
training: TrainingConfig = pydantic.Field(TrainingConfig())
model: TwhinModelConfig
train_data: TwhinDataConfig
validation_data: TwhinDataConfig

View file

@ -0,0 +1,72 @@
runtime:
wandb:
key_path: "/var/lib/tss/keys/${USER}/wandb.key"
name: "twhin-test"
entity: "-"
project: "twhin-test"
host: "https://https--wandb--prod--wandb.service.qus1.twitter.biz/"
training:
save_dir: "/tmp/model"
num_train_steps: 100000
checkpoint_every_n: 100000
train_log_every_n: 10
num_eval_steps: 1000
eval_log_every_n: 500
eval_timeout_in_s: 10000
num_epochs: 5
model:
translation_optimizer:
sgd:
lr: 0.05
learning_rate:
constant: 0.05
embeddings:
tables:
- name: user
num_embeddings: 424_241_060
embedding_dim: 128
data_type: fp32
optimizer:
sgd:
lr: 0.01
learning_rate:
constant: 0.01
- name: tweet
num_embeddings: 72_543_984
embedding_dim: 128
data_type: fp32
optimizer:
sgd:
lr: 0.005
learning_rate:
constant: 0.005
relations:
- name: fav
lhs: user
rhs: tweet
operator: translation
- name: reply
lhs: user
rhs: tweet
operator: translation
- name: retweet
lhs: user
rhs: tweet
operator: translation
- name: magic_recs
lhs: user
rhs: tweet
operator: translation
train_data:
data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*"
per_replica_batch_size: 60652
global_negatives: 0
in_batch_negatives: 0
limit: 9990
validation_data:
data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*"
per_replica_batch_size: 60652
global_negatives: 0
in_batch_negatives: 0
limit: 10
offset: 9990

View file

@ -0,0 +1,72 @@
runtime:
wandb:
key_path: "/var/lib/tss/keys/${USER}/wandb.key"
name: "twhin-test"
entity: "-"
project: "twhin-test"
host: "https://https--wandb--prod--wandb.service.qus1.twitter.biz/"
training:
save_dir: "/tmp/model"
num_train_steps: 100000
checkpoint_every_n: 100000
train_log_every_n: 10
num_eval_steps: 1000
eval_log_every_n: 500
eval_timeout_in_s: 10000
num_epochs: 5
model:
translation_optimizer:
sgd:
lr: 0.05
learning_rate:
constant: 0.05
embeddings:
tables:
- name: user
num_embeddings: 424_241
embedding_dim: 4
data_type: fp32
optimizer:
sgd:
lr: 0.01
learning_rate:
constant: 0.01
- name: tweet
num_embeddings: 72_543
embedding_dim: 4
data_type: fp32
optimizer:
sgd:
lr: 0.005
learning_rate:
constant: 0.005
relations:
- name: fav
lhs: user
rhs: tweet
operator: translation
- name: reply
lhs: user
rhs: tweet
operator: translation
- name: retweet
lhs: user
rhs: tweet
operator: translation
- name: magic_recs
lhs: user
rhs: tweet
operator: translation
train_data:
data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*"
per_replica_batch_size: 500
global_negatives: 0
in_batch_negatives: 10
limit: 9990
validation_data:
data_root: "gs://follows_tml_01/tweet_eng/2023-01-23/large/edges/*"
per_replica_batch_size: 500
global_negatives: 0
in_batch_negatives: 10
limit: 10
offset: 9990

View file

@ -0,0 +1,38 @@
version: "3.9"
services:
chief-0:
depends_on:
- datasetworker-0
- datasetworker-1
image: local/torch
command: ['python', 'tml/projects/twhin/toy.py']
volumes: &volumes
- $HOME/workspace/tml:/usr/src/app/tml
- $HOME/.config:/root/.config
environment:
- SLURM_RANK_INDICATOR=chief
- TEMP_SLURM_NUM_READERS=2
- RANK=0
- WORLD_SIZE=1
- MASTER_ADDR=localhost
- MASTER_PORT=2020
datasetworker-0: &worker
image: local/torch
command: ['python', 'tml/projects/twhin/toy.py']
volumes: *volumes
ports:
- "2222:2222"
environment:
- SLURM_RANK_INDICATOR=reader
- TEMP_SLURM_NUM_READERS=2
datasetworker-1:
image: local/torch
command: ['python', 'tml/projects/twhin/toy.py']
volumes: *volumes
ports:
- "2223:2222"
environment:
- SLURM_RANK_INDICATOR=reader
- TEMP_SLURM_NUM_READERS=2
networks:
host:

View file

@ -0,0 +1,72 @@
runtime:
wandb:
key_path: "/var/lib/tss/keys/${USER}/wandb.key"
name: "twhin-test"
entity: "-"
project: "twhin-test"
host: "https://https--wandb--prod--wandb.service.qus1.twitter.biz/"
training:
save_dir: "/tmp/model"
num_train_steps: 100000
checkpoint_every_n: 100000
train_log_every_n: 10
num_eval_steps: 1000
eval_log_every_n: 500
eval_timeout_in_s: 10000
num_epochs: 5
model:
translation_optimizer:
sgd:
lr: 0.05
learning_rate:
constant: 0.05
embeddings:
tables:
- name: user
num_embeddings: 700_410_729
embedding_dim: 64
data_type: fp16
optimizer:
sgd:
lr: 0.01
learning_rate:
constant: 0.01
- name: tweet
num_embeddings: 250_543_984
embedding_dim: 64
data_type: fp16
optimizer:
adagrad:
lr: 0.01
learning_rate:
constant: 0.01
relations:
- name: fav
lhs: user
rhs: tweet
operator: translation
- name: reply
lhs: user
rhs: tweet
operator: translation
- name: retweet
lhs: user
rhs: tweet
operator: translation
- name: magic_recs
lhs: user
rhs: tweet
operator: translation
train_data:
data_root: "/home/-/twhin-sample/*"
per_replica_batch_size: 65536
global_negatives: 100
in_batch_negatives: 100
limit: 9990
validation_data:
data_root: "/home/-/twhin-sample/*"
per_replica_batch_size: 65536
global_negatives: 100
in_batch_negatives: 100
limit: 10
offset: 9990

View file

@ -0,0 +1,14 @@
from tml.core.config import base_config
import pydantic
class TwhinDataConfig(base_config.BaseConfig):
data_root: str
per_replica_batch_size: pydantic.PositiveInt
global_negatives: int
in_batch_negatives: int
limit: pydantic.PositiveInt
offset: pydantic.PositiveInt = pydantic.Field(
None, description="The offset to start reading from."
)

View file

@ -0,0 +1,22 @@
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.projects.twhin.models.config import TwhinModelConfig
from tml.projects.twhin.data.edges import EdgesDataset
def create_dataset(data_config: TwhinDataConfig, model_config: TwhinModelConfig):
tables = model_config.embeddings.tables
table_sizes = {table.name: table.num_embeddings for table in tables}
relations = model_config.relations
pos_batch_size = data_config.per_replica_batch_size
global_negatives = data_config.global_negatives
in_batch_negatives = data_config.in_batch_negatives
return EdgesDataset(
file_pattern=data_config.data_root,
relations=relations,
table_sizes=table_sizes,
global_negatives=global_negatives,
in_batch_negatives=in_batch_negatives,
batch_size=pos_batch_size,
)

View file

@ -0,0 +1,169 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple
from tml.common.batch import DataclassBatch
from tml.reader.dataset import Dataset
from tml.projects.twhin.models.config import Relation
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import torch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@dataclass
class EdgeBatch(DataclassBatch):
nodes: KeyedJaggedTensor
labels: torch.Tensor
rels: torch.Tensor
weights: torch.Tensor
class EdgesDataset(Dataset):
rng = np.random.default_rng()
def __init__(
self,
file_pattern: str,
table_sizes: Dict[str, int],
relations: List[Relation],
lhs_column_name: str = "lhs",
rhs_column_name: str = "rhs",
rel_column_name: str = "rel",
**dataset_kwargs
):
self.batch_size = dataset_kwargs["batch_size"]
self.table_sizes = table_sizes
self.num_tables = len(table_sizes)
self.table_names = list(table_sizes.keys())
self.relations = relations
self.relations_t = torch.tensor(
[
[self.table_names.index(relation.lhs), self.table_names.index(relation.rhs)]
for relation in relations
]
)
self.lhs_column_name = lhs_column_name
self.rhs_column_name = rhs_column_name
self.rel_column_name = rel_column_name
self.label_column_name = "label"
super().__init__(file_pattern=file_pattern, **dataset_kwargs)
def pa_to_batch(self, batch: pa.RecordBatch):
lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
label = torch.from_numpy(batch.column(self.label_column_name).to_numpy())
nodes = self._to_kjt(lhs, rhs, rel)
return EdgeBatch(
nodes=nodes,
rels=rel,
labels=label,
weights=torch.ones(batch.num_rows),
)
def _to_kjt(
self, lhs: torch.Tensor, rhs: torch.Tensor, rel: torch.Tensor
) -> Tuple[KeyedJaggedTensor, List[Tuple[int, int]]]:
"""Process edges that contain lhs index, rhs index, relation index.
Example:
```
tables = ["f0", "f1", "f2", "f3"]
relations = [["f0", "f1"], ["f1", "f2"], ["f1", "f0"], ["f2", "f1"], ["f0", "f2"]]
self.relations_t = torch.Tensor([[0, 1], [1, 2], [1, 0], [2, 1], [0, 2]])
lhs = [1, 6, 3, 1, 8]
rhs = [6, 3, 4, 4, 9]
rel = [0, 2, 1, 3, 4]
This corresponds to the following "edges":
edges = [
{"lhs": 1, "rhs": 6, "relation": ["f0", "f1"]},
{"lhs": 6, "rhs": 3, "relation": ["f1", "f0"]},
{"lhs": 3, "rhs": 4, "relation": ["f1", "f2"]},
{"lhs": 1, "rhs": 4, "relation": ["f2", "f1"]},
{"lhs": 8, "rhs": 9, "relation": ["f0", "f2"]},
]
```
Returns a KeyedJaggedTensor used to look up all embeddings.
Note: We treat the lhs and rhs as though they're separate lookups: `len(lenghts) == 2 * bsz * len(tables)`.
This differs from the DLRM pattern where we have `len(lengths) = bsz * len(tables)`.
For the example above:
```
lookups = tensor([
[0., 1.],
[1., 6.],
[1., 6.],
[0., 3.],
[1., 3.],
[2., 4.],
[2., 1.],
[1., 4.],
[0., 8.],
[2., 9.]
])
kjt = KeyedJaggedTensor(
features=["f0", "f1", "f2"]
values=[
1, 3, 8, # f0
6, 6, 3, 4, # f1
4, 1, 9 # f2
]
lengths=[
1, 0, 0, 1, 0, 0, 0, 0, 1, 0, # f0
0, 1, 1, 0, 1, 0, 0, 1, 0, 0, # f1
0, 0, 0, 0, 0, 1, 1, 0, 0, 1, # f2
)
```
Note:
- values = [values for f0] + [values for f1] + [values for f2]
- lengths are always 0 or 1, and sum(lengths) = len(values) = 2 * bsz
"""
lookups = torch.concat((lhs[:, None], self.relations_t[rel], rhs[:, None]), dim=1)
index = torch.LongTensor([1, 0, 2, 3])
lookups = lookups[:, index].reshape(2 * self.batch_size, 2)
# values is just the row indices into each table, ordered by the table indices
_, indices = torch.sort(lookups[:, 0], dim=0, stable=True)
values = lookups[indices][:, 1].int()
# lengths[table_idx * batch_size + i] == whether the ith lookup is for the table with index table_idx
lengths = torch.arange(self.num_tables)[:, None].eq(lookups[:, 0])
lengths = lengths.reshape(-1).int()
return KeyedJaggedTensor(keys=self.table_names, values=values, lengths=lengths)
def to_batches(self):
ds = super().to_batches()
batch_size = self._dataset_kwargs["batch_size"]
names = [
self.lhs_column_name,
self.rhs_column_name,
self.rel_column_name,
self.label_column_name,
]
for _, batch in enumerate(ds):
# Pass along positive edges
lhs = batch.column(self.lhs_column_name)
rhs = batch.column(self.rhs_column_name)
rel = batch.column(self.rel_column_name)
label = pa.array(np.ones(batch_size, dtype=np.int64))
yield pa.RecordBatch.from_arrays(
arrays=[lhs, rhs, rel, label],
names=names,
)

View file

@ -0,0 +1,6 @@
import pytest
from unittest.mock import Mock
def test_create_dataset():
pass

View file

@ -0,0 +1,66 @@
"""Tests edges dataset functionality."""
from unittest.mock import patch
import os
import tempfile
from tml.projects.twhin.data.edges import EdgesDataset
from tml.projects.twhin.models.config import Relation
from fsspec.implementations.local import LocalFileSystem
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
TABLE_SIZES = {"user": 16, "author": 32}
RELATIONS = [
Relation(name="fav", lhs="user", rhs="author"),
Relation(name="engaged_with_reply", lhs="author", rhs="user"),
]
def test_gen():
import os
import tempfile
from fsspec.implementations.local import LocalFileSystem
import pyarrow as pa
import pyarrow.parquet as pq
lhs = pa.array(np.arange(4))
rhs = pa.array(np.flip(np.arange(4)))
rel = pa.array([0, 1, 0, 0])
names = ["lhs", "rhs", "rel"]
with tempfile.TemporaryDirectory() as tmpdir:
table = pa.Table.from_arrays([lhs, rhs, rel], names=names)
writer = pq.ParquetWriter(
os.path.join(tmpdir, "example.parquet"),
table.schema,
)
writer.write_table(table)
writer.close()
ds = EdgesDataset(
file_pattern=os.path.join(tmpdir, "*"),
table_sizes=TABLE_SIZES,
relations=RELATIONS,
batch_size=4,
)
ds.FS = LocalFileSystem()
dl = ds.dataloader()
batch = next(iter(dl))
# labels should be positive
labels = batch.labels
assert (labels[:4] == 1).sum() == 4
# make sure positive examples are what we expect
kjt_values = batch.nodes.values()
users, authors = torch.split(kjt_values, 4, dim=0)
assert torch.equal(users[:4], torch.tensor([0, 2, 2, 3]))
assert torch.equal(authors[:4], torch.tensor([3, 1, 1, 0]))

View file

@ -0,0 +1,12 @@
chief: &gpu
mem: 1.4Ti
cpu: 24
num_accelerators: 16
accelerator_type: a100
dataset_dispatcher:
mem: 2Gi
cpu: 2
num_dataset_workers: 4
dataset_worker:
mem: 14Gi
cpu: 2

17
projects/twhin/metrics.py Normal file
View file

@ -0,0 +1,17 @@
import torch
import torchmetrics as tm
import tml.core.metrics as core_metrics
def create_metrics(
device: torch.device,
):
metrics = dict()
metrics.update(
{
"AUC": core_metrics.Auc(128),
}
)
metrics = tm.MetricCollection(metrics).to(device)
return metrics

View file

@ -0,0 +1,54 @@
import typing
import enum
from tml.common.modules.embedding.config import LargeEmbeddingsConfig
from tml.core.config import base_config
from tml.optimizers.config import OptimizerConfig
import pydantic
from pydantic import validator
class TwhinEmbeddingsConfig(LargeEmbeddingsConfig):
@validator("tables")
def embedding_dims_match(cls, tables):
embedding_dim = tables[0].embedding_dim
data_type = tables[0].data_type
for table in tables:
assert table.embedding_dim == embedding_dim, "Embedding dimensions for all nodes must match."
assert table.data_type == data_type, "Data types for all nodes must match."
return tables
class Operator(str, enum.Enum):
TRANSLATION = "translation"
class Relation(pydantic.BaseModel):
"""graph relationship properties and operator"""
name: str = pydantic.Field(..., description="Relationship name.")
lhs: str = pydantic.Field(
...,
description="Name of the entity on the left-hand-side of this relation. Must match a table name.",
)
rhs: str = pydantic.Field(
...,
description="Name of the entity on the right-hand-side of this relation. Must match a table name.",
)
operator: Operator = pydantic.Field(
Operator.TRANSLATION, description="Transformation to apply to lhs embedding before dot product."
)
class TwhinModelConfig(base_config.BaseConfig):
embeddings: TwhinEmbeddingsConfig
relations: typing.List[Relation]
translation_optimizer: OptimizerConfig
@validator("relations", each_item=True)
def valid_node_types(cls, relation, values, **kwargs):
table_names = [table.name for table in values["embeddings"].tables]
assert relation.lhs in table_names, f"Invalid lhs node type: {relation.lhs}"
assert relation.rhs in table_names, f"Invalid rhs node type: {relation.rhs}"
return relation

View file

@ -0,0 +1,172 @@
from typing import Callable
import math
from tml.projects.twhin.data.edges import EdgeBatch
from tml.projects.twhin.models.config import TwhinModelConfig
from tml.projects.twhin.data.config import TwhinDataConfig
from tml.common.modules.embedding.embedding import LargeEmbeddings
from tml.optimizers.optimizer import get_optimizer_class
from tml.optimizers.config import get_optimizer_algorithm_config
import torch
from torch import nn
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
class TwhinModel(nn.Module):
def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig):
super().__init__()
self.batch_size = data_config.per_replica_batch_size
self.table_names = [table.name for table in model_config.embeddings.tables]
self.large_embeddings = LargeEmbeddings(model_config.embeddings)
self.embedding_dim = model_config.embeddings.tables[0].embedding_dim
self.num_tables = len(model_config.embeddings.tables)
self.in_batch_negatives = data_config.in_batch_negatives
self.global_negatives = data_config.global_negatives
self.num_relations = len(model_config.relations)
# one bias per relation
self.all_trans_embs = torch.nn.parameter.Parameter(
torch.nn.init.uniform_(torch.empty(self.num_relations, self.embedding_dim))
)
def forward(self, batch: EdgeBatch):
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
# KeyedTensor
outs = self.large_embeddings(batch.nodes)
# 2B x TD
x = outs.values()
# 2B x T x D
x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)
# 2B x D
x = torch.sum(x, 1)
# B x 2 x D
x = x.reshape(self.batch_size, 2, self.embedding_dim)
# translated
translated = x[:, 1, :] + trans_embs
negs = []
if self.in_batch_negatives:
# construct dot products for negatives via matmul
for relation in range(self.num_relations):
rel_mask = batch.rels == relation
rel_count = rel_mask.sum()
if not rel_count:
continue
# R x D
lhs_matrix = x[rel_mask, 0, :]
rhs_matrix = x[rel_mask, 1, :]
lhs_perm = torch.randperm(lhs_matrix.shape[0])
# repeat until we have enough negatives
lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
lhs_indices = lhs_perm[: self.in_batch_negatives]
sampled_lhs = lhs_matrix[lhs_indices]
rhs_perm = torch.randperm(rhs_matrix.shape[0])
# repeat until we have enough negatives
rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
rhs_indices = rhs_perm[: self.in_batch_negatives]
sampled_rhs = rhs_matrix[rhs_indices]
# RS
negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
negs.append(negs_lhs)
negs.append(negs_rhs)
# dot product for positives
x = (x[:, 0, :] * translated).sum(-1)
# concat positives and negatives
x = torch.cat([x, *negs])
return {
"logits": x,
"probabilities": torch.sigmoid(x),
}
def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig):
for table in model_config.embeddings.tables:
optimizer_class = get_optimizer_class(table.optimizer)
optimizer_kwargs = get_optimizer_algorithm_config(table.optimizer).dict()
params = [
param
for name, param in model.large_embeddings.ebc.named_parameters()
if (name.startswith(f"embedding_bags.{table.name}"))
]
apply_optimizer_in_backward(
optimizer_class=optimizer_class,
params=params,
optimizer_kwargs=optimizer_kwargs,
)
return model
class TwhinModelAndLoss(torch.nn.Module):
def __init__(
self,
model,
loss_fn: Callable,
data_config: TwhinDataConfig,
device: torch.device,
) -> None:
"""
Args:
model: torch module to wrap.
loss_fn: Function for calculating loss, should accept logits and labels.
"""
super().__init__()
self.model = model
self.loss_fn = loss_fn
self.batch_size = data_config.per_replica_batch_size
self.in_batch_negatives = data_config.in_batch_negatives
self.device = device
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(batch)
logits = outputs["logits"]
num_negatives = 2 * self.batch_size * self.in_batch_negatives
num_positives = self.batch_size
neg_weight = float(num_positives) / num_negatives
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
weights = torch.cat(
[batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
)
losses = self.loss_fn(logits, labels, weights)
outputs.update(
{
"loss": losses,
"labels": labels,
"weights": weights,
}
)
# Allow multiple losses.
return losses, outputs

Some files were not shown because too many files have changed in this diff Show more