281 lines
9.7 KiB
Python
281 lines
9.7 KiB
Python
|
"""
|
||
|
Contains RCE metrics.
|
||
|
"""
|
||
|
import copy
|
||
|
from functools import partial
|
||
|
from typing import Union
|
||
|
|
||
|
from tml.metrics import aggregation
|
||
|
|
||
|
import torch
|
||
|
import torchmetrics
|
||
|
|
||
|
|
||
|
def _smooth(
|
||
|
value: torch.Tensor, label_smoothing: Union[float, torch.Tensor]
|
||
|
) -> Union[float, torch.Tensor]:
|
||
|
"""
|
||
|
Smooth given values.
|
||
|
Args:
|
||
|
value: Value to smooth.
|
||
|
label_smoothing: smoothing constant.
|
||
|
Returns: Smoothed values.
|
||
|
"""
|
||
|
return value * (1.0 - label_smoothing) + 0.5 * label_smoothing
|
||
|
|
||
|
|
||
|
def _binary_cross_entropy_with_clipping(
|
||
|
predictions: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
epsilon: Union[float, torch.Tensor],
|
||
|
reduction: str = "none",
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Clip Predictions and apply binary cross entropy.
|
||
|
This is done to match the implementation in keras at
|
||
|
https://github.com/keras-team/keras/blob/r2.9/keras/backend.py#L5294-L5300
|
||
|
Args:
|
||
|
predictions: Predicted probabilities.
|
||
|
target: Ground truth.
|
||
|
epsilon: Epsilon fuzz factor used to clip the predictions.
|
||
|
reduction: The reduction method to use.
|
||
|
|
||
|
Returns: Binary cross entropy on the clipped predictions.
|
||
|
|
||
|
"""
|
||
|
predictions = torch.clamp(predictions, epsilon, 1.0 - epsilon)
|
||
|
bce = -target * torch.log(predictions + epsilon)
|
||
|
bce -= (1.0 - target) * torch.log(1.0 - predictions + epsilon)
|
||
|
if reduction == "mean":
|
||
|
return torch.mean(bce)
|
||
|
return bce
|
||
|
|
||
|
|
||
|
class RCE(torchmetrics.Metric):
|
||
|
"""
|
||
|
Compute the relative cross entropy (`RCE <http://go/rce>`_).
|
||
|
|
||
|
RCE is metric used for models predicting probability of success (p), i.e. pCTR.
|
||
|
RCE represents the binary `cross entropy <https://en.wikipedia.org/wiki/Cross_entropy>` of
|
||
|
the model compared to a reference straw man model.
|
||
|
|
||
|
Binary cross entropy is defined as:
|
||
|
|
||
|
y = label; p = prediction;
|
||
|
binary cross entropy(example) = - y * log(p) - (1-y) * log(1-p)
|
||
|
|
||
|
Where y in {0, 1}
|
||
|
|
||
|
Cross entropy of a model is defined as:
|
||
|
|
||
|
CE(model) = average(binary cross entropy(example))
|
||
|
|
||
|
Over all the examples we aggregate on.
|
||
|
|
||
|
The straw man model is quite simple, it is a constant predictor, always predicting the average
|
||
|
over the labels.
|
||
|
|
||
|
RCE of a model is defined as:
|
||
|
|
||
|
RCE(model) = 100 * (CE(reference model) - CE(model)) / CE(reference model)
|
||
|
|
||
|
.. note:: Maximizing the likelihood is the same as minimizing the cross entropy or maximizing
|
||
|
the RCE. Since cross entropy is the average minus likelihood for the binary case.
|
||
|
|
||
|
.. note:: Binary cross entropy of an example is non negative, and equal to the
|
||
|
`KL divergence <(https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
|
||
|
#Properties>`
|
||
|
since p is constant, and its entropy is equal to zero.
|
||
|
|
||
|
.. note:: 0% RCE means as good as the straw man model.
|
||
|
100% means always predicts exactly the label. Namely, cross entropy of the model is
|
||
|
always zero. In practice 100% is impossible to achieve due to clipping.
|
||
|
Negative RCE means that the model is doing worse than the straw man.
|
||
|
This usually means an un-calibrated model, namely, the average prediction
|
||
|
is "far" from the average label. Examining NRCE might help identifying if that is
|
||
|
the case.
|
||
|
|
||
|
.. note:: RCE is not a "ratio" in the statistical
|
||
|
`level of measurement sense <https://en.wikipedia.org/wiki/Level_of_measurement>`.
|
||
|
The higher the model's RCE is the harder it is to improve it by an extra point.
|
||
|
|
||
|
For example:
|
||
|
Let CE(model) = 0.5 CE(reference model), then the RCE(model) = 50.
|
||
|
Now take a "twice as good" model:
|
||
|
Let CE(better model) = 0.5 CE(model) = 0.25 CE(reference model),
|
||
|
then the RCE(better model) = 75 and not 100.
|
||
|
|
||
|
.. note:: In order to keep the log function stable, typically p is limited to
|
||
|
lie in [CLAMP_EPSILON, 1-CLAMP_EPSILON],
|
||
|
where CLAMP_EPSILON is some small constant like: 1e-7.
|
||
|
Old implementation used 1e-5 clipping by default, current uses
|
||
|
tf.keras.backend.epsilon()
|
||
|
whose default is 1e-7.
|
||
|
|
||
|
.. note:: Since the reference model prediction is constant (probability),
|
||
|
CE(reference model) = H(average(label))
|
||
|
|
||
|
Where H is the standard
|
||
|
`entropy <https://en.wikipedia.org/wiki/Entropy_(information_theory)>` function.
|
||
|
|
||
|
.. note:: Must have at least 1 positive and 1 negative sample accumulated,
|
||
|
or RCE will come out as NaN.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
from_logits: whether or not predictions are logits or probabilities.
|
||
|
label_smoothing: label smoothing constant.
|
||
|
epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.
|
||
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||
|
"""
|
||
|
super().__init__(**kwargs)
|
||
|
self.from_logits = from_logits
|
||
|
self.label_smoothing = label_smoothing
|
||
|
self.epsilon = epsilon
|
||
|
self.kwargs = kwargs
|
||
|
|
||
|
self.mean_label = aggregation.StableMean(**kwargs)
|
||
|
self.binary_cross_entropy = aggregation.StableMean(**kwargs)
|
||
|
|
||
|
if self.from_logits:
|
||
|
self.bce_loss_fn = torch.nn.functional.binary_cross_entropy_with_logits
|
||
|
else:
|
||
|
self.bce_loss_fn = partial(_binary_cross_entropy_with_clipping, epsilon=self.epsilon)
|
||
|
|
||
|
# Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.
|
||
|
self.batch_metric = copy.deepcopy(self)
|
||
|
|
||
|
def update(
|
||
|
self, predictions: torch.Tensor, target: torch.Tensor, weight: float = 1.0
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Update the current rce.
|
||
|
Args:
|
||
|
predictions: Predicted values.
|
||
|
target: Ground truth. Should have same shape as predictions.
|
||
|
weight: The weight to use for the predicted values. Shape should be broadcastable to that of
|
||
|
predictions.
|
||
|
"""
|
||
|
target = _smooth(target, self.label_smoothing)
|
||
|
self.mean_label.update(target, weight)
|
||
|
self.binary_cross_entropy.update(
|
||
|
self.bce_loss_fn(predictions, target, reduction="none"), weight
|
||
|
)
|
||
|
|
||
|
def compute(self) -> torch.Tensor:
|
||
|
"""
|
||
|
Compute and return the accumulated rce.
|
||
|
"""
|
||
|
baseline_mean = self.mean_label.compute()
|
||
|
|
||
|
baseline_ce = _binary_cross_entropy_with_clipping(
|
||
|
baseline_mean, baseline_mean, reduction="mean", epsilon=self.epsilon
|
||
|
)
|
||
|
|
||
|
pred_ce = self.binary_cross_entropy.compute()
|
||
|
|
||
|
return (1.0 - (pred_ce / baseline_ce)) * 100
|
||
|
|
||
|
def reset(self):
|
||
|
"""
|
||
|
Reset the metric to its initial state.
|
||
|
"""
|
||
|
super().reset()
|
||
|
self.mean_label.reset()
|
||
|
self.binary_cross_entropy.reset()
|
||
|
|
||
|
def forward(self, *args, **kwargs):
|
||
|
"""
|
||
|
Serves the dual purpose of both computing the metric on the current batch of inputs but also
|
||
|
add the batch statistics to the overall accumulating metric state.
|
||
|
Input arguments are the exact same as corresponding ``update`` method.
|
||
|
The returned output is the exact same as the output of ``compute``.
|
||
|
"""
|
||
|
self.update(*args, **kwargs)
|
||
|
self.batch_metric.update(*args, **kwargs)
|
||
|
batch_result = self.batch_metric.compute()
|
||
|
self.batch_metric.reset()
|
||
|
return batch_result
|
||
|
|
||
|
|
||
|
class NRCE(RCE):
|
||
|
"""
|
||
|
Calculate the RCE of the normalizes model.
|
||
|
Where the normalized model prediction average is normalized to the average label seen so far.
|
||
|
Namely, the the normalized model prediction:
|
||
|
|
||
|
normalized model prediction(example) = (model prediction(example) * average(label)) /
|
||
|
average(model prediction)
|
||
|
|
||
|
Where the average is over all previously seen examples.
|
||
|
|
||
|
.. note:: average(normalized model prediction) = average(label)
|
||
|
|
||
|
.. note:: NRCE can be misleading since it is oblivious to mis-calibrations.
|
||
|
The common interpretation of NRCE is to measure how good your model could potentially
|
||
|
perform if it was well calibrated.
|
||
|
|
||
|
.. note:: A big gap between NRCE and RCE might indicate a badly calibrated model,
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, from_logits: bool = False, label_smoothing: float = 0, epsilon: float = 1e-7, **kwargs
|
||
|
):
|
||
|
"""
|
||
|
|
||
|
Args:
|
||
|
from_logits: whether or not predictions are logits or probabilities.
|
||
|
label_smoothing: label smoothing constant.
|
||
|
epsilon: Epsilon fuzz factor used on the predictions probabilities when from_logits is False.
|
||
|
It only used when computing the cross entropy but not when normalizing.
|
||
|
**kwargs: Additional parameters supported by all torchmetrics.Metric.
|
||
|
"""
|
||
|
super().__init__(from_logits=False, label_smoothing=0, epsilon=epsilon, **kwargs)
|
||
|
self.nrce_from_logits = from_logits
|
||
|
self.nrce_label_smoothing = label_smoothing
|
||
|
self.mean_prediction = aggregation.StableMean()
|
||
|
|
||
|
# Used to compute non-accumulated batch metric if `forward` or `__call__` functions are used.
|
||
|
self.batch_metric = copy.deepcopy(self)
|
||
|
|
||
|
def update(
|
||
|
self,
|
||
|
predictions: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
weight: Union[float, torch.Tensor] = 1.0,
|
||
|
):
|
||
|
"""
|
||
|
Update the current nrce.
|
||
|
Args:
|
||
|
predictions: Predicted values.
|
||
|
target: Ground truth. Should have same shape as predictions.
|
||
|
weight: The weight to use for the predicted values. Shape should be broadcastable to that of
|
||
|
predictions.
|
||
|
"""
|
||
|
predictions = torch.sigmoid(predictions) if self.nrce_from_logits else predictions
|
||
|
|
||
|
target = _smooth(target, self.nrce_label_smoothing)
|
||
|
self.mean_label.update(target, weight)
|
||
|
|
||
|
self.mean_prediction.update(predictions, weight)
|
||
|
|
||
|
normalizer = self.mean_label.compute() / self.mean_prediction.compute()
|
||
|
|
||
|
predictions = predictions * normalizer
|
||
|
|
||
|
self.binary_cross_entropy.update(
|
||
|
self.bce_loss_fn(predictions, target, reduction="none"), weight
|
||
|
)
|
||
|
|
||
|
def reset(self):
|
||
|
"""
|
||
|
Reset the metric to its initial state.
|
||
|
"""
|
||
|
super().reset()
|
||
|
self.mean_prediction.reset()
|