Skip to content

Commit 50734d3

Browse files
committed
feat: support linear (fully connected layer) dynamo converter
1 parent 91fcea4 commit 50734d3

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,22 @@ def aten_ops_clone(
420420
name,
421421
args[0],
422422
)
423+
424+
425+
@dynamo_tensorrt_converter(torch.ops.aten.linear)
426+
def aten_ops_linear(
427+
network: TRTNetwork,
428+
target: Target,
429+
args: Tuple[Argument, ...],
430+
kwargs: Dict[str, Argument],
431+
name: str,
432+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
433+
return impl.linear.linear(
434+
network,
435+
target,
436+
SourceIR.ATEN,
437+
name,
438+
input=args[0],
439+
weight=args[1],
440+
bias=args[2],
441+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
condition,
77
elementwise,
88
embedding,
9+
linear,
910
matmul,
1011
normalization,
1112
permutation,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Optional, Union
2+
3+
import tensorrt as trt
4+
import torch
5+
from torch.fx.node import Target
6+
from torch_tensorrt.fx.converters.converter_utils import (
7+
SourceIR,
8+
get_trt_tensor,
9+
set_layer_name,
10+
to_numpy,
11+
)
12+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
13+
14+
15+
def linear(
16+
network: TRTNetwork,
17+
target: Union[Target, str],
18+
source_ir: Optional[SourceIR],
19+
name: str,
20+
input: TRTTensor,
21+
weight: Union[TRTTensor, torch.Tensor],
22+
bias: Optional[Union[TRTTensor, torch.Tensor]],
23+
) -> TRTTensor:
24+
"""
25+
TensorRT fully connected layer implicitly flatten last three dimensions at
26+
the start and implicitly reshape the result to (K, 1, 1) at the end.
27+
28+
e.g. If input is (N, C, H, W), first it gets flatten to (N, C*H*W). Then after
29+
going through fully connected operation, it becomes (N, K). Before sending it
30+
out, it gets reshaped into (N, K, 1, 1) and this is the final output.
31+
32+
TODO: We can optimize this to get rid of unneccesary transformation.
33+
"""
34+
35+
if not isinstance(input, trt.tensorrt.ITensor):
36+
raise RuntimeError(
37+
f"Linear received input {input} that is not part " "of the TensorRT region!"
38+
)
39+
40+
# reshape the input to (*, X, 1, 1)
41+
pre_shuffle_layer = network.add_shuffle(input)
42+
pre_shuffle_layer.reshape_dims = tuple(input.shape) + (1, 1)
43+
set_layer_name(pre_shuffle_layer, target, f"{name}_pre_shuffle", source_ir)
44+
45+
# Process bias terms
46+
if isinstance(bias, torch.Tensor):
47+
# Transform the bias constant into a Numpy array
48+
bias = to_numpy(bias)
49+
50+
elif isinstance(bias, TRTTensor):
51+
bias = get_trt_tensor(network, bias, f"{name}_bias")
52+
53+
elif bias is not None:
54+
raise RuntimeError(
55+
f"Linear layer {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
56+
)
57+
58+
# Process weight terms
59+
if network.has_explicit_precision or isinstance(weight, TRTTensor):
60+
weight = get_trt_tensor(network, weight, f"{name}_weight")
61+
62+
elif isinstance(weight, torch.Tensor):
63+
# Transform the weight constant into a Numpy array
64+
weight = to_numpy(weight)
65+
66+
else:
67+
raise RuntimeError(
68+
f"Linear layer {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
69+
)
70+
71+
# add fully connected layer
72+
fully_connected_layer = network.add_fully_connected(
73+
input=pre_shuffle_layer.get_output(0),
74+
num_outputs=weight.shape[0],
75+
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
76+
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
77+
)
78+
set_layer_name(fully_connected_layer, target, f"{name}_linear", source_ir)
79+
80+
# reshape the output from (*, K, 1, 1) to (*, K)
81+
post_shuffle_layer = network.add_shuffle(fully_connected_layer.get_output(0))
82+
post_shuffle_layer.reshape_dims = tuple(input.shape[:-1]) + (weight.shape[0],)
83+
set_layer_name(post_shuffle_layer, target, f"{name}_post_shuffle", source_ir)
84+
85+
return post_shuffle_layer.get_output(0)

0 commit comments

Comments
 (0)