twitter-algorithm-ml/projects/home/recap/model/entrypoint.py

305 lines
11 KiB
Python

from __future__ import annotations
from absl import logging
import torch
from typing import Optional, Callable, Mapping, Dict, Sequence, TYPE_CHECKING
from tml.projects.home.recap.model import feature_transform
from tml.projects.home.recap.model import config as model_config_mod
from tml.projects.home.recap.model import mlp
from tml.projects.home.recap.model import mask_net
from tml.projects.home.recap.model import numeric_calibration
from tml.projects.home.recap.model.model_and_loss import ModelAndLoss
import tml.projects.home.recap.model.config as model_config_mod
if TYPE_CHECKING:
from tml.projects.home.recap import config as config_mod
from tml.projects.home.recap.data.config import RecapDataConfig
from tml.projects.home.recap.model.config import ModelConfig
def sanitize(task_name):
return task_name.replace(".", "__")
def unsanitize(sanitized_task_name):
return sanitized_task_name.replace("__", ".")
def _build_single_task_model(task: model_config_mod.TaskModel, input_shape: int):
""" "Builds a model for a single task"""
if task.mlp_config:
return mlp.Mlp(in_features=input_shape, mlp_config=task.mlp_config)
elif task.dcn_config:
return dcn.Dcn(dcn_config=task.dcn_config, in_features=input_shape)
elif task.mask_net_config:
return mask_net.MaskNet(mask_net_config=task.mask_net_config, in_features=input_shape)
else:
raise ValueError("This should never be reached.")
class MultiTaskRankingModel(torch.nn.Module):
"""Multi-task ranking model."""
def __init__(
self,
input_shapes: Mapping[str, torch.Size],
config: ModelConfig,
data_config: RecapDataConfig,
return_backbone: bool = False,
):
"""Constructor for Multi task learning.
Assumptions made:
1. Tasks specified in data config match model architecture.
These are all validated in config.
"""
super().__init__()
self._config = config
self._data_config = data_config
self._preprocessor = feature_transform.build_features_preprocessor(
config.featurization_config, input_shapes
)
self.return_backbone = return_backbone
self.embeddings = None
self.small_embeddings = None
embedding_dims = 0
if config.large_embeddings:
from large_embeddings.models.learnable_embeddings import LargeEmbeddings
self.embeddings = LargeEmbeddings(large_embeddings_config=config.large_embeddings)
embedding_dims += sum([table.embedding_dim for table in config.large_embeddings.tables])
logging.info(f"Emb dim: {embedding_dims}")
if config.small_embeddings:
self.small_embeddings = SmallEmbedding(config.small_embeddings)
embedding_dims += sum([table.embedding_dim for table in config.small_embeddings.tables])
logging.info(f"Emb dim (with small embeddings): {embedding_dims}")
if "user_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["user_embedding"][-1]
self._user_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["user_embedding"][-1])
else:
self._user_embedding_layer_norm = None
if "user_eng_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["user_eng_embedding"][-1]
self._user_eng_embedding_layer_norm = torch.nn.LayerNorm(
input_shapes["user_eng_embedding"][-1]
)
else:
self._user_eng_embedding_layer_norm = None
if "author_embedding" in data_config.seg_dense_schema.renamed_features:
embedding_dims += input_shapes["author_embedding"][-1]
self._author_embedding_layer_norm = torch.nn.LayerNorm(input_shapes["author_embedding"][-1])
else:
self._author_embedding_layer_norm = None
input_dims = input_shapes["continuous"][-1] + input_shapes["binary"][-1] + embedding_dims
if config.position_debias_config:
self.position_debias_model = PositionDebias(config.position_debias_config)
input_dims += self.position_debias_model.out_features
else:
self.position_debias_model = None
logging.info(f"input dim: {input_dims}")
if config.multi_task_type in [
model_config_mod.MultiTaskType.SHARE_ALL,
model_config_mod.MultiTaskType.SHARE_PARTIAL,
]:
self._backbone = _build_single_task_model(config.backbone, input_dims)
else:
self._backbone = None
_towers: Dict[str, torch.nn.Module] = {}
_calibrators: Dict[str, torch.nn.Module] = {}
_affine_maps: Dict[str, torch.nn.Module] = {}
for task_name, task_architecture in config.tasks.items():
safe_name = sanitize(task_name)
# Complex input dimension calculation.
if config.multi_task_type == model_config_mod.MultiTaskType.SHARE_NONE:
num_inputs = input_dims
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
num_inputs = self._backbone.out_features
elif config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
num_inputs = input_dims + self._backbone.out_features
else:
raise ValueError("Unreachable branch of enum.")
# Annoyingly, ModuleDict doesn't allow . inside key names.
_towers[safe_name] = _build_single_task_model(task_architecture, num_inputs)
if task_architecture.affine_map:
affine_map = torch.nn.Linear(1, 1)
affine_map.weight.data = torch.tensor([[task_architecture.affine_map.scale]])
affine_map.bias.data = torch.tensor([task_architecture.affine_map.bias])
_affine_maps[safe_name] = affine_map
else:
_affine_maps[safe_name] = torch.nn.Identity()
_calibrators[safe_name] = numeric_calibration.NumericCalibration(
pos_downsampling_rate=data_config.tasks[task_name].pos_downsampling_rate,
neg_downsampling_rate=data_config.tasks[task_name].neg_downsampling_rate,
)
self._task_names = list(config.tasks.keys())
self._towers = torch.nn.ModuleDict(_towers)
self._affine_maps = torch.nn.ModuleDict(_affine_maps)
self._calibrators = torch.nn.ModuleDict(_calibrators)
self._counter = torch.autograd.Variable(torch.tensor(0), requires_grad=False)
def forward(
self,
continuous_features: torch.Tensor,
binary_features: torch.Tensor,
discrete_features: Optional[torch.Tensor] = None,
sparse_features=None, # Optional[KeyedJaggedTensor]
user_embedding: Optional[torch.Tensor] = None,
user_eng_embedding: Optional[torch.Tensor] = None,
author_embedding: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
weights: Optional[torch.Tensor] = None,
):
concat_dense_features = [
self._preprocessor(continuous_features=continuous_features, binary_features=binary_features)
]
if self.embeddings:
concat_dense_features.append(self.embeddings(sparse_features))
# Twhin embedding layer norms
if self.small_embeddings:
if discrete_features is None:
raise ValueError(
"Forward arg discrete_features is None, but since small_embeddings are used, a Tensor is expected."
)
concat_dense_features.append(self.small_embeddings(discrete_features))
if self._user_embedding_layer_norm:
if user_embedding is None:
raise ValueError(
"Forward arg user_embedding is None, but since Twhin user_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._user_embedding_layer_norm(user_embedding))
if self._user_eng_embedding_layer_norm:
if user_eng_embedding is None:
raise ValueError(
"Forward arg user_eng_embedding is None, but since Twhin user_eng_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._user_eng_embedding_layer_norm(user_eng_embedding))
if self._author_embedding_layer_norm:
if author_embedding is None:
raise ValueError(
"Forward arg author_embedding is None, but since Twhin author_embeddings are used by the model, a Tensor is expected."
)
concat_dense_features.append(self._author_embedding_layer_norm(author_embedding))
if self.position_debias_model:
if discrete_features is None:
raise ValueError(
"Forward arg discrete_features is None, but since position_debias_model is used, a Tensor is expected."
)
concat_dense_features.append(self.position_debias_model(discrete_features))
if discrete_features is not None and not (self.position_debias_model or self.small_embeddings):
logging.warning("Forward arg discrete_features is passed, but never used.")
concat_dense_features = torch.cat(concat_dense_features, dim=1)
if self._backbone:
if self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_ALL:
net = self._backbone(concat_dense_features)["output"]
elif self._config.multi_task_type == model_config_mod.MultiTaskType.SHARE_PARTIAL:
net = torch.cat(
[concat_dense_features, self._backbone(concat_dense_features)["output"]], dim=1
)
else:
net = concat_dense_features
backbone_result = net
all_logits = []
all_probabilities = []
all_calibrated_probabilities = []
for task_name in self._task_names:
safe_name = sanitize(task_name)
tower_outputs = self._towers[safe_name](net)
logits = tower_outputs["output"]
scaled_logits = self._affine_maps[safe_name](logits)
probabilities = torch.sigmoid(scaled_logits)
calibrated_probabilities = self._calibrators[safe_name](probabilities)
all_logits.append(scaled_logits)
all_probabilities.append(probabilities)
all_calibrated_probabilities.append(calibrated_probabilities)
results = {
"logits": torch.squeeze(torch.stack(all_logits, dim=1), dim=-1),
"probabilities": torch.squeeze(torch.stack(all_probabilities, dim=1), dim=-1),
"calibrated_probabilities": torch.squeeze(
torch.stack(all_calibrated_probabilities, dim=1), dim=-1
),
}
# Returning the backbone is intended for stitching post-tf conversion
# Leaving this on will ~200x the size of the output
# and could slow things down
if self.return_backbone:
results["backbone"] = backbone_result
return results
def create_ranking_model(
data_spec,
# Used for planner to be batch size aware.
config: config_mod.RecapConfig,
device: torch.device,
loss_fn: Optional[Callable] = None,
data_config=None,
return_backbone=False,
):
if list(config.model.tasks.values())[0].dlrm_config:
raise NotImplementedError()
model = EmbeddingRankingModel(
input_shapes=data_spec,
config=all_config.model,
data_config=all_config.train_data,
)
else:
model = MultiTaskRankingModel(
input_shapes=data_spec,
config=config.model,
data_config=data_config if data_config is not None else config.train_data,
return_backbone=return_backbone,
)
logging.info("***** Model Architecture *****")
logging.info(model)
logging.info("***** Named Parameters *****")
for elem in model.named_parameters():
logging.info(elem[0])
if loss_fn:
logging.info("***** Wrapping in loss *****")
model = ModelAndLoss(
model=model,
loss_fn=loss_fn,
stratifiers=config.model.stratifiers,
)
return model