20 lines
546 B
Python
20 lines
546 B
Python
|
import torch
|
||
|
|
||
|
|
||
|
class NumericCalibration(torch.nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
pos_downsampling_rate: float,
|
||
|
neg_downsampling_rate: float,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
# Using buffer to make sure they are on correct device (and not moved every time).
|
||
|
# Will also be part of state_dict.
|
||
|
self.register_buffer(
|
||
|
"ratio", torch.as_tensor(neg_downsampling_rate / pos_downsampling_rate), persistent=True
|
||
|
)
|
||
|
|
||
|
def forward(self, probs: torch.Tensor):
|
||
|
return probs * self.ratio / (1.0 - probs + (self.ratio * probs))
|