twitter-algorithm-ml/core/metrics.py

149 lines
4.9 KiB
Python

"""Common metrics that also support multi task.
We assume multi task models will output [task_idx, ...] predictions
"""
from typing import Any, Dict
from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin
import torch
import torchmetrics as tm
def probs_and_labels(
outputs: Dict[str, torch.Tensor],
task_idx: int,
) -> Dict[str, torch.Tensor]:
preds = outputs["probabilities"]
target = outputs["labels"]
if task_idx >= 0:
preds = preds[:, task_idx]
target = target[:, task_idx]
return {
"preds": preds,
"target": target.int(),
}
class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["labels"])
value = outputs["labels"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities"])
value = outputs["probabilities"]
if self._task_idx >= 0:
value = value[:, self._task_idx]
return {"value": value}
class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC):
def transform(self, outputs):
outputs = self.maybe_apply_stratification(outputs, ["probabilities", "labels"])
return probs_and_labels(outputs, self._task_idx)
class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/util.py#L420
"""
def __init__(self, num_samples, **kwargs):
super().__init__(**kwargs)
self.num_samples = num_samples
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
pos_scores = scores[labels == 1]
neg_scores = scores[labels == 0]
result = {
"value": pos_scores[torch.randint(len(pos_scores), (self.num_samples,))]
> neg_scores[torch.randint(len(neg_scores), (self.num_samples,))]
}
return result
class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The ranks of all positives
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L73
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": pos_ranks}
return result
class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The reciprocal of the ranks of all
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L74
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": torch.div(torch.ones_like(pos_ranks), pos_ranks)}
return result
class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric):
"""
The fraction of positives that rank in the top K among their negatives
Note that this is basically precision@k
Based on:
https://github.com/facebookresearch/PyTorch-BigGraph/blob/a11ff0eb644b7e4cb569067c280112b47f40ef62/torchbiggraph/eval.py#L75
"""
def __init__(self, k: int, **kwargs):
super().__init__(**kwargs)
self.k = k
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
scores, labels = outputs["logits"], outputs["labels"]
_, sorted_indices = scores.sort(descending=True)
pos_ranks = labels[sorted_indices].nonzero(as_tuple=True)[0] + 1 # all ranks start from 1
result = {"value": (pos_ranks <= self.k).float()}
return result