twitter-algorithm-ml/reader/dataset.py

131 lines
4.3 KiB
Python
Raw Normal View History

"""Dataset to be overwritten that can work with or without distributed reading.
- Override `pa_to_batch` for dataset specific imputation, negative sampling, or coercion to Batch.
- Readers can be colocated or off trainer machines.
"""
import abc
import functools
import random
from typing import Optional
from fsspec.implementations.local import LocalFileSystem
import pyarrow.dataset as pads
import pyarrow as pa
import pyarrow.parquet
import pyarrow.flight
from pyarrow.ipc import IpcWriteOptions
import torch
from tml.common.batch import DataclassBatch
from tml.machines import environment as env
import tml.reader.utils as reader_utils
from tml.common.filesystem import infer_fs
from tml.ml_logging.torch_logging import logging
class _Reader(pa.flight.FlightServerBase):
"""Distributed reader flight server wrapping a dataset."""
def __init__(self, location: str, ds: "Dataset"):
super().__init__(location=location)
self._location = location
self._ds = ds
def do_get(self, _, __):
# NB: An updated schema (to account for column selection) has to be given the stream.
schema = next(iter(self._ds.to_batches())).schema
batches = self._ds.to_batches()
return pa.flight.RecordBatchStream(
data_source=pa.RecordBatchReader.from_batches(
schema=schema,
batches=batches,
),
options=IpcWriteOptions(use_threads=True),
)
class Dataset(torch.utils.data.IterableDataset):
LOCATION = "grpc://0.0.0.0:2222"
def __init__(self, file_pattern: str, **dataset_kwargs) -> None:
"""Specify batch size and column to select for.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
"""
self._file_pattern = file_pattern
self._fs = infer_fs(self._file_pattern)
self._dataset_kwargs = dataset_kwargs
logging.info(f"Using dataset_kwargs: {self._dataset_kwargs}")
self._files = self._fs.glob(self._file_pattern)
assert len(self._files) > 0, f"No files found at {self._file_pattern}"
logging.info(f"Found {len(self._files)} files: {', '.join(self._files[:4])}, ...")
self._schema = pa.parquet.read_schema(self._files[0], filesystem=self._fs)
self._validate_columns()
def _validate_columns(self):
columns = set(self._dataset_kwargs.get("columns", []))
wrong_columns = set(columns) - set(self._schema.names)
if wrong_columns:
raise Exception(f"Specified columns {list(wrong_columns)} not in schema.")
def serve(self):
self.reader = _Reader(location=self.LOCATION, ds=self)
self.reader.serve()
def _create_dataset(self):
return pads.dataset(
source=random.sample(self._files, len(self._files))[0],
format="parquet",
filesystem=self._fs,
exclude_invalid_files=False,
)
def to_batches(self):
"""This allows the init to control reading settings.
Refer to https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_dataset.
Perform `drop_remainder` behavior to afix batch size.
This does not shift our data distribution bc of volume and file-level shuffling on every repeat.
"""
batch_size = self._dataset_kwargs["batch_size"]
while True:
ds = self._create_dataset()
for batch in ds.to_batches(**self._dataset_kwargs):
if batch.num_rows < batch_size:
logging.info(f"Dropping remainder ({batch.num_rows}/{batch_size})")
break
yield batch
@abc.abstractmethod
def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch:
raise NotImplementedError
def dataloader(self, remote: bool = False):
if not remote:
return map(self.pa_to_batch, self.to_batches())
readers = get_readers(2)
return map(self.pa_to_batch, reader_utils.roundrobin(*readers))
GRPC_OPTIONS = [
("GRPC_ARG_KEEPALIVE_TIME_MS", 60000),
("GRPC_ARG_MIN_RECONNECT_BACKOFF_MS", 2000),
("GRPC_ARG_MAX_METADATA_SIZE", 1024 * 1024 * 1024),
]
def get_readers(num_readers_per_worker: int):
addresses = env.get_flight_server_addresses()
readers = []
for worker in addresses:
logging.info(f"Attempting connection to reader {worker}.")
client = pa.flight.connect(worker, generic_options=GRPC_OPTIONS)
client.wait_for_available(60)
reader = client.do_get(None).to_reader()
logging.info(f"Connected reader to {worker}.")
readers.append(reader)
return readers