34 lines
640 B
Python
34 lines
640 B
Python
from contextlib import contextmanager
|
|
import datetime
|
|
import os
|
|
from unittest.mock import patch
|
|
|
|
import torch.distributed as dist
|
|
from tml.ml_logging.torch_logging import logging
|
|
|
|
|
|
MOCK_ENV = {
|
|
"LOCAL_RANK": "0",
|
|
"WORLD_SIZE": "1",
|
|
"LOCAL_WORLD_SIZE": "1",
|
|
"MASTER_ADDR": "localhost",
|
|
"MASTER_PORT": "29501",
|
|
"RANK": "0",
|
|
}
|
|
|
|
|
|
@contextmanager
|
|
def mock_pg():
|
|
with patch.dict(os.environ, MOCK_ENV):
|
|
try:
|
|
dist.init_process_group(
|
|
backend="gloo",
|
|
timeout=datetime.timedelta(1),
|
|
)
|
|
yield
|
|
except:
|
|
dist.destroy_process_group()
|
|
raise
|
|
finally:
|
|
dist.destroy_process_group()
|