twitter-algorithm/trust_and_safety_models/toxicity/load_model.py

227 lines
7 KiB
Python
Raw Normal View History

import os
from toxicity_ml_pipeline.settings.default_settings_tox import LOCAL_DIR, MAX_SEQ_LENGTH
try:
from toxicity_ml_pipeline.optim.losses import MaskedBCE
except ImportError:
print('No MaskedBCE loss')
from toxicity_ml_pipeline.utils.helpers import execute_command
import tensorflow as tf
try:
from twitter.cuad.representation.models.text_encoder import TextEncoder
except ModuleNotFoundError:
print("No TextEncoder package")
try:
from transformers import TFAutoModelForSequenceClassification
except ModuleNotFoundError:
print("No HuggingFace package")
LOCAL_MODEL_DIR = os.path.join(LOCAL_DIR, "models")
def reload_model_weights(weights_dir, language, **kwargs):
optimizer = tf.keras.optimizers.Adam(0.01)
model_type = (
"twitter_bert_base_en_uncased_mlm"
if language == "en"
else "twitter_multilingual_bert_base_cased_mlm"
)
model = load(optimizer=optimizer, seed=42, model_type=model_type, **kwargs)
model.load_weights(weights_dir)
return model
def _locally_copy_models(model_type):
if model_type == "twitter_multilingual_bert_base_cased_mlm":
preprocessor = "bert_multi_cased_preprocess_3"
elif model_type == "twitter_bert_base_en_uncased_mlm":
preprocessor = "bert_en_uncased_preprocess_3"
else:
raise NotImplementedError
copy_cmd = """mkdir {local_dir}
gsutil cp -r ...
gsutil cp -r ..."""
execute_command(
copy_cmd.format(model_type=model_type, preprocessor=preprocessor, local_dir=LOCAL_MODEL_DIR)
)
return preprocessor
def load_encoder(model_type, trainable):
try:
model = TextEncoder(
max_seq_lengths=MAX_SEQ_LENGTH,
model_type=model_type,
cluster="gcp",
trainable=trainable,
enable_dynamic_shapes=True,
)
except (OSError, tf.errors.AbortedError) as e:
print(e)
preprocessor = _locally_copy_models(model_type)
model = TextEncoder(
max_seq_lengths=MAX_SEQ_LENGTH,
local_model_path=f"models/{model_type}",
local_preprocessor_path=f"models/{preprocessor}",
cluster="gcp",
trainable=trainable,
enable_dynamic_shapes=True,
)
return model
def get_loss(loss_name, from_logits, **kwargs):
loss_name = loss_name.lower()
if loss_name == "bce":
print("Binary CE loss")
return tf.keras.losses.BinaryCrossentropy(from_logits=from_logits)
if loss_name == "cce":
print("Categorical cross-entropy loss")
return tf.keras.losses.CategoricalCrossentropy(from_logits=from_logits)
if loss_name == "scce":
print("Sparse categorical cross-entropy loss")
return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits)
if loss_name == "focal_bce":
gamma = kwargs.get("gamma", 2)
print("Focal binary CE loss", gamma)
return tf.keras.losses.BinaryFocalCrossentropy(gamma=gamma, from_logits=from_logits)
if loss_name == 'masked_bce':
multitask = kwargs.get("multitask", False)
if from_logits or multitask:
raise NotImplementedError
print(f'Masked Binary Cross Entropy')
return MaskedBCE()
if loss_name == "inv_kl_loss":
raise NotImplementedError
raise ValueError(
f"This loss name is not valid: {loss_name}. Accepted loss names: BCE, masked BCE, CCE, sCCE, "
f"Focal_BCE, inv_KL_loss"
)
def _add_additional_embedding_layer(doc_embedding, glorot, seed):
doc_embedding = tf.keras.layers.Dense(768, activation="tanh", kernel_initializer=glorot)(doc_embedding)
doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding)
return doc_embedding
def _get_bias(**kwargs):
smart_bias_value = kwargs.get('smart_bias_value', 0)
print('Smart bias init to ', smart_bias_value)
output_bias = tf.keras.initializers.Constant(smart_bias_value)
return output_bias
def load_inhouse_bert(model_type, trainable, seed, **kwargs):
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string)
encoder = load_encoder(model_type=model_type, trainable=trainable)
doc_embedding = encoder([inputs])["pooled_output"]
doc_embedding = tf.keras.layers.Dropout(rate=0.1, seed=seed)(doc_embedding)
glorot = tf.keras.initializers.glorot_uniform(seed=seed)
if kwargs.get("additional_layer", False):
doc_embedding = _add_additional_embedding_layer(doc_embedding, glorot, seed)
if kwargs.get('content_num_classes', None):
probs = get_last_layer(glorot=glorot, last_layer_name='target_output', **kwargs)(doc_embedding)
second_probs = get_last_layer(num_classes=kwargs['content_num_classes'],
last_layer_name='content_output',
glorot=glorot)(doc_embedding)
probs = [probs, second_probs]
else:
probs = get_last_layer(glorot=glorot, **kwargs)(doc_embedding)
model = tf.keras.models.Model(inputs=inputs, outputs=probs)
return model, False
def get_last_layer(**kwargs):
output_bias = _get_bias(**kwargs)
if 'glorot' in kwargs:
glorot = kwargs['glorot']
else:
glorot = tf.keras.initializers.glorot_uniform(seed=kwargs['seed'])
layer_name = kwargs.get('last_layer_name', 'dense_1')
if kwargs.get('num_classes', 1) > 1:
last_layer = tf.keras.layers.Dense(
kwargs["num_classes"], activation="softmax", kernel_initializer=glorot,
bias_initializer=output_bias, name=layer_name
)
elif kwargs.get('num_raters', 1) > 1:
if kwargs.get('multitask', False):
raise NotImplementedError
last_layer = tf.keras.layers.Dense(
kwargs['num_raters'], activation="sigmoid", kernel_initializer=glorot,
bias_initializer=output_bias, name='probs')
else:
last_layer = tf.keras.layers.Dense(
1, activation="sigmoid", kernel_initializer=glorot,
bias_initializer=output_bias, name=layer_name
)
return last_layer
def load_bertweet(**kwargs):
bert = TFAutoModelForSequenceClassification.from_pretrained(
os.path.join(LOCAL_MODEL_DIR, "bertweet-base"),
num_labels=1,
classifier_dropout=0.1,
hidden_size=768,
)
if "num_classes" in kwargs and kwargs["num_classes"] > 2:
raise NotImplementedError
return bert, True
def load(
optimizer,
seed,
model_type="twitter_multilingual_bert_base_cased_mlm",
loss_name="BCE",
trainable=True,
**kwargs,
):
if model_type == "bertweet-base":
model, from_logits = load_bertweet()
else:
model, from_logits = load_inhouse_bert(model_type, trainable, seed, **kwargs)
pr_auc = tf.keras.metrics.AUC(curve="PR", name="pr_auc", from_logits=from_logits)
roc_auc = tf.keras.metrics.AUC(curve="ROC", name="roc_auc", from_logits=from_logits)
loss = get_loss(loss_name, from_logits, **kwargs)
if kwargs.get('content_num_classes', None):
second_loss = get_loss(loss_name=kwargs['content_loss_name'], from_logits=from_logits)
loss_weights = {'content_output': kwargs['content_loss_weight'], 'target_output': 1}
model.compile(
optimizer=optimizer,
loss={'content_output': second_loss, 'target_output': loss},
loss_weights=loss_weights,
metrics=[pr_auc, roc_auc],
)
else:
model.compile(
optimizer=optimizer,
loss=loss,
metrics=[pr_auc, roc_auc],
)
print(model.summary(), "logits: ", from_logits)
return model