twitter-algorithm-ml/projects/home/recap/model/numeric_calibration.py

20 lines
546 B
Python
Raw Normal View History

import torch
class NumericCalibration(torch.nn.Module):
def __init__(
self,
pos_downsampling_rate: float,
neg_downsampling_rate: float,
):
super().__init__()
# Using buffer to make sure they are on correct device (and not moved every time).
# Will also be part of state_dict.
self.register_buffer(
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
)
def forward(self, probs: torch.Tensor):
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))