twitter-algorithm-ml/core/train_pipeline.py

627 lines
21 KiB
Python

"""
Taken from https://raw.githubusercontent.com/pytorch/torchrec/v0.3.2/torchrec/distributed/train_pipeline.py
with TrainPipelineSparseDist.progress modified to support gradient accumulation.
"""
import abc
from dataclasses import dataclass, field
import logging
from typing import (
Any,
cast,
Dict,
Generic,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
import torch
from torch.autograd.profiler import record_function
from torch.fx.node import Node
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
ShardedModule,
)
from torchrec.distributed.types import Awaitable
from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor
from torchrec.streamable import Multistreamable, Pipelineable
logger: logging.Logger = logging.getLogger(__name__)
In = TypeVar("In", bound=Pipelineable)
Out = TypeVar("Out")
class TrainPipeline(abc.ABC, Generic[In, Out]):
@abc.abstractmethod
def progress(self, dataloader_iter: Iterator[In]) -> Out:
pass
def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In:
assert isinstance(
batch, (torch.Tensor, Pipelineable)
), f"{type(batch)} must implement Pipelineable interface"
return cast(In, batch.to(device=device, non_blocking=non_blocking))
def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
# PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
# to tell the allocator about all these streams. Otherwise, the allocator might free the
# underlying memory of the tensor once it is no longer used by the creator stream. This is
# a notable programming trick when we write programs using multi CUDA streams.
cur_stream = torch.cuda.current_stream()
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(cur_stream)
class TrainPipelineBase(TrainPipeline[In, Out]):
"""
This class runs training iterations using a pipeline of two stages, each as a CUDA
stream, namely, the current (default) stream and `self._memcpy_stream`. For each
iteration, `self._memcpy_stream` moves the input from host (CPU) memory to GPU
memory, and the default stream runs forward, backward, and optimization.
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = (
torch.cuda.Stream() if device.type == "cuda" else None
)
self._cur_batch: Optional[In] = None
self._connected = False
def _connect(self, dataloader_iter: Iterator[In]) -> None:
cur_batch = next(dataloader_iter)
self._cur_batch = cur_batch
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
self._connected = True
def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self._connected:
self._connect(dataloader_iter)
# Fetch next batch
with record_function("## next_batch ##"):
next_batch = next(dataloader_iter)
cur_batch = self._cur_batch
assert cur_batch is not None
if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## wait_for_batch ##"):
_wait_for_batch(cur_batch, self._memcpy_stream)
with record_function("## forward ##"):
(losses, output) = self._model(cur_batch)
if self._model.training:
with record_function("## backward ##"):
torch.sum(losses, dim=0).backward()
# Copy the next batch to GPU
self._cur_batch = cur_batch = next_batch
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=True)
# Update
if self._model.training:
with record_function("## optimizer ##"):
self._optimizer.step()
return output
class Tracer(torch.fx.Tracer):
# Disable proxying buffers during tracing. Ideally, proxying buffers would
# be disabled, but some models are currently mutating buffer values, which
# causes errors during tracing. If those models can be rewritten to not do
# that, we can likely remove this line
proxy_buffer_attributes = False
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
super().__init__()
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if isinstance(m, ShardedModule) or module_qualified_name in self._leaf_modules:
return True
return super().is_leaf_module(m, module_qualified_name)
@dataclass
class TrainPipelineContext:
# pyre-ignore [4]
input_dist_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
# pyre-ignore [4]
feature_processor_forwards: List[Any] = field(default_factory=list)
@dataclass
class ArgInfo:
# attributes of input batch, e.g. batch.attr1.attr2 call
# will produce ["attr1", "attr2"]
input_attrs: List[str]
# batch[attr1].attr2 will produce [True, False]
is_getitems: List[bool]
# name for kwarg of pipelined forward() call or None
# for a positional arg
name: Optional[str]
class PipelinedForward:
def __init__(
self,
name: str,
args: List[ArgInfo],
module: ShardedModule,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> None:
self._name = name
self._args = args
self._module = module
self._context = context
self._dist_stream = dist_stream
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
assert self._name in self._context.input_dist_requests
request = self._context.input_dist_requests[self._name]
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
# in case some delayed stream scheduling happens during the wait() call.
with torch.cuda.stream(self._dist_stream):
data = request.wait()
# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
if self._dist_stream is not None:
torch.cuda.current_stream().wait_stream(self._dist_stream)
cur_stream = torch.cuda.current_stream()
assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
data.record_stream(cur_stream)
ctx = self._context.module_contexts[self._name]
ctx.record_stream(cur_stream)
if len(self._context.feature_processor_forwards) > 0:
with record_function("## feature_processor ##"):
for sparse_feature in data:
if sparse_feature.id_score_list_features is not None:
for fp_forward in self._context.feature_processor_forwards:
sparse_feature.id_score_list_features = fp_forward(
sparse_feature.id_score_list_features
)
return self._module.compute_and_output_dist(self._context.module_contexts[self._name], data)
@property
def name(self) -> str:
return self._name
@property
def args(self) -> List[ArgInfo]:
return self._args
def _start_data_dist(
pipelined_modules: List[ShardedModule],
batch: In,
context: TrainPipelineContext,
) -> None:
context.input_dist_requests.clear()
context.module_contexts.clear()
for module in pipelined_modules:
forward = module.forward
assert isinstance(forward, PipelinedForward)
# Retrieve argument for the input_dist of EBC
# is_getitem True means this argument could be retrieved by a list
# False means this argument is getting while getattr
# and this info was done in the _rewrite_model by tracing the
# entire model to get the arg_info_list
args = []
kwargs = {}
for arg_info in forward.args:
if arg_info.input_attrs:
arg = batch
for attr, is_getitem in zip(arg_info.input_attrs, arg_info.is_getitems):
if is_getitem:
arg = arg[attr]
else:
arg = getattr(arg, attr)
if arg_info.name:
kwargs[arg_info.name] = arg
else:
args.append(arg)
else:
args.append(None)
# Start input distribution.
module_ctx = module.create_context()
context.module_contexts[forward.name] = module_ctx
context.input_dist_requests[forward.name] = module.input_dist(module_ctx, *args, **kwargs)
# Call wait on the first awaitable in the input dist for the tensor splits
for key, awaitable in context.input_dist_requests.items():
context.input_dist_requests[key] = awaitable.wait()
def _get_node_args_helper(
# pyre-ignore
arguments,
num_found: int,
feature_processor_arguments: Optional[List[Node]] = None,
) -> Tuple[List[ArgInfo], int]:
"""
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
It also counts the number of (args + kwargs) found.
"""
arg_info_list = [ArgInfo([], [], None) for _ in range(len(arguments))]
for arg, arg_info in zip(arguments, arg_info_list):
if arg is None:
num_found += 1
continue
while True:
if not isinstance(arg, torch.fx.Node):
break
child_node = arg
if child_node.op == "placeholder":
num_found += 1
break
# skip this fp node
elif feature_processor_arguments is not None and child_node in feature_processor_arguments:
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "builtins"
# pyre-ignore[16]
and child_node.target.__name__ == "getattr"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, False)
arg = child_node.args[0]
elif (
child_node.op == "call_function"
and child_node.target.__module__ == "_operator"
# pyre-ignore[16]
and child_node.target.__name__ == "getitem"
):
arg_info.input_attrs.insert(0, child_node.args[1])
arg_info.is_getitems.insert(0, True)
arg = child_node.args[0]
else:
break
return arg_info_list, num_found
def _get_node_args(
node: Node, feature_processor_nodes: Optional[List[Node]] = None
) -> Tuple[List[ArgInfo], int]:
num_found = 0
pos_arg_info_list, num_found = _get_node_args_helper(
node.args, num_found, feature_processor_nodes
)
kwargs_arg_info_list, num_found = _get_node_args_helper(node.kwargs.values(), num_found)
# Replace with proper names for kwargs
for name, arg_info_list in zip(node.kwargs, kwargs_arg_info_list):
arg_info_list.name = name
arg_info_list = pos_arg_info_list + kwargs_arg_info_list
return arg_info_list, num_found
def _get_unsharded_module_names_helper(
model: torch.nn.Module,
path: str,
unsharded_module_names: Set[str],
) -> bool:
sharded_children = set()
for name, child in model.named_children():
curr_path = path + name
if isinstance(child, ShardedModule):
sharded_children.add(name)
else:
child_sharded = _get_unsharded_module_names_helper(
child,
curr_path + ".",
unsharded_module_names,
)
if child_sharded:
sharded_children.add(name)
if len(sharded_children) > 0:
for name, _ in model.named_children():
if name not in sharded_children:
unsharded_module_names.add(path + name)
return len(sharded_children) > 0
def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]:
"""
Returns a list of top level modules do not contain any sharded sub modules.
"""
unsharded_module_names: Set[str] = set()
_get_unsharded_module_names_helper(
model,
"",
unsharded_module_names,
)
return list(unsharded_module_names)
def _rewrite_model( # noqa C901
model: torch.nn.Module,
context: TrainPipelineContext,
dist_stream: Optional[torch.cuda.streams.Stream],
) -> List[ShardedModule]:
# Get underlying nn.Module
if isinstance(model, DistributedModelParallel):
model = model.module
# Collect a list of sharded modules.
sharded_modules = {}
fp_modules = {}
for name, m in model.named_modules():
if isinstance(m, ShardedModule):
sharded_modules[name] = m
if isinstance(m, BaseGroupedFeatureProcessor):
fp_modules[name] = m
# Trace a model.
tracer = Tracer(leaf_modules=_get_unsharded_module_names(model))
graph = tracer.trace(model)
feature_processor_nodes = []
# find the fp node
for node in graph.nodes:
if node.op == "call_module" and node.target in fp_modules:
feature_processor_nodes.append(node)
# Select sharded modules, which are top-level in the forward call graph,
# i.e. which don't have input transformations, i.e.
# rely only on 'builtins.getattr'.
ret = []
for node in graph.nodes:
if node.op == "call_module" and node.target in sharded_modules:
total_num_args = len(node.args) + len(node.kwargs)
if total_num_args == 0:
continue
arg_info_list, num_found = _get_node_args(node, feature_processor_nodes)
if num_found == total_num_args:
logger.info(f"Module '{node.target}'' will be pipelined")
child = sharded_modules[node.target]
child.forward = PipelinedForward(
node.target,
arg_info_list,
child,
context,
dist_stream,
)
ret.append(child)
return ret
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
forward and backward. This helps hide the all2all latency while preserving the
training forward / backward ordering.
stage 3: forward, backward - uses default CUDA stream
stage 2: ShardedModule.input_dist() - uses data_dist CUDA stream
stage 1: device transfer - uses memcpy CUDA stream
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.
Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.
"""
synced_pipeline_id: Dict[int, int] = {}
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
enable_amp: bool = False,
enable_grad_scaling: bool = True,
grad_accum: Optional[int] = None,
) -> None:
self._model = model
self._optimizer = optimizer
self._device = device
self._enable_amp = enable_amp
# NOTE: Pending upstream feedback, but two flags because we can run AMP without CUDA but cannot scale gradients without CUDA.
# Background on gradient/loss scaling
# https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling
# https://pytorch.org/docs/stable/amp.html#gradient-scaling
self._enable_grad_scaling = enable_grad_scaling
self._grad_scaler = torch.cuda.amp.GradScaler(
enabled=self._enable_amp and self._enable_grad_scaling
)
logging.info(f"Amp is enabled: {self._enable_amp}")
# use two data streams to support two concurrent batches
if device.type == "cuda":
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = torch.cuda.Stream()
else:
if self._enable_amp:
logging.warning("Amp is enabled, but no CUDA available")
self._memcpy_stream: Optional[torch.cuda.streams.Stream] = None
self._data_dist_stream: Optional[torch.cuda.streams.Stream] = None
self._batch_i: Optional[In] = None
self._batch_ip1: Optional[In] = None
self._batch_ip2: Optional[In] = None
self._connected = False
self._context = TrainPipelineContext()
self._pipelined_modules: List[ShardedModule] = []
self._progress_calls = 0
if grad_accum is not None:
assert isinstance(grad_accum, int) and grad_accum > 0
self._grad_accum = grad_accum
def _connect(self, dataloader_iter: Iterator[In]) -> None:
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
# Try to pipeline input data dist.
self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_i, self._memcpy_stream)
_start_data_dist(self._pipelined_modules, batch_i, self._context)
# batch 2
with torch.cuda.stream(self._memcpy_stream):
batch_ip1 = next(dataloader_iter)
self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
self._connected = True
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
NOTE: This method has been updated to perform gradient accumulation.
If `_grad_accum` is set, then loss values are scaled by this amount and
optimizer update/reset is skipped for `_grad_accum` calls of `progress`
(congruent to training steps), and then update/reset on every `_grad_accum`th
step.
"""
should_step_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 1) % self._grad_accum == 0
) or self._grad_accum is None
should_reset_optimizer = (
self._grad_accum is not None
and self._progress_calls > 0
and (self._progress_calls + 2) % self._grad_accum == 0
) or self._grad_accum is None
if not self._connected:
self._connect(dataloader_iter)
elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):
self._sync_pipeline()
self.__class__.synced_pipeline_id[id(self._model)] = id(self)
if self._model.training and should_reset_optimizer:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
batch_ip2 = next(dataloader_iter)
self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)
batch_i = cast(In, self._batch_i)
batch_ip1 = cast(In, self._batch_ip1)
with record_function("## wait_for_batch ##"):
_wait_for_batch(batch_i, self._data_dist_stream)
# Forward
with record_function("## forward ##"):
# if using multiple streams (ie. CUDA), create an event in default stream
# before starting forward pass
if self._data_dist_stream:
event = torch.cuda.current_stream().record_event()
if self._enable_amp:
# conditionally apply the model to the batch in the autocast context
# it appears that `enabled=self._enable_amp` should handle this,
# but it does not.
with torch.autocast(
device_type=self._device.type,
dtype=torch.bfloat16,
enabled=self._enable_amp,
):
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
else:
(losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
# Data Distribution
with record_function("## sparse_data_dist ##"):
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_ip1, self._memcpy_stream)
# Ensure event in default stream has been called before
# starting data dist
if self._data_dist_stream:
# pyre-ignore [61]: Local variable `event` is undefined, or not always defined
self._data_dist_stream.wait_event(event)
_start_data_dist(self._pipelined_modules, batch_ip1, self._context)
if self._model.training:
# Backward
with record_function("## backward ##"):
# Loss is normalize by number of accumulation steps.
# The reported loss in `output['loss']` remains the unnormalized value.
if self._grad_accum is not None:
losses = losses / self._grad_accum
self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()
if should_step_optimizer:
# Update
with record_function("## optimizer ##"):
self._grad_scaler.step(self._optimizer)
self._grad_scaler.update()
self._batch_i = batch_ip1
self._batch_ip1 = batch_ip2
if self._model.training:
self._progress_calls += 1
return output
def _sync_pipeline(self) -> None:
"""
Syncs `PipelinedForward` for sharded modules with context and dist stream of the
current train pipeline. Used when switching between train pipelines for the same
model.
"""
for module in self._pipelined_modules:
module.forward._context = self._context
module.forward._dist_stream = self._data_dist_stream