92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
|
from dataclasses import dataclass
|
||
|
from typing import Tuple
|
||
|
|
||
|
from tml.common.batch import DataclassBatch
|
||
|
from tml.common.testing_utils import mock_pg
|
||
|
from tml.core import train_pipeline
|
||
|
|
||
|
import torch
|
||
|
from torchrec.distributed import DistributedModelParallel
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class MockDataclassBatch(DataclassBatch):
|
||
|
continuous_features: torch.Tensor
|
||
|
labels: torch.Tensor
|
||
|
|
||
|
|
||
|
class MockModule(torch.nn.Module):
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
self.model = torch.nn.Linear(10, 1)
|
||
|
self.loss_fn = torch.nn.BCEWithLogitsLoss()
|
||
|
|
||
|
def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
pred = self.model(batch.continuous_features)
|
||
|
loss = self.loss_fn(pred, batch.labels)
|
||
|
return (loss, pred)
|
||
|
|
||
|
|
||
|
def create_batch(bsz: int):
|
||
|
return MockDataclassBatch(
|
||
|
continuous_features=torch.rand(bsz, 10).float(),
|
||
|
labels=torch.bernoulli(torch.empty(bsz, 1).uniform_(0, 1)).float(),
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_sparse_pipeline():
|
||
|
device = torch.device("cpu")
|
||
|
model = MockModule().to(device)
|
||
|
|
||
|
steps = 8
|
||
|
example = create_batch(1)
|
||
|
dataloader = iter(example for _ in range(steps + 2))
|
||
|
|
||
|
results = []
|
||
|
with mock_pg():
|
||
|
d_model = DistributedModelParallel(model)
|
||
|
pipeline = train_pipeline.TrainPipelineSparseDist(
|
||
|
model=d_model,
|
||
|
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
|
||
|
device=device,
|
||
|
grad_accum=2,
|
||
|
)
|
||
|
for _ in range(steps):
|
||
|
results.append(pipeline.progress(dataloader))
|
||
|
|
||
|
results = [elem.detach().numpy() for elem in results]
|
||
|
# Check gradients are accumulated, i.e. results do not change for every 0th and 1th.
|
||
|
for first, second in zip(results[::2], results[1::2]):
|
||
|
assert first == second, results
|
||
|
|
||
|
# Check we do update gradients, i.e. results do change for every 1th and 2nd.
|
||
|
for first, second in zip(results[1::2], results[2::2]):
|
||
|
assert first != second, results
|
||
|
|
||
|
|
||
|
def test_amp():
|
||
|
device = torch.device("cpu")
|
||
|
model = MockModule().to(device)
|
||
|
|
||
|
steps = 8
|
||
|
example = create_batch(1)
|
||
|
dataloader = iter(example for _ in range(steps + 2))
|
||
|
|
||
|
results = []
|
||
|
with mock_pg():
|
||
|
d_model = DistributedModelParallel(model)
|
||
|
pipeline = train_pipeline.TrainPipelineSparseDist(
|
||
|
model=d_model,
|
||
|
optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
|
||
|
device=device,
|
||
|
enable_amp=True,
|
||
|
# Not supported on CPU.
|
||
|
enable_grad_scaling=False,
|
||
|
)
|
||
|
for _ in range(steps):
|
||
|
results.append(pipeline.progress(dataloader))
|
||
|
|
||
|
results = [elem.detach() for elem in results]
|
||
|
for value in results:
|
||
|
assert value.dtype == torch.bfloat16
|