Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e19130b

Browse files
T2T Teamcopybara-github
authored andcommitted
Don't overspecify WeightNorm input_spec, match input_spec of wrapped
PiperOrigin-RevId: 359573317
1 parent 9902e88 commit e19130b

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4043,9 +4043,6 @@ def _data_dep_init(self, inputs):
40434043

40444044
def build(self, input_shape=None):
40454045
"""Build `Layer`."""
4046-
input_shape = tf.TensorShape(input_shape).as_list()
4047-
self.input_spec = layers().InputSpec(shape=input_shape)
4048-
40494046
if not self.layer.built:
40504047
self.layer.build(input_shape)
40514048
self.layer.built = False
@@ -4072,6 +4069,7 @@ def build(self, input_shape=None):
40724069
self._compute_weights()
40734070

40744071
self.layer.built = True
4072+
self.input_spec = self.layer.input_spec
40754073

40764074
super(WeightNorm, self).build()
40774075
self.built = True

tensor2tensor/layers/common_layers_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,5 +965,20 @@ def fn_recompute(x):
965965
self.assertAllClose(g1, g2)
966966

967967

968+
class WeightNormTest(tf.test.TestCase):
969+
970+
def testInputSpec(self):
971+
"""Test that WeighNorm does not overspecify the input_spec."""
972+
conv = common_layers.WeightNorm(
973+
tf.keras.layers.Conv1D(filters=8, kernel_size=3))
974+
# Call with one batch size:
975+
conv(tf.zeros([1, 16, 2]))
976+
# Should allow call with another batch size.
977+
conv(tf.zeros([2, 16, 2]))
978+
# Input spec does detect incorrect input feature dim.
979+
with self.assertRaises(ValueError):
980+
conv(tf.zeros([2, 16, 3]))
981+
982+
968983
if __name__ == "__main__":
969984
tf.test.main()

0 commit comments

Comments
 (0)