@@ -242,6 +242,8 @@ def _mtf_model_fn(self, features, mesh):
242
242
hparams = self ._hparams
243
243
extra_losses = []
244
244
targets = tf .to_int32 (features ["targets" ])
245
+ mode = getattr (hparams , "mode" , tf .estimator .ModeKeys .TRAIN )
246
+ is_training = mode == tf .estimator .ModeKeys .TRAIN
245
247
if len (targets .get_shape ()) > 2 :
246
248
tf .logging .info ("targets = %s" % targets )
247
249
targets = tf .squeeze (targets , [2 , 3 ])
@@ -289,7 +291,7 @@ def pad_to_max_length(x):
289
291
290
292
def layer_prepostprocess_dropout (x ):
291
293
return mtf .dropout (
292
- x , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
294
+ x , is_training , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
293
295
noise_shape = mtf .Shape (self .batch_dims + [self .model_dim ]))
294
296
295
297
(inputs_embedding_var ,
@@ -426,10 +428,11 @@ def _feedforward_layer(self, x, layer_type, losses=None):
426
428
ValueError: if hparams make no sense
427
429
"""
428
430
hparams = self ._hparams
429
-
431
+ mode = getattr (hparams , "mode" , tf .estimator .ModeKeys .TRAIN )
432
+ is_training = mode == tf .estimator .ModeKeys .TRAIN
430
433
if layer_type == "drd" :
431
434
return mtf .layers .dense_relu_dense (
432
- x , self .feedforward_dim , dropout = hparams .relu_dropout ,
435
+ x , self .feedforward_dim , is_training , dropout = hparams .relu_dropout ,
433
436
dropout_broadcast_dims = [self .length_dim ],
434
437
master_dtype = self .master_dtype ,
435
438
slice_dtype = self .slice_dtype )
@@ -493,11 +496,13 @@ def _layer_stack(self,
493
496
"""
494
497
hparams = self ._hparams
495
498
is_incremental = (step_num is not None )
499
+ mode = getattr (hparams , "mode" , tf .estimator .ModeKeys .TRAIN )
500
+ is_training = mode == tf .estimator .ModeKeys .TRAIN
496
501
def layer_prepostprocess_dropout (x ):
497
502
if is_incremental :
498
503
return x
499
504
return mtf .dropout (
500
- x , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
505
+ x , is_training , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
501
506
noise_shape = mtf .Shape (self .batch_dims + [self .model_dim ]))
502
507
num_layers = len (layers )
503
508
num_layer_norms = num_layers + 1
@@ -540,6 +545,7 @@ def normalize(x):
540
545
mtf .layers .multihead_attention (
541
546
normalize (x ), None ,
542
547
self_attention_mask , self .kv_dim , self .heads_dim ,
548
+ is_training ,
543
549
dropout = hparams .attention_dropout ,
544
550
dropout_broadcast_dims = [self .length_dim ],
545
551
master_dtype = self .master_dtype ,
@@ -560,6 +566,7 @@ def normalize(x):
560
566
mtf .layers .multihead_attention (
561
567
normalize (x ), encoder_output ,
562
568
encdec_attention_mask , self .kv_dim , self .heads_dim ,
569
+ is_training ,
563
570
dropout = hparams .attention_dropout ,
564
571
dropout_broadcast_dims = [self .length_dim ],
565
572
master_dtype = self .master_dtype ,
@@ -582,7 +589,7 @@ def normalize(x):
582
589
x += layer_prepostprocess_dropout (
583
590
mtf .layers .masked_local_attention_1d (
584
591
normalize (x ),
585
- self .kv_dim , self .heads_dim ,
592
+ self .kv_dim , self .heads_dim , is_training ,
586
593
window_size = hparams .local_attention_window_size ,
587
594
master_dtype = self .master_dtype ,
588
595
slice_dtype = self .slice_dtype ,
@@ -601,6 +608,7 @@ def normalize(x):
601
608
compression_factor = hparams .compression_factor ,
602
609
kv_channels = self .kv_dim ,
603
610
heads = self .heads_dim ,
611
+ is_training = is_training ,
604
612
dropout = hparams .attention_dropout ,
605
613
dropout_broadcast_dims = [self .length_dim ],
606
614
master_dtype = self .master_dtype ,
0 commit comments