Skip to content

Commit 2500ff3

Browse files
authored
Update train.py
1 parent bb33737 commit 2500ff3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def main(args):
233233
if args.bias_weight_decay is not None:
234234
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
235235
if args.transformer_embedding_decay is not None:
236-
for key in ["class_token", "position_embedding", "relative_position_bias"]:
236+
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
237237
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
238238
parameters = utils.set_weight_decay(
239239
model,

0 commit comments

Comments
 (0)