@@ -51,77 +51,14 @@ def batch_norm(
51
51
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
52
52
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
53
53
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
54
- if all (
54
+ if not all (
55
55
[
56
56
isinstance (weight , torch .Tensor ),
57
57
isinstance (bias , torch .Tensor ),
58
58
isinstance (running_mean , torch .Tensor ),
59
59
isinstance (running_var , torch .Tensor ),
60
60
]
61
61
):
62
- if weight is None :
63
- weight = 1.0
64
-
65
- if bias is None :
66
- bias = 0.0
67
-
68
- if running_mean is None :
69
- running_mean = 0.0
70
-
71
- if running_var is None :
72
- running_var = 1.0
73
- adjusted_scale = weight / torch .sqrt (running_var + eps )
74
- adjusted_bias = bias - running_mean * adjusted_scale
75
- power = torch .ones_like (adjusted_scale )
76
- adjusted_scale = to_trt_weights (
77
- ctx ,
78
- adjusted_scale ,
79
- name ,
80
- layer_type_name = "SCALE" ,
81
- weight_type_name = "SCALE" ,
82
- target = target ,
83
- source_ir = source_ir ,
84
- )
85
- adjusted_bias = to_trt_weights (
86
- ctx ,
87
- adjusted_bias ,
88
- name ,
89
- layer_type_name = "SCALE" ,
90
- weight_type_name = "SHIFT" ,
91
- target = target ,
92
- source_ir = source_ir ,
93
- )
94
-
95
- power = to_trt_weights (
96
- ctx ,
97
- power ,
98
- name ,
99
- layer_type_name = "SCALE" ,
100
- weight_type_name = "POWER" ,
101
- target = target ,
102
- source_ir = source_ir ,
103
- )
104
-
105
- output_shape = input .shape
106
- if len (input .shape ) < 4 :
107
-
108
- new_shape = (
109
- (input .shape [0 ], input .shape [1 ], 1 , 1 )
110
- if len (input .shape ) == 2
111
- else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
112
- )
113
- input = impl .shuffle .reshape (
114
- ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
115
- )
116
-
117
- layer = ctx .net .add_scale_nd (
118
- input , trt .ScaleMode .CHANNEL , adjusted_bias , adjusted_scale , power , 1
119
- )
120
- set_layer_name (layer , target , name , source_ir )
121
- output = layer .get_output (0 )
122
-
123
- else :
124
-
125
62
# We name the weight here according to the state_dict name
126
63
weight = (
127
64
get_trt_tensor (ctx , 1.0 , f"{ name } _weight" )
@@ -206,6 +143,70 @@ def batch_norm(
206
143
bias_adjusted_reshape ,
207
144
)
208
145
146
+ else :
147
+ if weight is None :
148
+ weight = 1.0
149
+
150
+ if bias is None :
151
+ bias = 0.0
152
+
153
+ if running_mean is None :
154
+ running_mean = 0.0
155
+
156
+ if running_var is None :
157
+ running_var = 1.0
158
+ adjusted_scale , adjusted_bias = batch_norm_constant_folding (
159
+ weight , bias , running_mean , running_var , eps
160
+ )
161
+ power = torch .ones_like (adjusted_scale )
162
+
163
+ adjusted_scale = to_trt_weights (
164
+ ctx ,
165
+ adjusted_scale ,
166
+ name ,
167
+ layer_type_name = "SCALE" ,
168
+ weight_type_name = "SCALE" ,
169
+ target = target ,
170
+ source_ir = source_ir ,
171
+ )
172
+ adjusted_bias = to_trt_weights (
173
+ ctx ,
174
+ adjusted_bias ,
175
+ name ,
176
+ layer_type_name = "SCALE" ,
177
+ weight_type_name = "SHIFT" ,
178
+ target = target ,
179
+ source_ir = source_ir ,
180
+ )
181
+
182
+ power = to_trt_weights (
183
+ ctx ,
184
+ power ,
185
+ name ,
186
+ layer_type_name = "SCALE" ,
187
+ weight_type_name = "POWER" ,
188
+ target = target ,
189
+ source_ir = source_ir ,
190
+ )
191
+
192
+ output_shape = input .shape
193
+ if len (input .shape ) < 4 :
194
+
195
+ new_shape = (
196
+ (input .shape [0 ], input .shape [1 ], 1 , 1 )
197
+ if len (input .shape ) == 2
198
+ else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
199
+ )
200
+ input = impl .shuffle .reshape (
201
+ ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
202
+ )
203
+
204
+ layer = ctx .net .add_scale_nd (
205
+ input , trt .ScaleMode .CHANNEL , adjusted_bias , adjusted_scale , power , 1
206
+ )
207
+ set_layer_name (layer , target , name , source_ir )
208
+ output = layer .get_output (0 )
209
+
209
210
# For BatchNorm1d, reshape output back to original shape if necessary
210
211
if len (output_shape ) < 4 :
211
212
output = impl .shuffle .reshape (
@@ -224,6 +225,18 @@ def batch_norm(
224
225
return output
225
226
226
227
228
+ def batch_norm_constant_folding (
229
+ weight : torch .Tensor ,
230
+ bias : torch .Tensor ,
231
+ running_mean : torch .Tensor ,
232
+ running_var : torch .Tensor ,
233
+ eps : float ,
234
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
235
+ adjusted_scale = weight / torch .sqrt (running_var + eps )
236
+ adjusted_bias = bias - running_mean * adjusted_scale
237
+ return adjusted_scale , adjusted_bias
238
+
239
+
227
240
def native_layer_norm (
228
241
ctx : ConversionContext ,
229
242
target : Target ,
@@ -303,7 +316,7 @@ def native_group_norm(
303
316
ctx , target , source_ir , f"{ name } _expand_bias_zero" , bias_zero , shape
304
317
)
305
318
306
- axes = get_axes_for_reduce_op ([ i for i in range (1 if group == 1 else 2 , rank )] )
319
+ axes = get_axes_for_reduce_op (list ( range (1 if group == 1 else 2 , rank )) )
307
320
308
321
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
309
322
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later
0 commit comments