121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
|
from typing import Mapping, Tuple, Union
|
||
|
import torch
|
||
|
import torchrec
|
||
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
def keyed_tensor_from_tensors_dict(
|
||
|
tensor_map: Mapping[str, torch.Tensor]
|
||
|
) -> "torchrec.KeyedTensor":
|
||
|
"""
|
||
|
Convert a dictionary of torch tensor to torchrec keyed tensor
|
||
|
Args:
|
||
|
tensor_map:
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
keys = list(tensor_map.keys())
|
||
|
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
|
||
|
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
|
||
|
# [Batch_size x 1].
|
||
|
values = [
|
||
|
tensor_map[key] if len(tensor_map[key].shape) > 1 else torch.unsqueeze(tensor_map[key], -1)
|
||
|
for key in keys
|
||
|
]
|
||
|
return torchrec.KeyedTensor.from_tensor_list(keys, values)
|
||
|
|
||
|
|
||
|
def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
if tensor.is_sparse:
|
||
|
x = tensor.coalesce() # Ensure that the indices are ordered.
|
||
|
lengths = torch.bincount(x.indices()[0])
|
||
|
values = x.values()
|
||
|
else:
|
||
|
values = tensor
|
||
|
lengths = torch.ones(tensor.shape[0], dtype=torch.int32, device=tensor.device)
|
||
|
return values, lengths
|
||
|
|
||
|
|
||
|
def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedTensor":
|
||
|
"""
|
||
|
Convert a torch tensor to torchrec jagged tensor.
|
||
|
Note: Currently only support shape of [Batch_size] or [Batch_size x N] for dense tensors.
|
||
|
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x N]; the
|
||
|
dense_shape of the sparse tensor can be arbitrary.
|
||
|
Args:
|
||
|
tensor: a torch (sparse) tensor.
|
||
|
Returns:
|
||
|
"""
|
||
|
values, lengths = _compute_jagged_tensor_from_tensor(tensor)
|
||
|
return torchrec.JaggedTensor(values=values, lengths=lengths)
|
||
|
|
||
|
|
||
|
def keyed_jagged_tensor_from_tensors_dict(
|
||
|
tensor_map: Mapping[str, torch.Tensor]
|
||
|
) -> "torchrec.KeyedJaggedTensor":
|
||
|
"""
|
||
|
Convert a dictionary of (sparse) torch tensors to torchrec keyed jagged tensor.
|
||
|
Note: Currently only support shape of [Batch_size] or [Batch_size x 1] for dense tensors.
|
||
|
For sparse tensor the shape of .values() should be [Batch_size] or [Batch_size x 1]; the
|
||
|
dense_shape of the sparse tensor can be arbitrary.
|
||
|
Args:
|
||
|
tensor_map:
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
|
||
|
if not tensor_map:
|
||
|
return torchrec.KeyedJaggedTensor(
|
||
|
keys=[],
|
||
|
values=torch.zeros(0, dtype=torch.int),
|
||
|
lengths=torch.zeros(0, dtype=torch.int),
|
||
|
)
|
||
|
values = []
|
||
|
lengths = []
|
||
|
for tensor in tensor_map.values():
|
||
|
tensor_val, tensor_len = _compute_jagged_tensor_from_tensor(tensor)
|
||
|
values.append(torch.squeeze(tensor_val))
|
||
|
lengths.append(tensor_len)
|
||
|
|
||
|
values = torch.cat(values, axis=0)
|
||
|
lengths = torch.cat(lengths, axis=0)
|
||
|
|
||
|
return torchrec.KeyedJaggedTensor(
|
||
|
keys=list(tensor_map.keys()),
|
||
|
values=values,
|
||
|
lengths=lengths,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray:
|
||
|
return tf_tensor._numpy() # noqa
|
||
|
|
||
|
|
||
|
def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Tensor:
|
||
|
tensor = _tf_to_numpy(tensor)
|
||
|
# Pytorch does not support bfloat16, up cast to float32 to keep the same number of bits on exponent
|
||
|
if tensor.dtype.name == "bfloat16":
|
||
|
tensor = tensor.astype(np.float32)
|
||
|
|
||
|
tensor = torch.from_numpy(tensor)
|
||
|
if pin_memory:
|
||
|
tensor = tensor.pin_memory()
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def sparse_or_dense_tf_to_torch(
|
||
|
tensor: Union[tf.Tensor, tf.SparseTensor], pin_memory: bool
|
||
|
) -> torch.Tensor:
|
||
|
if isinstance(tensor, tf.SparseTensor):
|
||
|
tensor = torch.sparse_coo_tensor(
|
||
|
_dense_tf_to_torch(tensor.indices, pin_memory).t(),
|
||
|
_dense_tf_to_torch(tensor.values, pin_memory),
|
||
|
torch.Size(_tf_to_numpy(tensor.dense_shape)),
|
||
|
)
|
||
|
else:
|
||
|
tensor = _dense_tf_to_torch(tensor, pin_memory)
|
||
|
return tensor
|