106 lines
3.4 KiB
Python
106 lines
3.4 KiB
Python
from typing import Dict, Tuple
|
|
import math
|
|
import bisect
|
|
|
|
from tml.optimizers.config import (
|
|
LearningRate,
|
|
OptimizerConfig,
|
|
)
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
|
|
def compute_lr(lr_config, step):
|
|
"""Compute a learning rate."""
|
|
if lr_config.constant is not None:
|
|
return lr_config.constant
|
|
elif lr_config.piecewise_constant is not None:
|
|
return lr_config.piecewise_constant.learning_rate_values[
|
|
bisect.bisect_right(lr_config.piecewise_constant.learning_rate_boundaries, step)
|
|
]
|
|
elif lr_config.linear_ramp_to_constant is not None:
|
|
slope = (
|
|
lr_config.linear_ramp_to_constant.learning_rate
|
|
/ lr_config.linear_ramp_to_constant.num_ramp_steps
|
|
)
|
|
return min(lr_config.linear_ramp_to_constant.learning_rate, slope * step)
|
|
elif lr_config.linear_ramp_to_cosine is not None:
|
|
cfg = lr_config.linear_ramp_to_cosine
|
|
if step < cfg.num_ramp_steps:
|
|
slope = cfg.learning_rate / cfg.num_ramp_steps
|
|
return slope * step
|
|
elif step <= cfg.final_num_steps:
|
|
return cfg.final_learning_rate + (cfg.learning_rate - cfg.final_learning_rate) * 0.5 * (
|
|
1.0
|
|
+ math.cos(
|
|
math.pi * (step - cfg.num_ramp_steps) / (cfg.final_num_steps - cfg.num_ramp_steps)
|
|
)
|
|
)
|
|
else:
|
|
return cfg.final_learning_rate
|
|
else:
|
|
raise ValueError(f"No option selected in lr_config, passed {lr_config}")
|
|
|
|
|
|
class LRShim(_LRScheduler):
|
|
"""Shim to get learning rates into a LRScheduler.
|
|
|
|
This adheres to the torch.optim scheduler API and can be plugged anywhere that
|
|
e.g. exponential decay can be used.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
optimizer,
|
|
lr_dict: Dict[str, LearningRate],
|
|
last_epoch=-1,
|
|
verbose=False,
|
|
):
|
|
self.optimizer = optimizer
|
|
self.lr_dict = lr_dict
|
|
self.group_names = list(self.lr_dict.keys())
|
|
|
|
num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
|
|
if num_param_groups != len(lr_dict):
|
|
raise ValueError(
|
|
f"Optimizer had {len(optimizer.param_groups)}, but config had {len(lr_dict)}."
|
|
)
|
|
|
|
super().__init__(optimizer, last_epoch, verbose)
|
|
|
|
def get_lr(self):
|
|
if not self._get_lr_called_within_step:
|
|
logging.warn(
|
|
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
|
|
UserWarning,
|
|
)
|
|
return self._get_closed_form_lr()
|
|
|
|
def _get_closed_form_lr(self):
|
|
return [compute_lr(lr_config, self.last_epoch) for lr_config in self.lr_dict.values()]
|
|
|
|
|
|
def get_optimizer_class(optimizer_config: OptimizerConfig):
|
|
if optimizer_config.adam is not None:
|
|
return torch.optim.Adam
|
|
elif optimizer_config.sgd is not None:
|
|
return torch.optim.SGD
|
|
elif optimizer_config.adagrad is not None:
|
|
return torch.optim.Adagrad
|
|
|
|
|
|
def build_optimizer(
|
|
model: torch.nn.Module, optimizer_config: OptimizerConfig
|
|
) -> Tuple[Optimizer, _LRScheduler]:
|
|
"""Builds an optimizer and LR scheduler from an OptimizerConfig.
|
|
Note: use this when you want the same optimizer and learning rate schedule for all your parameters.
|
|
"""
|
|
optimizer_class = get_optimizer_class(optimizer_config)
|
|
optimizer = optimizer_class(model.parameters(), **optimizer_config.sgd.dict())
|
|
# We're passing everything in as one group here
|
|
scheduler = LRShim(optimizer, lr_dict={"ALL_PARAMS": optimizer_config.learning_rate})
|
|
return optimizer, scheduler
|