130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
|
import functools
|
||
|
import json
|
||
|
|
||
|
from tml.projects.home.recap.data import config as recap_data_config
|
||
|
|
||
|
from absl import logging
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
DEFAULTS_MAP = {"int64_list": 0, "float_list": 0.0, "bytes_list": ""}
|
||
|
DTYPE_MAP = {"int64_list": tf.int64, "float_list": tf.float32, "bytes_list": tf.string}
|
||
|
|
||
|
|
||
|
def create_tf_example_schema(
|
||
|
data_config: recap_data_config.SegDenseSchema,
|
||
|
segdense_schema,
|
||
|
):
|
||
|
"""Generate schema for deseralizing tf.Example.
|
||
|
|
||
|
Args:
|
||
|
segdense_schema: List of dicts of segdense features (includes feature_name, dtype, length).
|
||
|
labels: List of strings denoting labels.
|
||
|
|
||
|
Returns:
|
||
|
A dictionary schema suitable for deserializing tf.Example.
|
||
|
"""
|
||
|
segdense_config = data_config.seg_dense_schema
|
||
|
labels = list(data_config.tasks.keys())
|
||
|
used_features = (
|
||
|
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
|
||
|
)
|
||
|
logging.info(used_features)
|
||
|
|
||
|
tfe_schema = {}
|
||
|
for entry in segdense_schema:
|
||
|
feature_name = entry["feature_name"]
|
||
|
|
||
|
if feature_name in used_features:
|
||
|
length = entry["length"]
|
||
|
dtype = entry["dtype"]
|
||
|
|
||
|
if feature_name in labels:
|
||
|
logging.info(f"Label: feature name is {feature_name} type is {dtype}")
|
||
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||
|
length, DTYPE_MAP[dtype], DEFAULTS_MAP[dtype]
|
||
|
)
|
||
|
elif length == -1:
|
||
|
tfe_schema[feature_name] = tf.io.VarLenFeature(DTYPE_MAP[dtype])
|
||
|
else:
|
||
|
tfe_schema[feature_name] = tf.io.FixedLenFeature(
|
||
|
length, DTYPE_MAP[dtype], [DEFAULTS_MAP[dtype]] * length
|
||
|
)
|
||
|
for feature_name in used_features:
|
||
|
if feature_name not in tfe_schema:
|
||
|
raise ValueError(f"{feature_name} missing from schema: {segdense_config.schema_path}.")
|
||
|
return tfe_schema
|
||
|
|
||
|
|
||
|
@functools.lru_cache(1)
|
||
|
def make_mantissa_mask(mask_length: int) -> tf.Tensor:
|
||
|
"""For experimentating with emulating bfloat16 or less precise types."""
|
||
|
return tf.constant((1 << 32) - (1 << mask_length), dtype=tf.int32)
|
||
|
|
||
|
|
||
|
def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor:
|
||
|
"""For experimentating with emulating bfloat16 or less precise types."""
|
||
|
mask: tf.Tensor = make_mantissa_mask(mask_length)
|
||
|
return tf.bitcast(tf.bitwise.bitwise_and(tf.bitcast(tensor, tf.int32), mask), tensor.dtype)
|
||
|
|
||
|
|
||
|
def parse_tf_example(
|
||
|
serialized_example,
|
||
|
tfe_schema,
|
||
|
seg_dense_schema_config,
|
||
|
):
|
||
|
"""Parse serialized tf.Example into dict of tensors.
|
||
|
|
||
|
Args:
|
||
|
serialized_example: Serialized tf.Example to be parsed.
|
||
|
tfe_schema: Dictionary schema suitable for deserializing tf.Example.
|
||
|
|
||
|
Returns:
|
||
|
Dictionary of tensors to be used as model input.
|
||
|
"""
|
||
|
inputs = tf.io.parse_example(serialized=serialized_example, features=tfe_schema)
|
||
|
|
||
|
for new_feature_name, old_feature_name in seg_dense_schema_config.renamed_features.items():
|
||
|
inputs[new_feature_name] = inputs.pop(old_feature_name)
|
||
|
|
||
|
# This should not actually be used except for experimentation with low precision floats.
|
||
|
if "mask_mantissa_features" in seg_dense_schema_config:
|
||
|
for feature_name, mask_length in seg_dense_schema_config.mask_mantissa_features.items():
|
||
|
inputs[feature_name] = mask_mantissa(inputs[feature_name], mask_length)
|
||
|
|
||
|
# DANGER DANGER: This default seems really scary, and it's only here because it has to be visible
|
||
|
# at TF level.
|
||
|
# We should not return empty tensors if we dont use embeddings.
|
||
|
# Otherwise, it breaks numpy->pt conversion
|
||
|
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
|
||
|
for renamed_key in renamed_keys:
|
||
|
if "embedding" in renamed_key and (renamed_key not in inputs):
|
||
|
inputs[renamed_key] = tf.zeros([], tf.float32)
|
||
|
|
||
|
logging.info(f"parsed example and inputs are {inputs}")
|
||
|
return inputs
|
||
|
|
||
|
|
||
|
def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig):
|
||
|
"""Placeholder for seg dense.
|
||
|
|
||
|
In the future, when we use more seg dense variations, we can change this.
|
||
|
"""
|
||
|
with tf.io.gfile.GFile(data_config.seg_dense_schema.schema_path, "r") as f:
|
||
|
seg_dense_schema = json.load(f)["schema"]
|
||
|
|
||
|
tf_example_schema = create_tf_example_schema(
|
||
|
data_config,
|
||
|
seg_dense_schema,
|
||
|
)
|
||
|
|
||
|
logging.info("***** TF Example Schema *****")
|
||
|
logging.info(tf_example_schema)
|
||
|
|
||
|
parse = functools.partial(
|
||
|
parse_tf_example,
|
||
|
tfe_schema=tf_example_schema,
|
||
|
seg_dense_schema_config=data_config.seg_dense_schema,
|
||
|
)
|
||
|
return parse
|