twitter-algorithm-ml/projects/home/recap/data/generate_random_data.py

82 lines
2.6 KiB
Python
Raw Permalink Normal View History

import os
import json
from absl import app, flags, logging
import tensorflow as tf
from typing import Dict
from tml.projects.home.recap.data import tfe_parsing
from tml.core import config as tml_config_mod
import tml.projects.home.recap.config as recap_config_mod
flags.DEFINE_string("config_path", None, "Path to hyperparameters for model.")
flags.DEFINE_integer("n_examples", 100, "Numer of examples to generate.")
FLAGS = flags.FLAGS
def _generate_random_example(
tf_example_schema: Dict[str, tf.io.FixedLenFeature]
) -> Dict[str, tf.Tensor]:
example = {}
for feature_name, feature_spec in tf_example_schema.items():
dtype = feature_spec.dtype
if (dtype == tf.int64) or (dtype == tf.int32):
x = tf.experimental.numpy.random.randint(0, high=10, size=feature_spec.shape, dtype=dtype)
elif (dtype == tf.float32) or (dtype == tf.float64):
x = tf.random.uniform(shape=[feature_spec.shape], dtype=dtype)
else:
raise NotImplementedError(f"Unknown type {dtype}")
example[feature_name] = x
return example
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes:
feature = {}
serializers = {tf.float32: _float_feature, tf.int64: _int64_feature}
for feature_name, tensor in x.items():
feature[feature_name] = serializers[tensor.dtype](tensor)
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
def generate_data(data_path: str, config: recap_config_mod.RecapConfig):
with tf.io.gfile.GFile(config.train_data.seg_dense_schema.schema_path, "r") as f:
seg_dense_schema = json.load(f)["schema"]
tf_example_schema = tfe_parsing.create_tf_example_schema(
config.train_data,
seg_dense_schema,
)
record_filename = os.path.join(data_path, "random.tfrecord.gz")
with tf.io.TFRecordWriter(record_filename, "GZIP") as writer:
random_example = _generate_random_example(tf_example_schema)
serialized_example = _serialize_example(random_example)
writer.write(serialized_example)
def _generate_data_main(unused_argv):
config = tml_config_mod.load_config_from_yaml(recap_config_mod.RecapConfig, FLAGS.config_path)
# Find the path where to put the data
data_path = os.path.dirname(config.train_data.inputs)
logging.info("Putting random data in %s", data_path)
generate_data(data_path, config)
if __name__ == "__main__":
app.run(_generate_data_main)