Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5623deb

Browse files
afrozenatorcopybara-github
authored andcommitted
[Mesh-TF] Add is_training as an arg to mtf.dropout
PiperOrigin-RevId: 361088273
1 parent e19130b commit 5623deb

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

tensor2tensor/models/mtf_image_transformer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,10 @@ def import_to_batch_by_length(x, name):
243243
def layer_prepostprocess_dropout(x, hparams):
244244
batch_dim = x.shape.dims[0]
245245
model_dim = x.shape.dims[-1]
246+
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
247+
is_training = mode == tf.estimator.ModeKeys.TRAIN
246248
return mtf.dropout(
247-
x,
249+
x, is_training,
248250
keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
249251
noise_shape=mtf.Shape([batch_dim, model_dim]))
250252

@@ -259,6 +261,8 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
259261
x = mtf.reshape(
260262
x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim]))
261263
# [ self attention - ffn - residual + dropout] x n
264+
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
265+
is_training = mode == tf.estimator.ModeKeys.TRAIN
262266
for layer in range(hparams.num_decoder_layers):
263267
layer_name = "decoder_layer_%d" % layer
264268
with tf.variable_scope(layer_name):
@@ -268,6 +272,7 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
268272
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
269273
kv_dim,
270274
heads_dim,
275+
is_training,
271276
memory_w_dim=blocks_w_dim,
272277
mask_right=True,
273278
name="self_att"), hparams)
@@ -276,6 +281,7 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
276281
mtf.layers.dense_relu_dense(
277282
mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
278283
feedforward_dim,
284+
is_training,
279285
hparams.dropout,
280286
dropout_broadcast_dims=[length_dim]), hparams)
281287

@@ -305,6 +311,8 @@ def local_attention2d_spatial_decoder(x, kv_dim, heads_dim,
305311
batch_dim, num_h_blocks_dim, num_w_blocks_dim,
306312
blocks_h_dim, blocks_w_dim, model_dim
307313
]))
314+
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
315+
is_training = mode == tf.estimator.ModeKeys.TRAIN
308316
# Image Transformer Decoder
309317
# [ self attention - ffn - residual + dropout] x n
310318
for layer in range(hparams.num_decoder_layers):
@@ -316,6 +324,7 @@ def local_attention2d_spatial_decoder(x, kv_dim, heads_dim,
316324
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
317325
kv_dim,
318326
heads_dim,
327+
is_training,
319328
memory_h_dim=num_h_blocks_dim,
320329
memory_w_dim=num_w_blocks_dim,
321330
name="self_att"), hparams)
@@ -336,6 +345,8 @@ def local_attention1d_masked_decoder(x, kv_dim, heads_dim,
336345
"""Image Transformer decoder with local1D masked layers."""
337346
print(x)
338347
_, length_dim, model_dim = x.shape.dims
348+
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
349+
is_training = mode == tf.estimator.ModeKeys.TRAIN
339350
for layer in range(hparams.num_decoder_layers):
340351
layer_name = "decoder_layer_%d" % layer
341352
with tf.variable_scope(layer_name):
@@ -347,6 +358,7 @@ def local_attention1d_masked_decoder(x, kv_dim, heads_dim,
347358
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
348359
kv_dim,
349360
heads_dim,
361+
is_training,
350362
window_size=hparams.block_length,
351363
length_per_split=length_per_split,
352364
name="self_att"), hparams)

tensor2tensor/models/mtf_transformer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def _mtf_model_fn(self, features, mesh):
242242
hparams = self._hparams
243243
extra_losses = []
244244
targets = tf.to_int32(features["targets"])
245+
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
246+
is_training = mode == tf.estimator.ModeKeys.TRAIN
245247
if len(targets.get_shape()) > 2:
246248
tf.logging.info("targets = %s" % targets)
247249
targets = tf.squeeze(targets, [2, 3])
@@ -289,7 +291,7 @@ def pad_to_max_length(x):
289291

290292
def layer_prepostprocess_dropout(x):
291293
return mtf.dropout(
292-
x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
294+
x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
293295
noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
294296

295297
(inputs_embedding_var,
@@ -426,10 +428,11 @@ def _feedforward_layer(self, x, layer_type, losses=None):
426428
ValueError: if hparams make no sense
427429
"""
428430
hparams = self._hparams
429-
431+
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
432+
is_training = mode == tf.estimator.ModeKeys.TRAIN
430433
if layer_type == "drd":
431434
return mtf.layers.dense_relu_dense(
432-
x, self.feedforward_dim, dropout=hparams.relu_dropout,
435+
x, self.feedforward_dim, is_training, dropout=hparams.relu_dropout,
433436
dropout_broadcast_dims=[self.length_dim],
434437
master_dtype=self.master_dtype,
435438
slice_dtype=self.slice_dtype)
@@ -493,11 +496,13 @@ def _layer_stack(self,
493496
"""
494497
hparams = self._hparams
495498
is_incremental = (step_num is not None)
499+
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
500+
is_training = mode == tf.estimator.ModeKeys.TRAIN
496501
def layer_prepostprocess_dropout(x):
497502
if is_incremental:
498503
return x
499504
return mtf.dropout(
500-
x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
505+
x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
501506
noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
502507
num_layers = len(layers)
503508
num_layer_norms = num_layers + 1
@@ -540,6 +545,7 @@ def normalize(x):
540545
mtf.layers.multihead_attention(
541546
normalize(x), None,
542547
self_attention_mask, self.kv_dim, self.heads_dim,
548+
is_training,
543549
dropout=hparams.attention_dropout,
544550
dropout_broadcast_dims=[self.length_dim],
545551
master_dtype=self.master_dtype,
@@ -560,6 +566,7 @@ def normalize(x):
560566
mtf.layers.multihead_attention(
561567
normalize(x), encoder_output,
562568
encdec_attention_mask, self.kv_dim, self.heads_dim,
569+
is_training,
563570
dropout=hparams.attention_dropout,
564571
dropout_broadcast_dims=[self.length_dim],
565572
master_dtype=self.master_dtype,
@@ -582,7 +589,7 @@ def normalize(x):
582589
x += layer_prepostprocess_dropout(
583590
mtf.layers.masked_local_attention_1d(
584591
normalize(x),
585-
self.kv_dim, self.heads_dim,
592+
self.kv_dim, self.heads_dim, is_training,
586593
window_size=hparams.local_attention_window_size,
587594
master_dtype=self.master_dtype,
588595
slice_dtype=self.slice_dtype,
@@ -601,6 +608,7 @@ def normalize(x):
601608
compression_factor=hparams.compression_factor,
602609
kv_channels=self.kv_dim,
603610
heads=self.heads_dim,
611+
is_training=is_training,
604612
dropout=hparams.attention_dropout,
605613
dropout_broadcast_dims=[self.length_dim],
606614
master_dtype=self.master_dtype,

0 commit comments

Comments
 (0)