105 lines
3 KiB
Python
105 lines
3 KiB
Python
"""
|
|
Mixin that requires a transform to munge output dictionary of tensors a
|
|
model produces to a form that the torchmetrics.Metric.update expects.
|
|
|
|
By unifying on our signature for `update`, we can also now use
|
|
torchmetrics.MetricCollection which requires all metrics have
|
|
the same call signature.
|
|
|
|
To use, override this with a transform that munges `outputs`
|
|
into a kwargs dict that the inherited metric.update accepts.
|
|
|
|
Here are two examples of how to extend torchmetrics.SumMetric so that it accepts
|
|
an output dictionary of tensors and munges it to what SumMetric expects (single `value`)
|
|
for its update method.
|
|
|
|
1. Using as a mixin to inherit from or define a new metric class.
|
|
|
|
class Count(MetricMixin, SumMetric):
|
|
def transform(self, outputs):
|
|
return {'value': 1}
|
|
|
|
2. Redefine an existing metric class.
|
|
|
|
SumMetric = prepend_transform(SumMetric, lambda outputs: {'value': 1})
|
|
|
|
"""
|
|
from abc import abstractmethod
|
|
from typing import Callable, Dict, List
|
|
|
|
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
|
|
|
import torch
|
|
import torchmetrics
|
|
|
|
|
|
class MetricMixin:
|
|
@abstractmethod
|
|
def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict:
|
|
...
|
|
|
|
def update(self, outputs: Dict[str, torch.Tensor]):
|
|
results = self.transform(outputs)
|
|
# Do not try to update if any tensor is empty as a result of stratification.
|
|
for value in results.values():
|
|
if torch.is_tensor(value) and not value.nelement():
|
|
return
|
|
super().update(**results)
|
|
|
|
|
|
class TaskMixin:
|
|
def __init__(self, task_idx: int = -1, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._task_idx = task_idx
|
|
|
|
|
|
class StratifyMixin:
|
|
def __init__(
|
|
self,
|
|
stratifier=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self._stratifier = stratifier
|
|
|
|
def maybe_apply_stratification(
|
|
self, outputs: Dict[str, torch.Tensor], value_names: List[str]
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""Pick out examples with values for which the stratifier feature is equal to a specific stratifier indicator value."""
|
|
outputs = outputs.copy()
|
|
if not self._stratifier:
|
|
return outputs
|
|
stratifiers = outputs.get("stratifiers")
|
|
if not stratifiers:
|
|
return outputs
|
|
if stratifiers.get(self._stratifier.name) is None:
|
|
return outputs
|
|
|
|
mask = torch.flatten(outputs["stratifiers"][self._stratifier.name] == self._stratifier.value)
|
|
target_slice = torch.squeeze(mask.nonzero(), -1)
|
|
for value_name in value_names:
|
|
target = outputs[value_name]
|
|
outputs[value_name] = torch.index_select(target, 0, target_slice)
|
|
return outputs
|
|
|
|
|
|
def prepend_transform(base_metric: torchmetrics.Metric, transform: Callable):
|
|
"""Returns new class using MetricMixin and given base_metric.
|
|
|
|
Functionally the same using inheritance, just saves some lines of code
|
|
if no need for class attributes.
|
|
|
|
"""
|
|
|
|
def transform_method(_self, *args, **kwargs):
|
|
return transform(*args, **kwargs)
|
|
|
|
return type(
|
|
base_metric.__name__,
|
|
(
|
|
MetricMixin,
|
|
base_metric,
|
|
),
|
|
{"transform": transform_method},
|
|
)
|