317 lines
11 KiB
Python
317 lines
11 KiB
Python
"""Torch and torchrec specific training and evaluation loops.
|
|
|
|
Features (go/100_enablements):
|
|
- CUDA data-fetch, compute, gradient-push overlap
|
|
- Large learnable embeddings through torchrec
|
|
- On/off-chief evaluation
|
|
- Warmstart/checkpoint management
|
|
- go/dataset-service 0-copy integration
|
|
|
|
"""
|
|
import datetime
|
|
import os
|
|
from typing import Callable, Dict, Iterable, List, Mapping, Optional
|
|
|
|
|
|
from tml.common import log_weights
|
|
import tml.common.checkpointing.snapshot as snapshot_lib
|
|
from tml.core.losses import get_global_loss_detached
|
|
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
|
|
from tml.core.train_pipeline import TrainPipelineSparseDist
|
|
|
|
import tree
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
import torchmetrics as tm
|
|
|
|
|
|
def get_new_iterator(iterable: Iterable):
|
|
"""
|
|
This obtain a new iterator from the iterable. If the iterable uses tf.data.Dataset internally,
|
|
getting a new iterator each N steps will avoid memory leak. To avoid the memory leak
|
|
calling iter(iterable) should return a "fresh" iterator using a fresh
|
|
(new instance of) tf.data.Iterator.
|
|
In particular, iterable can be a torch.utils.data.IterableDataset or a
|
|
torch.utils.data.DataLoader.
|
|
|
|
When using DDS, performing this reset does not change the order in which elements are received
|
|
(excluding elements already prefetched) provided that iter(iterable) internally uses
|
|
a new instance of tf.data.Dataset created by calling from_dataset_id.
|
|
This requirement is satisfied by RecapDataset.
|
|
:param iterable:
|
|
:return:
|
|
"""
|
|
return iter(iterable)
|
|
|
|
|
|
def _get_step_fn(pipeline, data_iterator, training: bool):
|
|
def step_fn():
|
|
# It turns out that model.train() and model.eval() simply switch a single field inside the model
|
|
# class,so it's somewhat safer to wrap in here.
|
|
if training:
|
|
pipeline._model.train()
|
|
else:
|
|
pipeline._model.eval()
|
|
|
|
outputs = pipeline.progress(data_iterator)
|
|
return tree.map_structure(lambda elem: elem.detach(), outputs)
|
|
|
|
return step_fn
|
|
|
|
|
|
@torch.no_grad()
|
|
def _run_evaluation(
|
|
pipeline,
|
|
dataset,
|
|
eval_steps: int,
|
|
metrics: tm.MetricCollection,
|
|
eval_batch_size: int,
|
|
logger=None,
|
|
):
|
|
"""Runs the evaluation loop over all evaluation iterators."""
|
|
dataset = get_new_iterator(dataset)
|
|
step_fn = _get_step_fn(pipeline, dataset, training=False)
|
|
last_time = datetime.datetime.now()
|
|
logging.info(f"Starting {eval_steps} steps of evaluation.")
|
|
for _ in range(eval_steps):
|
|
outputs = step_fn()
|
|
metrics.update(outputs)
|
|
eval_ex_per_s = (
|
|
eval_batch_size * eval_steps / (datetime.datetime.now() - last_time).total_seconds()
|
|
)
|
|
logging.info(f"eval examples_per_s : {eval_ex_per_s}")
|
|
metrics_result = metrics.compute()
|
|
# Resetting at end to release metrics memory not in use.
|
|
# Reset metrics to prevent accumulation between multiple evaluation splits and not report a
|
|
# running average.
|
|
metrics.reset()
|
|
return metrics_result
|
|
|
|
|
|
def train(
|
|
model: torch.nn.Module,
|
|
optimizer: torch.optim.Optimizer,
|
|
device: str,
|
|
save_dir: str,
|
|
logging_interval: int,
|
|
train_steps: int,
|
|
checkpoint_frequency: int,
|
|
dataset: Iterable,
|
|
worker_batch_size: int,
|
|
num_workers: Optional[int] = 0,
|
|
enable_amp: bool = False,
|
|
initial_checkpoint_dir: Optional[str] = None,
|
|
gradient_accumulation: Optional[int] = None,
|
|
logger_initializer: Optional[Callable] = None,
|
|
scheduler: _LRScheduler = None,
|
|
metrics: Optional[tm.MetricCollection] = None,
|
|
parameters_to_log: Optional[Dict[str, Callable]] = None,
|
|
tables_to_log: Optional[List[str]] = None,
|
|
) -> None:
|
|
"""Runs training and eval on the given TrainPipeline
|
|
|
|
Args:
|
|
dataset: data iterator for the training set
|
|
evaluation_iterators: data iterators for the different evaluation sets
|
|
scheduler: optional learning rate scheduler
|
|
output_transform_for_metrics: optional transformation functions to transorm the model
|
|
output and labels into a format the metrics can understand
|
|
"""
|
|
|
|
train_pipeline = TrainPipelineSparseDist(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
device=device,
|
|
enable_amp=enable_amp,
|
|
grad_accum=gradient_accumulation,
|
|
) # type: ignore[var-annotated]
|
|
|
|
# We explicitly initialize optimizer state here so that checkpoint will work properly.
|
|
if hasattr(train_pipeline._optimizer, "init_state"):
|
|
train_pipeline._optimizer.init_state()
|
|
|
|
save_state = {
|
|
"model": train_pipeline._model,
|
|
"optimizer": train_pipeline._optimizer,
|
|
"scaler": train_pipeline._grad_scaler,
|
|
}
|
|
|
|
chosen_checkpoint = None
|
|
checkpoint_handler = snapshot_lib.Snapshot(
|
|
save_dir=save_dir,
|
|
state=save_state,
|
|
)
|
|
|
|
if save_dir:
|
|
chosen_checkpoint = snapshot_lib.get_checkpoint(save_dir=save_dir, missing_ok=True)
|
|
|
|
start_step = 0
|
|
start_walltime = 0.0
|
|
if chosen_checkpoint:
|
|
# Skip restoration and exit if we should be finished.
|
|
chosen_checkpoint_global_step = snapshot_lib.step_from_checkpoint(chosen_checkpoint)
|
|
if not chosen_checkpoint_global_step < dist.get_world_size() * train_steps:
|
|
logging.info(
|
|
"Not restoring and finishing training as latest checkpoint "
|
|
f"{chosen_checkpoint} found "
|
|
f"at global_step ({chosen_checkpoint_global_step}) >= "
|
|
f"train_steps ({dist.get_world_size() * train_steps})"
|
|
)
|
|
return
|
|
logging.info(f"Restoring latest checkpoint from global_step {chosen_checkpoint_global_step}")
|
|
checkpoint_handler.restore(chosen_checkpoint)
|
|
start_step = checkpoint_handler.step
|
|
start_walltime = checkpoint_handler.walltime
|
|
elif initial_checkpoint_dir:
|
|
base, ckpt_step = os.path.split(initial_checkpoint_dir)
|
|
warmstart_handler = snapshot_lib.Snapshot(
|
|
save_dir=base,
|
|
state=save_state,
|
|
)
|
|
ckpt = snapshot_lib.get_checkpoint(save_dir=base, missing_ok=False, global_step=int(ckpt_step))
|
|
logging.info(
|
|
f"Restoring from initial_checkpoint_dir: {initial_checkpoint_dir}, but keeping starting step as 0."
|
|
)
|
|
warmstart_handler.restore(ckpt)
|
|
|
|
train_logger = logger_initializer(mode="train") if logger_initializer else None
|
|
train_step_fn = _get_step_fn(train_pipeline, get_new_iterator(dataset), training=True)
|
|
|
|
# Counting number of parameters in the model directly when creating it.
|
|
nb_param = 0
|
|
for p in model.parameters():
|
|
nb_param += p.numel()
|
|
logging.info(f"Model has {nb_param} parameters")
|
|
|
|
last_time = datetime.datetime.now()
|
|
start_time = last_time
|
|
last_pending_snapshot = None
|
|
for step in range(start_step, train_steps + 1):
|
|
checkpoint_handler.step = step
|
|
outputs = train_step_fn()
|
|
step_done_time = datetime.datetime.now()
|
|
checkpoint_handler.walltime = (step_done_time - start_time).total_seconds() + start_walltime
|
|
|
|
if scheduler:
|
|
scheduler.step()
|
|
|
|
if step % logging_interval == 0:
|
|
interval_time = (step_done_time - last_time).total_seconds()
|
|
steps_per_s = logging_interval / interval_time
|
|
worker_example_per_s = steps_per_s * worker_batch_size
|
|
global_example_per_s = worker_example_per_s * (1 + (num_workers or 0))
|
|
global_step = step
|
|
|
|
log_values = {
|
|
"global_step": global_step,
|
|
"loss": get_global_loss_detached(outputs["loss"]),
|
|
"steps_per_s": steps_per_s,
|
|
"global_example_per_s": global_example_per_s,
|
|
"worker_examples_per_s": worker_example_per_s,
|
|
"active_training_walltime": checkpoint_handler.walltime,
|
|
}
|
|
if parameters_to_log:
|
|
log_values.update(
|
|
log_weights.weights_to_log(
|
|
model=model,
|
|
how_to_log=parameters_to_log,
|
|
)
|
|
)
|
|
log_values = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), log_values)
|
|
|
|
if tables_to_log:
|
|
log_values.update(
|
|
log_weights.log_ebc_norms(
|
|
model_state_dict=train_pipeline._model.state_dict(),
|
|
ebc_keys=tables_to_log,
|
|
)
|
|
)
|
|
if train_logger:
|
|
train_logger.log(log_values, step=global_step)
|
|
log_line = ", ".join(f"{name}: {value}" for name, value in log_values.items())
|
|
logging.info(f"Step: {step}, training. {log_line}")
|
|
last_time = step_done_time
|
|
|
|
# If we just restored, do not save again.
|
|
if checkpoint_frequency and step > start_step and step % checkpoint_frequency == 0:
|
|
if last_pending_snapshot and not last_pending_snapshot.done():
|
|
logging.warning(
|
|
"Begin a new snapshot and the last one hasn't finished. That probably indicates "
|
|
"either you're snapshotting really often or something is wrong. Will now block and "
|
|
"wait for snapshot to finish before beginning the next one."
|
|
)
|
|
last_pending_snapshot.wait()
|
|
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
|
|
|
|
# Save if we did not just save.
|
|
if checkpoint_frequency and step % checkpoint_frequency != 0:
|
|
# For the final save, wait for the checkpoint to write to make sure the process doesn't finish
|
|
# before its completed.
|
|
last_pending_snapshot = checkpoint_handler.save(global_step=step * dist.get_world_size())
|
|
logging.info(f"Finished training steps: {step}, global_steps: {step * dist.get_world_size()}")
|
|
|
|
if last_pending_snapshot:
|
|
logging.info(f"Waiting for any checkpoints to finish.")
|
|
last_pending_snapshot.wait()
|
|
|
|
|
|
def log_eval_results(
|
|
results,
|
|
eval_logger,
|
|
partition_name: str,
|
|
step: int,
|
|
):
|
|
results = tree.map_structure(lambda elem: torch.as_tensor(elem).cpu(), results)
|
|
logging.info(f"Step: {step}, evaluation ({partition_name}).")
|
|
for metric_name, metric_value in results.items():
|
|
logging.info(f"\t{metric_name}: {metric_value:1.4e}")
|
|
|
|
if eval_logger:
|
|
eval_logger.log(results, step=step, commit=True)
|
|
|
|
|
|
def only_evaluate(
|
|
model: torch.nn.Module,
|
|
optimizer: torch.optim.Optimizer,
|
|
device: str,
|
|
save_dir: str,
|
|
num_train_steps: int,
|
|
dataset: Iterable,
|
|
eval_batch_size: int,
|
|
num_eval_steps: int,
|
|
eval_timeout_in_s: int,
|
|
eval_logger: Callable,
|
|
partition_name: str,
|
|
metrics: Optional[tm.MetricCollection] = None,
|
|
):
|
|
logging.info(f"Evaluating on partition {partition_name}.")
|
|
logging.info("Computing metrics:")
|
|
logging.info(metrics)
|
|
eval_pipeline = TrainPipelineSparseDist(model, optimizer, device) # type: ignore[var-annotated]
|
|
save_state = {
|
|
"model": eval_pipeline._model,
|
|
"optimizer": eval_pipeline._optimizer,
|
|
}
|
|
checkpoint_handler = snapshot_lib.Snapshot(
|
|
save_dir=save_dir,
|
|
state=save_state,
|
|
)
|
|
for checkpoint_path in snapshot_lib.checkpoints_iterator(save_dir, timeout=eval_timeout_in_s):
|
|
checkpoint_handler.restore(checkpoint_path)
|
|
step = checkpoint_handler.step
|
|
dataset = get_new_iterator(dataset)
|
|
results = _run_evaluation(
|
|
pipeline=eval_pipeline,
|
|
dataset=dataset,
|
|
eval_steps=num_eval_steps,
|
|
eval_batch_size=eval_batch_size,
|
|
metrics=metrics,
|
|
)
|
|
log_eval_results(results, eval_logger, partition_name, step=step)
|
|
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
if rank == 0:
|
|
snapshot_lib.mark_done_eval(checkpoint_path, partition_name)
|
|
if step >= num_train_steps:
|
|
return
|