twitter-algorithm-ml/projects/twhin/data/test_edges.py

67 lines
1.7 KiB
Python

"""Tests edges dataset functionality."""
from unittest.mock import patch
import os
import tempfile
from tml.projects.twhin.data.edges import EdgesDataset
from tml.projects.twhin.models.config import Relation
from fsspec.implementations.local import LocalFileSystem
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq
import torch
TABLE_SIZES = {"user": 16, "author": 32}
RELATIONS = [
Relation(name="fav", lhs="user", rhs="author"),
Relation(name="engaged_with_reply", lhs="author", rhs="user"),
]
def test_gen():
import os
import tempfile
from fsspec.implementations.local import LocalFileSystem
import pyarrow as pa
import pyarrow.parquet as pq
lhs = pa.array(np.arange(4))
rhs = pa.array(np.flip(np.arange(4)))
rel = pa.array([0, 1, 0, 0])
names = ["lhs", "rhs", "rel"]
with tempfile.TemporaryDirectory() as tmpdir:
table = pa.Table.from_arrays([lhs, rhs, rel], names=names)
writer = pq.ParquetWriter(
os.path.join(tmpdir, "example.parquet"),
table.schema,
)
writer.write_table(table)
writer.close()
ds = EdgesDataset(
file_pattern=os.path.join(tmpdir, "*"),
table_sizes=TABLE_SIZES,
relations=RELATIONS,
batch_size=4,
)
ds.FS = LocalFileSystem()
dl = ds.dataloader()
batch = next(iter(dl))
# labels should be positive
labels = batch.labels
assert (labels[:4] == 1).sum() == 4
# make sure positive examples are what we expect
kjt_values = batch.nodes.values()
users, authors = torch.split(kjt_values, 4, dim=0)
assert torch.equal(users[:4], torch.tensor([0, 2, 2, 3]))
assert torch.equal(authors[:4], torch.tensor([3, 1, 1, 0]))