Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions TensorFlow/LanguageModeling/BERT/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self,
token_type_ids=None,
use_one_hot_embeddings=False,
scope=None,
compute_type=tf.float32):
compute_type=tf.bfloat16):#tf.float32):
"""Constructor for BertModel.

Args:
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(self,
initializer_range=config.initializer_range,
do_return_all_layers=True)

self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.float32)
self.sequence_output = tf.cast(self.all_encoder_layers[-1], tf.bfloat16)#tf.float32)
# The "pooler" converts the encoded sequence tensor of shape
# [batch_size, seq_length, hidden_size] to a tensor of shape
# [batch_size, hidden_size]. This is necessary for segment-level
Expand Down Expand Up @@ -392,7 +392,7 @@ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):

def create_initializer(initializer_range=0.02):
"""Creates a `truncated_normal_initializer` with the given range."""
return tf.truncated_normal_initializer(stddev=initializer_range)
return tf.truncated_normal_initializer(stddev=initializer_range, dtype=tf.bfloat16)


def embedding_lookup(input_ids,
Expand Down Expand Up @@ -427,11 +427,12 @@ def embedding_lookup(input_ids,
embedding_table = tf.get_variable(
name=word_embedding_name,
shape=[vocab_size, embedding_size],
initializer=create_initializer(initializer_range))
initializer=create_initializer(initializer_range),
dtype=tf.bfloat16)

flat_input_ids = tf.reshape(input_ids, [-1])
flat_input_ids = tf.reshape(input_ids, [-1])#, tf.uint8)
if use_one_hot_embeddings:
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size,dtype=tf.bfloat16)
output = tf.matmul(one_hot_input_ids, embedding_table)
else:
output = tf.gather(embedding_table, flat_input_ids)
Expand All @@ -440,6 +441,7 @@ def embedding_lookup(input_ids,

output = tf.reshape(output,
input_shape[0:-1] + [input_shape[-1] * embedding_size])
#output = tf.cast(output, tf.bfloat16)
return (output, embedding_table)


Expand Down Expand Up @@ -489,32 +491,35 @@ def embedding_postprocessor(input_tensor,
width = input_shape[2]

output = input_tensor

if use_token_type:
if token_type_ids is None:
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
token_type_table = tf.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
initializer=create_initializer(initializer_range))
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
initializer=create_initializer(initializer_range),
dtype=tf.bfloat16)
flat_token_type_ids = tf.reshape(token_type_ids, [-1])#, tf.uint8)
if use_one_hot_embeddings:
# This vocab will be small so we always do one-hot here, since it is
# always faster for a small vocabulary.
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size, dtype=tf.bfloat16)
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
else:
token_type_embeddings = tf.gather(token_type_table, flat_token_type_ids)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
#token_type_embeddings = tf.cast(token_type_embeddings, tf.bfloat16)
output += token_type_embeddings

if use_position_embeddings:
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
initializer=create_initializer(initializer_range),
dtype=tf.bfloat16)
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
Expand Down Expand Up @@ -553,7 +558,7 @@ def create_attention_mask_from_input_mask(from_tensor, to_mask):
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
to_mask = tf.cast(to_mask, dtype=tf.float32)
to_mask = tf.cast(to_mask, dtype=tf.bfloat16)

from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
batch_size = from_shape[0]
Expand Down