From 399f929cacc418acbbbf4ef67ed3170734c67f4e Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 20 Jul 2023 21:26:27 -0700 Subject: [PATCH 1/5] feat: Add preliminary support for freezing tensors in Dynamo fix: Refactor tensor freezing in Dynamo Key op fixes for failing tests --- py/torch_tensorrt/dynamo/backend/backends.py | 37 +++++++++++-------- .../dynamo/conversion/_TRTInterpreter.py | 31 +++++++++++++++- .../conversion/impl/normalization/ops.py | 13 +++++-- py/torch_tensorrt/dynamo/lowering/__init__.py | 1 + .../partitioning/_global_partitioner.py | 5 ++- .../fx/converters/acc_ops_converters.py | 29 ++++++++------- .../fx/converters/converter_utils.py | 8 ++-- .../fx/converters/impl/convolution.py | 20 ++++------ 8 files changed, 92 insertions(+), 52 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 2ba9f4d754..cf453dedf4 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -1,12 +1,13 @@ from __future__ import annotations import logging -from functools import partial +import unittest from typing import Any, Callable, Sequence import torch import torch._dynamo as td -from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +from torch._dynamo.utils import detect_fake_mode +from torch._functorch.aot_autograd import aot_export_joint_simple from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions @@ -33,8 +34,7 @@ def torch_tensorrt_backend( DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend - compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs) - return compiled_mod + return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) @td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc] @@ -43,21 +43,26 @@ def aot_torch_tensorrt_aten_backend( ) -> torch.nn.Module: settings = parse_dynamo_kwargs(kwargs) - custom_backend = partial( - _pretraced_backend, - settings=settings, - ) - # Perform Pre-AOT Lowering for Module-Level Replacement gm = pre_aot_substitutions(gm) - # Invoke AOTAutograd to translate operators to aten - return aot_module_simplified( - gm, - sample_inputs, - fw_compiler=make_boxed_compiler(custom_backend), - decompositions=get_decompositions(settings.enable_experimental_decompositions), - ) + fake_mode = detect_fake_mode(sample_inputs) + + # Place backend tracing within FakeTensor context allowing nonfake Tensors + with unittest.mock.patch.object( + fake_mode, "allow_non_fake_inputs", True + ), fake_mode: + # Invoke AOTAutograd to translate operators to aten + graph_module = aot_export_joint_simple( + gm, + sample_inputs, + trace_joint=False, + decompositions=get_decompositions( + settings.enable_experimental_decompositions + ), + ) + + return _pretraced_backend(graph_module, sample_inputs, settings) def _pretraced_backend( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 29485a919b..35b092e263 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set -import numpy +import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt @@ -11,6 +11,7 @@ import torch.fx from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata +from torch.utils._python_dispatch import _disable_current_modes from torch_tensorrt._Input import Input from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name from torch_tensorrt.fx.observer import Observer @@ -169,7 +170,7 @@ def run( cache = None if timing_cache: - cache_file = numpy.array(timing_cache) + cache_file = np.array(timing_cache) cache = builder_config.create_timing_cache(cache_file.tobytes()) else: cache = builder_config.create_timing_cache(b"") @@ -323,6 +324,21 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: assert self._cur_node_name is not None return converter(self.network, target, args, kwargs, self._cur_node_name) + def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray: + with _disable_current_modes(): + from torch_tensorrt.fx.converters import to_numpy + + frozen_attr = self.fetch_attr(target) + + if isinstance(frozen_attr, torch.nn.Parameter): + constant_tensor = frozen_attr.data + else: + constant_tensor = frozen_attr + + network_constant = to_numpy(constant_tensor) + + return network_constant + def call_method(self, target: str, args: Any, kwargs: Any) -> Any: assert isinstance(target, str) converter = CONVERTERS.get(self._cur_node) @@ -344,6 +360,17 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: else: outputs = (args[0],) + for output_idx in range(len(outputs)): + from torch_tensorrt.fx.converters import get_trt_tensor + + output = outputs[output_idx] + + if not isinstance(output, trt.tensorrt.ITensor): + new_output = get_trt_tensor(self.network, output, target) + outputs = ( + outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :] + ) + if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 2ab74ef86b..7822b515f8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -2,6 +2,7 @@ from typing import Any, List, Optional, Sequence, Union, cast import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -19,8 +20,6 @@ from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims -import tensorrt as trt - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -101,9 +100,15 @@ def layer_norm( "of the TensorRT region!" ) - gamma = weight.detach().cpu().float().numpy() + gamma = ( + weight.detach().cpu().float().numpy() + if isinstance(weight, torch.Tensor) + else weight + ) gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) - beta = bias.detach().cpu().float().numpy() + beta = ( + bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias + ) beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) eps_field = trt.PluginField( "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 6eda61a6fd..4c515586b2 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,4 +1,5 @@ from ._decompositions import get_decompositions # noqa: F401 +from ._freeze_aot_graph import * # noqa: F401 from ._fusers import * # noqa: F401 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 from ._pre_aot_lowering import register_substitution # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index bdb15b3394..6520ff3723 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -153,7 +153,10 @@ def is_node_supported( ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) - if node in CONVERTERS and node_name not in self.torch_executed_ops: + if ( + node.target in CONVERTERS.keys() + or (node.op == "get_attr" and "constant" in node_name) + ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure(): if node_name not in self.supported_operators: diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 1765077930..c1deb21303 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -3,30 +3,27 @@ import math import operator import warnings -from typing import cast, Dict, Optional, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union, cast import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch - -from ..converter_registry import tensorrt_converter - -from ..tracer.acc_tracer import acc_ops -from ..types import * # noqa: F403 from torch.fx.immutable_collections import immutable_list from torch.fx.node import Argument, Target - -from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks - -from .converter_utils import * # noqa: F403 +from torch_tensorrt.fx.converters.impl import activation, convolution from torch_tensorrt.fx.passes.lower_basic_pass import ( trt_transposed_linear, trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous -from torch_tensorrt.fx.converters.impl import activation, convolution + +from ..converter_registry import tensorrt_converter +from ..tracer.acc_tracer import acc_ops +from ..types import * # noqa: F403 +from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter +from .converter_utils import * # noqa: F403 _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2714,8 +2711,14 @@ def acc_ops_linear( "dim for linear and it can't be the last dim." ) - if isinstance(kwargs["weight"], torch.Tensor): - weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") + if isinstance(kwargs["weight"], (torch.Tensor, np.ndarray)): + weight = get_trt_tensor( + network, + kwargs["weight"].t() + if isinstance(kwargs["weight"], torch.Tensor) + else kwargs["weight"].T, + f"{name}_weight", + ) if target not in (acc_ops.linear, torch.ops.aten.linear): weight_op = trt.MatrixOperation.TRANSPOSE else: diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 49bf401f58..17b3f6785f 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -1,8 +1,8 @@ import operator import warnings +from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -from enum import Enum, auto import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt @@ -20,7 +20,7 @@ TRTPluginFieldCollection, TRTTensor, ) -from ..utils import unified_dtype_converter, Frameworks +from ..utils import Frameworks, unified_dtype_converter class SourceIR(Enum): @@ -271,7 +271,7 @@ def create_constant( """ constant = network.add_constant( (1,) if isinstance(value, (int, float)) else value.shape, - to_numpy(value, dtype), + to_numpy(value, dtype).copy(), ) constant.name = name return constant.get_output(0) @@ -311,7 +311,7 @@ def get_trt_tensor( elif isinstance(input_val, np.ndarray) and ( input_val.dtype == np.bool_ or input_val.dtype == np.int64 ): - input_val = input_val.to(np.int32) + input_val = input_val.astype(np.int32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): return create_constant(network, input_val, name, dtype) diff --git a/py/torch_tensorrt/fx/converters/impl/convolution.py b/py/torch_tensorrt/fx/converters/impl/convolution.py index 84071ed2d4..946eb74485 100644 --- a/py/torch_tensorrt/fx/converters/impl/convolution.py +++ b/py/torch_tensorrt/fx/converters/impl/convolution.py @@ -1,27 +1,23 @@ -import numpy as np from typing import Any, Optional, Sequence, Union +import numpy as np + # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch from torch.fx.node import Target - +from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, extend_attr_to_tuple, get_dyn_range, + get_trt_tensor, + has_dynamic_shape, mark_as_int8_layer, set_layer_name, - has_dynamic_shape, to_numpy, - get_trt_tensor, -) -from torch_tensorrt.fx.converters import acc_ops_converters - -from torch_tensorrt.fx.types import ( - TRTNetwork, - TRTTensor, ) +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor def convNd( @@ -54,7 +50,7 @@ def convNd( ) # Process bias terms - if isinstance(bias, torch.Tensor): + if isinstance(bias, (torch.Tensor, np.ndarray)): # Transform the bias constant into a Numpy array bias = to_numpy(bias) @@ -79,7 +75,7 @@ def convNd( network, target, tuple(), kwargs, name + "_unsqueeze_weight" ) - elif isinstance(weight, torch.Tensor): + elif isinstance(weight, (torch.Tensor, np.ndarray)): # Transform the weight constant into a Numpy array weight = to_numpy(weight) From 4b44ff281fe59cdb2a94cd003de21e21cf10b8a1 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 11 Aug 2023 17:18:40 -0700 Subject: [PATCH 2/5] fix: Add constant folding utility to freezing --- py/torch_tensorrt/dynamo/backend/backends.py | 25 +++++++++++++++++++ py/torch_tensorrt/dynamo/lowering/__init__.py | 1 - .../dynamo/lowering/_pre_aot_lowering.py | 4 --- .../partitioning/_adjacency_partitioner.py | 4 ++- .../partitioning/_global_partitioner.py | 3 +-- 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index cf453dedf4..b4c4f7a992 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -8,6 +8,7 @@ import torch._dynamo as td from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple +from torch._inductor.freezing import ConstantFolder, replace_node_with_constant from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions @@ -62,6 +63,8 @@ def aot_torch_tensorrt_aten_backend( ), ) + constant_fold(graph_module) + return _pretraced_backend(graph_module, sample_inputs, settings) @@ -105,3 +108,25 @@ def _pretraced_backend( + "specify pass_through_build_failures=False." ) raise + + +@torch.utils._python_dispatch._disable_current_modes() # type: ignore +def constant_fold(gm: torch.fx.GraphModule) -> Any: + cf = ConstantFolder(gm, skip_constructors=False) + cf.run() + + for node, constant in cf.node_replacements.items(): + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.nodes: + if node.op == "get_attr" and len(node.users) == 0: + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 4c515586b2..6eda61a6fd 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,5 +1,4 @@ from ._decompositions import get_decompositions # noqa: F401 -from ._freeze_aot_graph import * # noqa: F401 from ._fusers import * # noqa: F401 from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401 from ._pre_aot_lowering import register_substitution # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py index e69b9987c7..663cbf0c00 100644 --- a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py @@ -81,10 +81,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ logger.debug("Pre-module replacement graph:\n" + str(gm.graph)) - # Ensure all parameters are in inference mode - for param in gm.parameters(): - param.requires_grad = False - # Iterate over graph nodes, extracting module calls, to check for interceptions for n in gm.graph.nodes: exists_in_registry = False diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index f25bd2df12..5399bc5d6f 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -43,7 +43,9 @@ def is_node_supported( ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) - if node in CONVERTERS and node_name not in self.torch_executed_ops: + if ( + node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure(): if node_name not in self.supported_operators: diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 6520ff3723..19fccfc73f 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -154,8 +154,7 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node.target in CONVERTERS.keys() - or (node.op == "get_attr" and "constant" in node_name) + node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure(): From a94a07590234574ad532b26b7838c6d5c7dedc9d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:04:20 -0700 Subject: [PATCH 3/5] fix: Move tracer code into try/except --- py/torch_tensorrt/dynamo/backend/backends.py | 63 ++++++++++---------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index b4c4f7a992..096412cb35 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -43,29 +43,7 @@ def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any ) -> torch.nn.Module: settings = parse_dynamo_kwargs(kwargs) - - # Perform Pre-AOT Lowering for Module-Level Replacement - gm = pre_aot_substitutions(gm) - - fake_mode = detect_fake_mode(sample_inputs) - - # Place backend tracing within FakeTensor context allowing nonfake Tensors - with unittest.mock.patch.object( - fake_mode, "allow_non_fake_inputs", True - ), fake_mode: - # Invoke AOTAutograd to translate operators to aten - graph_module = aot_export_joint_simple( - gm, - sample_inputs, - trace_joint=False, - decompositions=get_decompositions( - settings.enable_experimental_decompositions - ), - ) - - constant_fold(graph_module) - - return _pretraced_backend(graph_module, sample_inputs, settings) + return _pretraced_backend(gm, sample_inputs, settings) def _pretraced_backend( @@ -83,15 +61,38 @@ def _pretraced_backend( Compiled FX GraphModule """ try: - logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) + logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph)) + + # Perform Pre-AOT Lowering for Module-Level Replacement + gm = pre_aot_substitutions(gm) + + fake_mode = detect_fake_mode(sample_inputs) + + # Place backend tracing within FakeTensor context allowing nonfake Tensors + with unittest.mock.patch.object( + fake_mode, "allow_non_fake_inputs", True + ), fake_mode: + # Invoke AOTAutograd to translate operators to aten + graph_module = aot_export_joint_simple( + gm, + sample_inputs, + trace_joint=False, + decompositions=get_decompositions( + settings.enable_experimental_decompositions + ), + ) - trt_compiled = compile_module( - gm, - sample_inputs, - settings=settings, - ) - return trt_compiled - except AssertionError: + logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) + + constant_fold(graph_module) + + trt_compiled = compile_module( + graph_module, + sample_inputs, + settings=settings, + ) + return trt_compiled + except (AssertionError, RuntimeError): if not settings.pass_through_build_failures: logger.warning( "TRT conversion failed on the subgraph. See trace above. " From 4e308f1d551216d813873f5844077855e0c694c6 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 28 Aug 2023 19:19:01 -0700 Subject: [PATCH 4/5] Custom implementation of AOT for compile --- py/torch_tensorrt/dynamo/backend/backends.py | 56 ++++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 096412cb35..13aaefabc1 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -2,13 +2,15 @@ import logging import unittest -from typing import Any, Callable, Sequence +from typing import Any, Callable, Dict, Optional, Sequence import torch import torch._dynamo as td +import torch.utils._pytree as pytree from torch._dynamo.utils import detect_fake_mode -from torch._functorch.aot_autograd import aot_export_joint_simple +from torch._functorch.aot_autograd import _aot_export_function from torch._inductor.freezing import ConstantFolder, replace_node_with_constant +from torch._ops import OpOverload from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions @@ -73,10 +75,9 @@ def _pretraced_backend( fake_mode, "allow_non_fake_inputs", True ), fake_mode: # Invoke AOTAutograd to translate operators to aten - graph_module = aot_export_joint_simple( + graph_module = aot_export_for_compile( gm, sample_inputs, - trace_joint=False, decompositions=get_decompositions( settings.enable_experimental_decompositions ), @@ -131,3 +132,50 @@ def constant_fold(gm: torch.fx.GraphModule) -> Any: gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() + + +def aot_export_for_compile( + func: torch.fx.GraphModule, + args: Sequence[torch.Tensor], + *, + decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None, +) -> torch.fx.GraphModule: + """Adapted from: + https://github.com/pytorch/pytorch/blob/054f3f1d8f9eb63ef8437991eba5b8f2aeee920f/torch/_functorch/aot_autograd.py#L4133-L4134 + + Removed check for input aliasing in resultant subgraph - TRT is functional-only + """ + with torch.no_grad(): + fx_g, metadata, in_spec, out_spec = _aot_export_function( + func, + args, + decompositions=decompositions, + ) + + # No input mutations + if ( + len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) + != 0 + ): + raise RuntimeError( + f"aot_export_joint_simple does not support input mutations. {str(metadata)}" + ) + # No pytrees + if type(in_spec) == pytree.LeafSpec: + raise RuntimeError( + f"aot_export_for_compile requires inputs to be a single list/tuple. in_spec={str(in_spec)}" + ) + if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: + raise RuntimeError( + f"aot_export_for_compile requires individual inputs not to be pytrees. in_spec={str(in_spec)}" + ) + if type(out_spec) == pytree.LeafSpec: + raise RuntimeError( + f"aot_export_for_compile requires outputs to be a single list/tuple. out_spec={str(out_spec)}" + ) + if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0: + raise RuntimeError( + f"aot_export_for_compile requires individual outputs not to be pytrees. out_spec={str(out_spec)}" + ) + + return fx_g From 95d3f98723277c5347a6060ed941904cf572a920 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 29 Aug 2023 21:00:21 -0700 Subject: [PATCH 5/5] Move fixes into Dynamo directory --- py/torch_tensorrt/dynamo/backend/backends.py | 16 +++- .../dynamo/conversion/_TRTInterpreter.py | 2 +- .../dynamo/conversion/aten_ops_converters.py | 91 +++++++++---------- .../dynamo/conversion/converter_utils.py | 82 ++++++++++++++++- .../dynamo/conversion/impl/condition/ops.py | 12 +-- .../dynamo/conversion/impl/conv.py | 33 ++++--- .../conversion/impl/elementwise/base.py | 6 +- .../dynamo/conversion/impl/elementwise/ops.py | 7 +- .../dynamo/conversion/impl/embedding.py | 3 +- .../dynamo/conversion/impl/linear.py | 2 +- .../dynamo/conversion/impl/matmul.py | 7 +- .../dynamo/conversion/impl/unsqueeze.py | 2 +- .../dynamo/lowering/_pre_aot_lowering.py | 2 - .../fx/converters/acc_ops_converters.py | 29 +++--- .../fx/converters/converter_utils.py | 8 +- .../fx/converters/impl/convolution.py | 20 ++-- .../dynamo/backend/test_specialized_models.py | 81 ++++++++++++++++- .../py/dynamo/lowering/test_decompositions.py | 1 + tests/py/dynamo/testing_utilities.py | 33 +++++-- 19 files changed, 308 insertions(+), 129 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 13aaefabc1..6be97d42d9 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -9,7 +9,7 @@ import torch.utils._pytree as pytree from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import _aot_export_function -from torch._inductor.freezing import ConstantFolder, replace_node_with_constant +from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant from torch._ops import OpOverload from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module @@ -100,7 +100,7 @@ def _pretraced_backend( + "Returning GraphModule forward instead.", exc_info=True, ) - return gm.forward + return gm else: logger.critical( "Halting compilation on build failure since " @@ -114,6 +114,13 @@ def _pretraced_backend( @torch.utils._python_dispatch._disable_current_modes() # type: ignore def constant_fold(gm: torch.fx.GraphModule) -> Any: + """Adapted from: + https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 + + Folds constants in the graph module, not skipping constructors + + Modifies the graph in-place and replaces node with constants + """ cf = ConstantFolder(gm, skip_constructors=False) cf.run() @@ -141,10 +148,13 @@ def aot_export_for_compile( decompositions: Optional[Dict[OpOverload, Callable[[Any], Any]]] = None, ) -> torch.fx.GraphModule: """Adapted from: - https://github.com/pytorch/pytorch/blob/054f3f1d8f9eb63ef8437991eba5b8f2aeee920f/torch/_functorch/aot_autograd.py#L4133-L4134 + https://github.com/pytorch/pytorch/blob/1a5fdc2458b98697c75c32eb6f4b8b34d76429cf/torch/_functorch/aot_autograd.py#L4084-L4158 Removed check for input aliasing in resultant subgraph - TRT is functional-only + + Exports the function to ATen for torch compile """ + # Trace function with input arguments and decompositions with torch.no_grad(): fx_g, metadata, in_spec, out_spec = _aot_export_function( func, diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 35b092e263..9f3dc5deb9 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -361,7 +361,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]: outputs = (args[0],) for output_idx in range(len(outputs)): - from torch_tensorrt.fx.converters import get_trt_tensor + from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor output = outputs[output_idx] diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9fcf959346..42d6165256 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -94,7 +94,7 @@ def aten_ops_fmod( return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1]) -@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) # type: ignore[misc] def aten_ops_relu( network: TRTNetwork, target: Target, @@ -111,7 +111,7 @@ def aten_ops_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) +@dynamo_tensorrt_converter(torch.ops.aten.sigmoid.default) # type: ignore[misc] def aten_ops_sigmoid( network: TRTNetwork, target: Target, @@ -128,7 +128,7 @@ def aten_ops_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) +@dynamo_tensorrt_converter(torch.ops.aten.tanh.default) # type: ignore[misc] def aten_ops_tanh( network: TRTNetwork, target: Target, @@ -145,7 +145,7 @@ def aten_ops_tanh( ) -@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) +@dynamo_tensorrt_converter(torch.ops.aten.leaky_relu.default) # type: ignore[misc] def aten_ops_leaky_relu( network: TRTNetwork, target: Target, @@ -163,7 +163,7 @@ def aten_ops_leaky_relu( ) -@dynamo_tensorrt_converter(torch.ops.aten.elu.default) +@dynamo_tensorrt_converter(torch.ops.aten.elu.default) # type: ignore[misc] def aten_ops_elu( network: TRTNetwork, target: Target, @@ -182,7 +182,7 @@ def aten_ops_elu( ) -@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) +@dynamo_tensorrt_converter(torch.ops.aten.softplus.default) # type: ignore[misc] def aten_ops_softplus( network: TRTNetwork, target: Target, @@ -200,7 +200,7 @@ def aten_ops_softplus( ) -@dynamo_tensorrt_converter(torch.ops.aten.clip.default) +@dynamo_tensorrt_converter(torch.ops.aten.clip.default) # type: ignore[misc] def aten_ops_clip( network: TRTNetwork, target: Target, @@ -219,7 +219,7 @@ def aten_ops_clip( ) -@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) +@dynamo_tensorrt_converter(torch.ops.aten.hardsigmoid.default) # type: ignore[misc] def aten_ops_hard_sigmoid( network: TRTNetwork, target: Target, @@ -296,7 +296,7 @@ def aten_ops_rsqrt( ) -@dynamo_tensorrt_converter(torch.ops.aten.neg.default) +@dynamo_tensorrt_converter(torch.ops.aten.neg.default) # type: ignore[misc] def aten_ops_neg( network: TRTNetwork, target: Target, @@ -304,18 +304,12 @@ def aten_ops_neg( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = args[0] - if (isinstance(input_val, TRTTensor)) and ( - input_val.dtype == trt.int8 or input_val.dtype == trt.int32 - ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) - return impl.unary.neg( network, target, SourceIR.ATEN, name, - input_val, + args[0], ) @@ -503,7 +497,7 @@ def aten_ops_clone( ) -@dynamo_tensorrt_converter(torch.ops.aten.expand.default) +@dynamo_tensorrt_converter(torch.ops.aten.expand.default) # type: ignore[misc] def aten_ops_expand( network: TRTNetwork, target: Target, @@ -533,7 +527,7 @@ def amax_param_validator(amax_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.amax.default, capability_validator=amax_param_validator -) +) # type: ignore[misc] def aten_ops_amax( network: TRTNetwork, target: Target, @@ -552,8 +546,8 @@ def aten_ops_amax( ) -@dynamo_tensorrt_converter(torch.ops.aten.sum.default) -@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) +@dynamo_tensorrt_converter(torch.ops.aten.sum.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) # type: ignore[misc] def aten_ops_sum( network: TRTNetwork, target: Target, @@ -946,8 +940,8 @@ def aten_ops_isinf( ) -@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) # type: ignore[misc] def aten_ops_add( network: TRTNetwork, target: Target, @@ -978,8 +972,8 @@ def aten_ops_add( ) -@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) # type: ignore[misc] def aten_ops_mul( network: TRTNetwork, target: Target, @@ -997,7 +991,7 @@ def aten_ops_mul( ) -@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) +@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc] def aten_ops_max( network: TRTNetwork, target: Target, @@ -1015,7 +1009,7 @@ def aten_ops_max( ) -@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) +@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc] def aten_ops_min( network: TRTNetwork, target: Target, @@ -1033,8 +1027,8 @@ def aten_ops_min( ) -@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) # type: ignore[misc] def aten_ops_sub( network: TRTNetwork, target: Target, @@ -1065,10 +1059,10 @@ def aten_ops_sub( ) -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) -@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) # type: ignore[misc] def aten_ops_div( network: TRTNetwork, target: Target, @@ -1111,9 +1105,9 @@ def aten_ops_div( ) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) -@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) # type: ignore[misc] def aten_ops_pow( network: TRTNetwork, target: Target, @@ -1131,8 +1125,8 @@ def aten_ops_pow( ) -@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) -@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) # type: ignore[misc] def aten_ops_floor_div( network: TRTNetwork, target: Target, @@ -1150,7 +1144,7 @@ def aten_ops_floor_div( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) +@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) # type: ignore[misc] def aten_ops_logical_and( network: TRTNetwork, target: Target, @@ -1168,7 +1162,7 @@ def aten_ops_logical_and( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) +@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) # type: ignore[misc] def aten_ops_logical_or( network: TRTNetwork, target: Target, @@ -1186,7 +1180,7 @@ def aten_ops_logical_or( ) -@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) +@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) # type: ignore[misc] def aten_ops_logical_xor( network: TRTNetwork, target: Target, @@ -1204,8 +1198,8 @@ def aten_ops_logical_xor( ) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) # type: ignore[misc] def aten_ops_equal( network: TRTNetwork, target: Target, @@ -1223,8 +1217,8 @@ def aten_ops_equal( ) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) # type: ignore[misc] def aten_ops_greater( network: TRTNetwork, target: Target, @@ -1242,8 +1236,8 @@ def aten_ops_greater( ) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) -@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) # type: ignore[misc] def aten_ops_less( network: TRTNetwork, target: Target, @@ -1267,7 +1261,7 @@ def conv_param_validator(conv_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, capability_validator=conv_param_validator -) +) # type: ignore[misc] def aten_ops_convolution( network: TRTNetwork, target: Target, @@ -1291,7 +1285,8 @@ def aten_ops_convolution( ) -@dynamo_tensorrt_converter(torch.ops.aten.linear.default) +@dynamo_tensorrt_converter(torch.ops.aten.linear.default) # type: ignore[misc] +@dynamo_tensorrt_converter(torch.ops.aten.linear) # type: ignore[misc] def aten_ops_linear( network: TRTNetwork, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index c5df3f9752..1d8dfecf3b 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,14 +1,16 @@ import functools import logging import re -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union +import numpy as np import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, get_axes_for_reduce_op, + to_numpy, unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor @@ -185,11 +187,85 @@ def extend_attr_to_tuple( if isinstance(val, list): val = tuple(val) - return val + + if isinstance(val, tuple): + return val + else: + raise AssertionError(f"Could not extend attribute {val}") -def cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor): +def cast_int_or_float_to_bool( + network: TRTNetwork, name: str, tensor: TRTTensor +) -> TRTTensor: if tensor.dtype != trt.bool: return cast_trt_tensor(network, tensor, trt.bool, name) return tensor + + +def create_constant( + network: TRTNetwork, + value: Union[int, float, np.ndarray, torch.Tensor], + name: str, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]], +) -> TRTTensor: + """ + Add a TensorRT constant layer whose value is `value` to `network`. + Args: + network (TRTNetwork): A TensorRT network to which we want to add + a constant layer. + value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array, + or a PyTorch tensor that will be used as value of the added TensorRT Constant layer. + name (str): Name of the added TensorRT Constant layer. + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. + Returns: + A TensorRT ITensor that represents the given value. + """ + constant = network.add_constant( + (1,) if isinstance(value, (int, float)) else value.shape, + to_numpy(value, dtype).copy(), + ) + constant.name = name + return constant.get_output(0) + + +def get_trt_tensor( + network: TRTNetwork, + input_val: Any, + name: str, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, +) -> TRTTensor: + """ + Given a value of random type, we try to convert it to a TensorRT ITensor. + An runtime error is raised if we're not able to do that. + Args: + network (TRTNetwork): A TensorRT network. If we want to + add a TensorRT Constant layer, we will add it to this network. + input_val (Any): An value that we want to convert to a TensorRT ITensor. + name (str): The name of the created TensorRT Constant layer if there's + one. + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If dtype is provided, the given value will be converted to this dtype. + Returns: + A TensorRT ITensor that represents the given value. + """ + # TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later + # This is useful for logical operations which require input to be bool type + if isinstance(input_val, bool): + input_val = int(input_val) + elif isinstance(input_val, torch.Tensor) and ( + input_val.dtype == torch.bool or input_val.dtype == torch.int64 + ): + input_val = input_val.to(torch.int32) + elif isinstance(input_val, np.ndarray) and ( + input_val.dtype == np.bool_ or input_val.dtype == np.int64 + ): + input_val = input_val.astype(np.int32) + + if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): + return create_constant(network, input_val, name, dtype) + elif isinstance(input_val, TRTTensor): + return input_val + else: + raise AssertionError(f"Cannot convert {input_val} to TRT constant") diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py index b81418490c..9c225357b5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -1,19 +1,17 @@ from typing import Optional +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable -from torch_tensorrt.dynamo.conversion.impl.slice import expand -from torch_tensorrt.fx.converters.converter_utils import ( - broadcast, +from torch_tensorrt.dynamo.conversion.converter_utils import ( + broadcastable, get_trt_tensor, - set_layer_name, ) +from torch_tensorrt.dynamo.conversion.impl.slice import expand +from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -import tensorrt as trt - def where( network: TRTNetwork, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/conv.py b/py/torch_tensorrt/dynamo/conversion/impl/conv.py index ff7deb0962..ebe4e37c9e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/conv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/conv.py @@ -7,11 +7,13 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple +from torch_tensorrt.dynamo.conversion.converter_utils import ( + extend_attr_to_tuple, + get_trt_tensor, +) from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, get_dyn_range, - get_trt_tensor, has_dynamic_shape, mark_as_int8_layer, set_layer_name, @@ -27,8 +29,8 @@ def convNd( name: str, is_conv1d: bool, input: TRTTensor, - weight: Union[TRTTensor, torch.Tensor], - bias: Optional[Union[TRTTensor, torch.Tensor]], + weight: Union[TRTTensor, torch.Tensor, np.ndarray], + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], stride: Optional[Union[int, Sequence[int]]], padding: Optional[Union[int, Sequence[int]]], dilation: Optional[Union[int, Sequence[int]]], @@ -97,19 +99,28 @@ def convNd( if isinstance(bias, TRTTensor): conv_layer.set_input(2, bias) + # Cast certain fields to tuples, in accordance with TRT requirements + padding = (padding,) if isinstance(padding, int) else padding + stride = (stride,) if isinstance(stride, int) else stride + dilation = (dilation,) if isinstance(dilation, int) else dilation + # Expand parameters manually for Conv1D computations if is_conv1d: - padding = tuple(padding) + (0,) - stride = extend_attr_to_tuple(stride, 2) - dilation = extend_attr_to_tuple(dilation, 2) + padding = (tuple(padding) + (0,)) if padding is not None else padding + stride = extend_attr_to_tuple(stride, 2) if stride is not None else stride + dilation = ( + extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation + ) set_layer_name(conv_layer, target, name, source_ir) # Set relevant attributes of convolution layer - conv_layer.padding_nd = padding - conv_layer.stride_nd = stride - conv_layer.dilation_nd = dilation - + if padding is not None: + conv_layer.padding_nd = padding + if stride is not None: + conv_layer.stride_nd = stride + if dilation is not None: + conv_layer.dilation_nd = dilation if groups is not None: conv_layer.num_groups = groups diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 46380cbec7..95dcd88a75 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -7,10 +7,12 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_trt_tensor, +) from torch_tensorrt.fx.converters.converter_utils import ( broadcast, - get_trt_tensor, set_layer_name, squeeze_left, ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index f5d46efc17..75ff33f26f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -7,17 +7,14 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_int_int_div_trt_tensor, cast_int_or_float_to_bool, + get_trt_tensor, ) from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( convert_binary_elementwise, ) from torch_tensorrt.dynamo.conversion.impl.unary import sign from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary -from torch_tensorrt.fx.converters.converter_utils import ( - get_trt_tensor, - set_layer_name, - squeeze_left, -) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index 26064f621c..8ddfdf015f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -3,7 +3,8 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor, set_layer_name +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/linear.py b/py/torch_tensorrt/dynamo/conversion/impl/linear.py index 0a98087bce..cad97a5c9a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/linear.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/linear.py @@ -5,7 +5,7 @@ import torch from torch.fx.node import Target from torch_tensorrt.dynamo.conversion import impl -from torch_tensorrt.fx.converters.converter_utils import SourceIR, get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor from torch_tensorrt.fx.types import TRTNetwork, TRTTensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 4b69b09d2a..a62d24121f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -3,11 +3,8 @@ import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import ( - broadcast, - get_trt_tensor, - set_layer_name, -) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py index 9929e59d86..fae22888d8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -2,9 +2,9 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import ( get_positive_dim, - get_trt_tensor, set_layer_name, ) from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor diff --git a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py index 663cbf0c00..70cc5424af 100644 --- a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py @@ -124,7 +124,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # Replace all original node uses and clean up graph n.replace_all_uses_with(new_node) - gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() @@ -138,7 +137,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: continue # Perform cleanup and recompilation before returning module - gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index c1deb21303..1765077930 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -3,27 +3,30 @@ import math import operator import warnings -from typing import Dict, Optional, Sequence, Tuple, Union, cast +from typing import cast, Dict, Optional, Sequence, Tuple, Union import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch + +from ..converter_registry import tensorrt_converter + +from ..tracer.acc_tracer import acc_ops +from ..types import * # noqa: F403 from torch.fx.immutable_collections import immutable_list from torch.fx.node import Argument, Target -from torch_tensorrt.fx.converters.impl import activation, convolution + +from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks + +from .converter_utils import * # noqa: F403 from torch_tensorrt.fx.passes.lower_basic_pass import ( trt_transposed_linear, trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous - -from ..converter_registry import tensorrt_converter -from ..tracer.acc_tracer import acc_ops -from ..types import * # noqa: F403 -from ..utils import Frameworks, get_dynamic_dims, unified_dtype_converter -from .converter_utils import * # noqa: F403 +from torch_tensorrt.fx.converters.impl import activation, convolution _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2711,14 +2714,8 @@ def acc_ops_linear( "dim for linear and it can't be the last dim." ) - if isinstance(kwargs["weight"], (torch.Tensor, np.ndarray)): - weight = get_trt_tensor( - network, - kwargs["weight"].t() - if isinstance(kwargs["weight"], torch.Tensor) - else kwargs["weight"].T, - f"{name}_weight", - ) + if isinstance(kwargs["weight"], torch.Tensor): + weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") if target not in (acc_ops.linear, torch.ops.aten.linear): weight_op = trt.MatrixOperation.TRANSPOSE else: diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 17b3f6785f..49bf401f58 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -1,8 +1,8 @@ import operator import warnings -from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from enum import Enum, auto import numpy as np # @manual=//deeplearning/trt/python:py_tensorrt @@ -20,7 +20,7 @@ TRTPluginFieldCollection, TRTTensor, ) -from ..utils import Frameworks, unified_dtype_converter +from ..utils import unified_dtype_converter, Frameworks class SourceIR(Enum): @@ -271,7 +271,7 @@ def create_constant( """ constant = network.add_constant( (1,) if isinstance(value, (int, float)) else value.shape, - to_numpy(value, dtype).copy(), + to_numpy(value, dtype), ) constant.name = name return constant.get_output(0) @@ -311,7 +311,7 @@ def get_trt_tensor( elif isinstance(input_val, np.ndarray) and ( input_val.dtype == np.bool_ or input_val.dtype == np.int64 ): - input_val = input_val.astype(np.int32) + input_val = input_val.to(np.int32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): return create_constant(network, input_val, name, dtype) diff --git a/py/torch_tensorrt/fx/converters/impl/convolution.py b/py/torch_tensorrt/fx/converters/impl/convolution.py index 946eb74485..84071ed2d4 100644 --- a/py/torch_tensorrt/fx/converters/impl/convolution.py +++ b/py/torch_tensorrt/fx/converters/impl/convolution.py @@ -1,23 +1,27 @@ -from typing import Any, Optional, Sequence, Union - import numpy as np +from typing import Any, Optional, Sequence, Union # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch from torch.fx.node import Target -from torch_tensorrt.fx.converters import acc_ops_converters + from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, extend_attr_to_tuple, get_dyn_range, - get_trt_tensor, - has_dynamic_shape, mark_as_int8_layer, set_layer_name, + has_dynamic_shape, to_numpy, + get_trt_tensor, +) +from torch_tensorrt.fx.converters import acc_ops_converters + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, ) -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor def convNd( @@ -50,7 +54,7 @@ def convNd( ) # Process bias terms - if isinstance(bias, (torch.Tensor, np.ndarray)): + if isinstance(bias, torch.Tensor): # Transform the bias constant into a Numpy array bias = to_numpy(bias) @@ -75,7 +79,7 @@ def convNd( network, target, tuple(), kwargs, name + "_unsqueeze_weight" ) - elif isinstance(weight, (torch.Tensor, np.ndarray)): + elif isinstance(weight, torch.Tensor): # Transform the weight constant into a Numpy array weight = to_numpy(weight) diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index 143aa9b241..1b9e5fb337 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -2,7 +2,7 @@ import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests -from ..testing_utilities import lower_graph_testing +from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing class TestFakeTensors(TestCase): @@ -157,5 +157,84 @@ def forward(self, x): torch._dynamo.reset() +class TestTensorFreezing(TestCase): + def test_tensor_freeze_attr(self): + class TensorFreeze(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.ones((8, 2), device="cuda") + + def forward(self, x): + return x @ self.const + + inputs = [ + torch.ones( + 7, + 8, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(TensorFreeze()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Frozen-Tensor TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_constant_fold(self): + class Arange(torch.nn.Module): + def forward(self, x): + y = torch.arange(10, device="cuda") + return x + y + + inputs = [ + torch.rand( + 10, + 10, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(Arange()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"Constant Folded TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index 909ded2690..fd834394c1 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -12,6 +12,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def forward(self, x, y): + x += 1 x = torch.ops.aten.add_.Tensor(x, y) x = torch.ops.aten.relu_.default(x) return x diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index e2607d859b..f311f2db2b 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -1,18 +1,18 @@ +import unittest from copy import deepcopy from functools import partial from typing import Any, List, Sequence, Set import torch -from torch._dynamo.backends.common import fake_tensor_unsupported -from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +from torch._dynamo.utils import detect_fake_mode from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.backend.backends import aot_export_for_compile, constant_fold from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions DECIMALS_OF_AGREEMENT = 4 -@fake_tensor_unsupported def fx_dynamo_testing_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], @@ -33,13 +33,26 @@ def fx_dynamo_testing_backend( gm = pre_aot_substitutions(gm) - # Invoke AOTAutograd to translate operators to aten - return aot_module_simplified( - gm, - sample_inputs, - fw_compiler=make_boxed_compiler(custom_backend), - decompositions=get_decompositions(), - ) + fake_mode = detect_fake_mode(sample_inputs) + + # Place backend tracing within FakeTensor context allowing nonfake Tensors + with unittest.mock.patch.object( + fake_mode, "allow_non_fake_inputs", True + ), fake_mode: + # Invoke AOTAutograd to translate operators to aten + graph_module = aot_export_for_compile( + gm, + sample_inputs, + decompositions=get_decompositions(), + ) + + constant_fold(graph_module) + + trt_compiled = custom_backend( + graph_module, + sample_inputs, + ) + return trt_compiled def compile_module_testing(