Skip to content

Commit fe73508

Browse files
committed
Address the comments
1 parent 7f05316 commit fe73508

File tree

2 files changed

+85
-67
lines changed

2 files changed

+85
-67
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,18 @@ def construct_refit_mapping_from_weight_name_map(
9191
) -> dict[Any, Any]:
9292
engine_weight_map = {}
9393
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
94+
# Add more constant folding converters here
9495
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
9596
# Batch Norm Layer
9697
params = {}
9798
for w in sd_weight_name:
9899
params[w.split(".")[-1]] = state_dict[w].cuda()
99-
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
100-
shift = params["bias"] - params["running_mean"] * scale
100+
# Batch norm constant folding
101+
from torch_tensorrt.dynamo.conversion.impl.normalization.ops import (
102+
batch_norm_constant_folding,
103+
)
104+
105+
scale, shift = batch_norm_constant_folding(**params, eps=1e-7)
101106
# Set scale to scale or shift to shift
102107
engine_weight_map[engine_weight_name] = eval(
103108
engine_weight_name.split(" ")[-1].lower()

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 78 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -51,77 +51,14 @@ def batch_norm(
5151
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
5252
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
5353
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
54-
if all(
54+
if not all(
5555
[
5656
isinstance(weight, torch.Tensor),
5757
isinstance(bias, torch.Tensor),
5858
isinstance(running_mean, torch.Tensor),
5959
isinstance(running_var, torch.Tensor),
6060
]
6161
):
62-
if weight is None:
63-
weight = 1.0
64-
65-
if bias is None:
66-
bias = 0.0
67-
68-
if running_mean is None:
69-
running_mean = 0.0
70-
71-
if running_var is None:
72-
running_var = 1.0
73-
adjusted_scale = weight / torch.sqrt(running_var + eps)
74-
adjusted_bias = bias - running_mean * adjusted_scale
75-
power = torch.ones_like(adjusted_scale)
76-
adjusted_scale = to_trt_weights(
77-
ctx,
78-
adjusted_scale,
79-
name,
80-
layer_type_name="SCALE",
81-
weight_type_name="SCALE",
82-
target=target,
83-
source_ir=source_ir,
84-
)
85-
adjusted_bias = to_trt_weights(
86-
ctx,
87-
adjusted_bias,
88-
name,
89-
layer_type_name="SCALE",
90-
weight_type_name="SHIFT",
91-
target=target,
92-
source_ir=source_ir,
93-
)
94-
95-
power = to_trt_weights(
96-
ctx,
97-
power,
98-
name,
99-
layer_type_name="SCALE",
100-
weight_type_name="POWER",
101-
target=target,
102-
source_ir=source_ir,
103-
)
104-
105-
output_shape = input.shape
106-
if len(input.shape) < 4:
107-
108-
new_shape = (
109-
(input.shape[0], input.shape[1], 1, 1)
110-
if len(input.shape) == 2
111-
else (input.shape[0], input.shape[1], input.shape[2], 1)
112-
)
113-
input = impl.shuffle.reshape(
114-
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
115-
)
116-
117-
layer = ctx.net.add_scale_nd(
118-
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
119-
)
120-
set_layer_name(layer, target, name, source_ir)
121-
output = layer.get_output(0)
122-
123-
else:
124-
12562
# We name the weight here according to the state_dict name
12663
weight = (
12764
get_trt_tensor(ctx, 1.0, f"{name}_weight")
@@ -206,6 +143,70 @@ def batch_norm(
206143
bias_adjusted_reshape,
207144
)
208145

146+
else:
147+
if weight is None:
148+
weight = 1.0
149+
150+
if bias is None:
151+
bias = 0.0
152+
153+
if running_mean is None:
154+
running_mean = 0.0
155+
156+
if running_var is None:
157+
running_var = 1.0
158+
adjusted_scale, adjusted_bias = batch_norm_constant_folding(
159+
weight, bias, running_mean, running_var, eps
160+
)
161+
power = torch.ones_like(adjusted_scale)
162+
163+
adjusted_scale = to_trt_weights(
164+
ctx,
165+
adjusted_scale,
166+
name,
167+
layer_type_name="SCALE",
168+
weight_type_name="SCALE",
169+
target=target,
170+
source_ir=source_ir,
171+
)
172+
adjusted_bias = to_trt_weights(
173+
ctx,
174+
adjusted_bias,
175+
name,
176+
layer_type_name="SCALE",
177+
weight_type_name="SHIFT",
178+
target=target,
179+
source_ir=source_ir,
180+
)
181+
182+
power = to_trt_weights(
183+
ctx,
184+
power,
185+
name,
186+
layer_type_name="SCALE",
187+
weight_type_name="POWER",
188+
target=target,
189+
source_ir=source_ir,
190+
)
191+
192+
output_shape = input.shape
193+
if len(input.shape) < 4:
194+
195+
new_shape = (
196+
(input.shape[0], input.shape[1], 1, 1)
197+
if len(input.shape) == 2
198+
else (input.shape[0], input.shape[1], input.shape[2], 1)
199+
)
200+
input = impl.shuffle.reshape(
201+
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
202+
)
203+
204+
layer = ctx.net.add_scale_nd(
205+
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
206+
)
207+
set_layer_name(layer, target, name, source_ir)
208+
output = layer.get_output(0)
209+
209210
# For BatchNorm1d, reshape output back to original shape if necessary
210211
if len(output_shape) < 4:
211212
output = impl.shuffle.reshape(
@@ -224,6 +225,18 @@ def batch_norm(
224225
return output
225226

226227

228+
def batch_norm_constant_folding(
229+
weight: torch.Tensor,
230+
bias: torch.Tensor,
231+
running_mean: torch.Tensor,
232+
running_var: torch.Tensor,
233+
eps: float,
234+
) -> Tuple[torch.Tensor, torch.Tensor]:
235+
adjusted_scale = weight / torch.sqrt(running_var + eps)
236+
adjusted_bias = bias - running_mean * adjusted_scale
237+
return adjusted_scale, adjusted_bias
238+
239+
227240
def native_layer_norm(
228241
ctx: ConversionContext,
229242
target: Target,
@@ -303,7 +316,7 @@ def native_group_norm(
303316
ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape
304317
)
305318

306-
axes = get_axes_for_reduce_op([i for i in range(1 if group == 1 else 2, rank)])
319+
axes = get_axes_for_reduce_op(list(range(1 if group == 1 else 2, rank)))
307320

308321
# INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel,
309322
# hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later

0 commit comments

Comments
 (0)