110 lines
2.9 KiB
Python
110 lines
2.9 KiB
Python
|
"""Loss functions -- including multi task ones."""
|
||
|
|
||
|
import typing
|
||
|
|
||
|
from tml.core.loss_type import LossType
|
||
|
from tml.ml_logging.torch_logging import logging
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def _maybe_warn(reduction: str):
|
||
|
"""
|
||
|
Warning for reduction different than mean.
|
||
|
"""
|
||
|
if reduction != "mean":
|
||
|
logging.warn(
|
||
|
f"For the same global_batch_size, the gradient in DDP is guaranteed to be equal,"
|
||
|
f"to the gradient without DDP only for mean reduction. If you need this property for"
|
||
|
f"the provided reduction {reduction}, it needs to be implemented."
|
||
|
)
|
||
|
|
||
|
|
||
|
def build_loss(
|
||
|
loss_type: LossType,
|
||
|
reduction="mean",
|
||
|
):
|
||
|
_maybe_warn(reduction)
|
||
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||
|
|
||
|
def loss_fn(logits, labels):
|
||
|
return f(logits, labels.type_as(logits), reduction=reduction)
|
||
|
|
||
|
return loss_fn
|
||
|
|
||
|
|
||
|
def get_global_loss_detached(local_loss, reduction="mean"):
|
||
|
"""
|
||
|
Perform all_reduce to obtain the global loss function using the provided reduction.
|
||
|
:param local_loss: The local loss of the current rank.
|
||
|
:param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
|
||
|
:return: The reduced & detached global loss.
|
||
|
"""
|
||
|
if reduction != "mean":
|
||
|
logging.warn(
|
||
|
f"The reduction used in this function should be the same as the one used by "
|
||
|
f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately"
|
||
|
f"modified for reduction {reduction}."
|
||
|
)
|
||
|
|
||
|
if reduction not in ["mean", "sum"]:
|
||
|
raise ValueError(f"Reduction {reduction} is currently unsupported.")
|
||
|
|
||
|
global_loss = local_loss.detach()
|
||
|
|
||
|
if reduction == "mean":
|
||
|
global_loss.div_(torch.distributed.get_world_size())
|
||
|
|
||
|
torch.distributed.all_reduce(global_loss)
|
||
|
return global_loss
|
||
|
|
||
|
|
||
|
def build_multi_task_loss(
|
||
|
loss_type: LossType,
|
||
|
tasks: typing.List[str],
|
||
|
task_loss_reduction="mean",
|
||
|
global_reduction="mean",
|
||
|
pos_weights=None,
|
||
|
):
|
||
|
_maybe_warn(global_reduction)
|
||
|
_maybe_warn(task_loss_reduction)
|
||
|
f = _LOSS_TYPE_TO_FUNCTION[loss_type]
|
||
|
|
||
|
loss_reduction_fns = {
|
||
|
"mean": torch.mean,
|
||
|
"sum": torch.sum,
|
||
|
"min": torch.min,
|
||
|
"max": torch.max,
|
||
|
"median": torch.median,
|
||
|
}
|
||
|
|
||
|
def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
|
||
|
if pos_weights is None:
|
||
|
torch_weights = torch.ones([len(tasks)])
|
||
|
else:
|
||
|
torch_weights = torch.tensor(pos_weights)
|
||
|
|
||
|
losses = {}
|
||
|
for task_idx, task in enumerate(tasks):
|
||
|
task_logits = logits[:, task_idx]
|
||
|
label = labels[:, task_idx].type_as(task_logits)
|
||
|
|
||
|
loss = f(
|
||
|
task_logits,
|
||
|
label,
|
||
|
reduction=task_loss_reduction,
|
||
|
pos_weight=torch_weights[task_idx],
|
||
|
weight=weights[:, task_idx],
|
||
|
)
|
||
|
losses[f"loss/{task}"] = loss
|
||
|
|
||
|
losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
|
||
|
return losses
|
||
|
|
||
|
return loss_fn
|
||
|
|
||
|
|
||
|
_LOSS_TYPE_TO_FUNCTION = {
|
||
|
LossType.BCE_WITH_LOGITS: torch.nn.functional.binary_cross_entropy_with_logits
|
||
|
}
|