twitter-algorithm-ml/common/log_weights.py

91 lines
3.3 KiB
Python

"""For logging model weights."""
import itertools
from typing import Callable, Dict, List, Optional, Union
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
def weights_to_log(
model: torch.nn.Module,
how_to_log: Optional[Union[Callable, Dict[str, Callable]]] = None,
):
"""Creates dict of reduced weights to log to give sense of training.
Args:
model: model to traverse.
how_to_log: if a function, then applies this to every parameter, if a dict
then only applies and logs specified parameters.
"""
if not how_to_log:
return
to_log = dict()
named_parameters = model.named_parameters()
logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
if isinstance(model, DistributedModelParallel):
named_parameters = itertools.chain(
named_parameters, model._dmp_wrapped_module.named_parameters()
)
logging.info(
f"Using dmp parameters: {list(name for name, _ in model._dmp_wrapped_module.named_parameters())}"
)
for param_name, params in named_parameters:
if callable(how_to_log):
how = how_to_log
else:
how = how_to_log.get(param_name) # type: ignore[assignment]
if not how:
continue # type: ignore
to_log[f"model/{how.__name__}/{param_name}"] = how(params.detach()).cpu().numpy()
return to_log
def log_ebc_norms(
model_state_dict,
ebc_keys: List[str],
sample_size: int = 4_000_000,
) -> Dict[str, torch.Tensor]:
"""Logs the norms of the embedding tables as specified by ebc_keys.
As of now, log average norm per rank.
Args:
model_state_dict: model.state_dict()
ebc_keys: list of embedding keys from state_dict to log. Must contain full name,
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
"""
norm_logs = dict()
for emb_key in ebc_keys:
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
if emb_key in model_state_dict:
emb_weight = model_state_dict[emb_key]
try:
emb_weight_tensor = emb_weight.local_tensor()
except AttributeError as e:
logging.info(e)
emb_weight_tensor = emb_weight
logging.info("Running Tensor.detach()")
emb_weight_tensor = emb_weight_tensor.detach()
sample_mask = torch.randperm(emb_weight_tensor.shape[0])[
: min(sample_size, emb_weight_tensor.shape[0])
]
# WARNING: .cpu() transfer executes malloc that may be the cause of memory leaks
# Change sample_size if the you observe frequent OOM errors or remove weight logging.
norms = emb_weight_tensor[sample_mask].cpu().norm(dim=1).to(torch.float32)
logging.info(f"Norm shape before reduction: {norms.shape}", rank=-1)
norms = norms.mean().to(torch.device(f"cuda:{dist.get_rank()}"))
all_norms = [
torch.zeros(1, dtype=norms.dtype).to(norms.device) for _ in range(dist.get_world_size())
]
dist.all_gather(all_norms, norms)
for idx, norm in enumerate(all_norms):
if norm != -1.0:
norm_logs[f"{emb_key}-norm-{idx}"] = norm
logging.info(f"Norm Logs are {norm_logs}")
return norm_logs