twitter-algorithm-ml/common/run_training.py

110 lines
3.5 KiB
Python
Raw Normal View History

import os
import subprocess
import sys
from typing import Optional
from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined]
from twitter.ml.tensorflow.experimental.distributed import utils
import torch
import torch.distributed.run
def is_distributed_worker():
world_size = os.environ.get("WORLD_SIZE", None)
rank = os.environ.get("RANK", None)
return world_size is not None and rank is not None
def maybe_run_training(
train_fn,
module_name,
nproc_per_node: Optional[int] = None,
num_nodes: Optional[int] = None,
set_python_path_in_subprocess: bool = False,
is_chief: Optional[bool] = False,
**training_kwargs,
):
"""Wrapper function for single node, multi-GPU Pytorch training.
If the necessary distributed Pytorch environment variables
(WORLD_SIZE, RANK) have been set, then this function executes
`train_fn(**training_kwargs)`.
Otherwise, this function calls torchrun and points at the calling module
`module_name`. After this call, the necessary environment variables are set
and training will commence.
Args:
train_fn: The function that is responsible for training
module_name: The name of the module that this function was called from;
used to indicate torchrun entrypoint.
nproc_per_node: Number of workers per node; supported values.
num_nodes: Number of nodes, otherwise inferred from environment.
is_chief: If process is running on chief.
set_python_path_in_subprocess: A bool denoting whether to set PYTHONPATH.
"""
machines = utils.machine_from_env()
if num_nodes is None:
num_nodes = 1
if machines.num_workers:
num_nodes += machines.num_workers
if is_distributed_worker():
# world_size, rank, etc are set; assuming any other env vars are set (checks to come)
# start the actual training!
train_fn(**training_kwargs)
else:
if nproc_per_node is None:
if torch.cuda.is_available():
nproc_per_node = torch.cuda.device_count()
else:
nproc_per_node = machines.chief.num_accelerators
# Rejoin all arguments to send back through torchrec
# this is a temporary measure, will replace the os.system call
# with torchrun API calls
args = list(f"--{key}={val}" for key, val in training_kwargs.items())
cmd = [
"--nnodes",
str(num_nodes),
]
if nproc_per_node:
cmd.extend(["--nproc_per_node", str(nproc_per_node)])
if num_nodes > 1:
cluster_resolver = utils.cluster_resolver()
backend_address = cluster_resolver.cluster_spec().task_address("chief", 0)
cmd.extend(
[
"--rdzv_backend",
"c10d",
"--rdzv_id",
backend_address,
]
)
# Set localhost on chief because of https://github.com/pytorch/pytorch/issues/79388
if is_chief:
cmd.extend(["--rdzv_endpoint", "localhost:2222"])
else:
cmd.extend(["--rdzv_endpoint", backend_address])
else:
cmd.append("--standalone")
cmd.extend(
[
str(module_name),
*args,
]
)
logging.info(f"""Distributed running with cmd: '{" ".join(cmd)}'""")
# Call torchrun on this module; will spawn new processes and re-run this
# function, eventually calling "train_fn". The following line sets the PYTHONPATH to accommodate
# bazel stubbing for the main binary.
if set_python_path_in_subprocess:
subprocess.run(["torchrun"] + cmd, env={**os.environ, "PYTHONPATH": ":".join(sys.path)})
else:
torch.distributed.run.main(cmd)