From d0f0012a9fe752ff5703addf1f6e1bbfdd32bd3b Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 21 Aug 2023 16:45:45 -0700 Subject: [PATCH 1/2] feat: support linear (fully connected layer) dynamo converter refactor linear func --- .../dynamo/conversion/aten_ops_converters.py | 19 +++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/linear.py | 53 +++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/linear.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 792f58955b..018c7e944a 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1289,3 +1289,22 @@ def aten_ops_convolution( dilation=args[5], groups=args[8], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.linear) +def aten_ops_linear( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.linear.linear( + network, + target, + SourceIR.ATEN, + name, + input=args[0], + weight=args[1], + bias=args_bounds_check(args, 2, None), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 4ee7fd2bed..db7c877e8f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -7,6 +7,7 @@ conv, elementwise, embedding, + linear, matmul, normalization, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py new file mode 100644 index 0000000000..0a98087bce --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -0,0 +1,53 @@ +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.fx.converters.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + + +def linear( + network: TRTNetwork, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: Union[TRTTensor, torch.Tensor, np.ndarray], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], +) -> TRTTensor: + # Process weight terms + if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)): + raise RuntimeError( + f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," + ) + elif isinstance(weight, (torch.Tensor, np.ndarray)): + weight = get_trt_tensor(network, weight, f"{name}_weight") + + # Process bias terms + if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)): + raise RuntimeError( + f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray]," + ) + elif isinstance(bias, (torch.Tensor, np.ndarray)): + bias = get_trt_tensor(network, bias, f"{name}_bias") + + # add IMatrixMultiplyLayer + out = impl.matmul.matrix_multiply( + network, + target, + source_ir, + name, + input, + weight, + input_matrix_op=trt.MatrixOperation.NONE, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + + if bias is not None: + # add bias + out = impl.elementwise.add(network, target, source_ir, name, out, bias) + + return out From 5f1eb91f3109f74aad7c9e9fb02f42a36ad94cf5 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 8 Sep 2023 16:03:43 -0700 Subject: [PATCH 2/2] add default overload --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 018c7e944a..9fcf959346 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1291,7 +1291,7 @@ def aten_ops_convolution( ) -@dynamo_tensorrt_converter(torch.ops.aten.linear) +@dynamo_tensorrt_converter(torch.ops.aten.linear.default) def aten_ops_linear( network: TRTNetwork, target: Target,