diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 20a6acb7ff..ec67a7a358 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -11,3 +11,4 @@ TRUNCATE_LONG_AND_DOUBLE = False USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True +ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 4be44cd779..6f17ad768b 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -4,6 +4,7 @@ import torch from torch_tensorrt.dynamo._defaults import ( DEBUG, + ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, @@ -19,6 +20,27 @@ @dataclass class CompilationSettings: + """Compilation settings for Torch-TensorRT Dynamo Paths + + Args: + precision (torch.dtype): Model Layer precision + debug (bool): Whether to print out verbose debugging information + workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) + min_block_size (int): Minimum number of operators per TRT-Engine Block + torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage + pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) + max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine + version_compatible (bool): Provide version forward-compatibility for engine plan files + optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time, + searching for more optimization options. TRT defaults to 3 + use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the + argument as None + truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32 + enable_experimental_decompositions (bool): Whether to enable all core aten decompositions + or only a selected subset of them + """ + precision: torch.dtype = PRECISION debug: bool = DEBUG workspace_size: int = WORKSPACE_SIZE @@ -31,3 +53,4 @@ class CompilationSettings: use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE use_fast_partitioner: bool = USE_FAST_PARTITIONER + enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 6efbf89e34..2b761970a1 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -48,7 +48,7 @@ def aot_torch_tensorrt_aten_backend( gm, sample_inputs, fw_compiler=make_boxed_compiler(custom_backend), - decompositions=get_decompositions(), + decompositions=get_decompositions(settings.enable_experimental_decompositions), ) diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 86e7dd6688..eb051e93e9 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -14,6 +14,7 @@ from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._defaults import ( DEBUG, + ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, @@ -63,6 +64,7 @@ def compile( optimization_level: Optional[int] = OPTIMIZATION_LEVEL, use_python_runtime: bool = USE_PYTHON_RUNTIME, use_fast_partitioner: bool = USE_FAST_PARTITIONER, + enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, **kwargs: Any, ) -> torch.fx.GraphModule: if debug: @@ -72,9 +74,10 @@ def compile( logger.warning( "The Dynamo backend is an experimental feature, for which only the " - + "following arguments are supported: " - + "{enabled_precisions, debug, workspace_size, min_block_size, " - + "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}" + "following arguments are supported: " + "{enabled_precisions, debug, workspace_size, min_block_size, " + "torch_executed_ops, pass_through_build_failures, use_fast_partitioner, " + "enable_experimental_decompositions}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -115,6 +118,7 @@ def compile( "use_python_runtime": use_python_runtime, "truncate_long_and_double": truncate_long_and_double, "use_fast_partitioner": use_fast_partitioner, + "enable_experimental_decompositions": enable_experimental_decompositions, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py new file mode 100644 index 0000000000..60fef93e08 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -0,0 +1,200 @@ +from typing import Any, Callable, Dict, Set + +import torch +from torch._decomp import core_aten_decompositions +from torch._decomp import get_decompositions as get_torch_decompositions +from torch._ops import OpOverload + +aten = torch.ops.aten + +_core_aten_decompositions: Dict[ + OpOverload, Callable[[Any], Any] +] = core_aten_decompositions() +torch_enabled_decompositions: Set[OpOverload] = { + aten._adaptive_avg_pool2d_backward, + aten.addcdiv, + aten.addcdiv_, + aten.addcmul, + aten.addcmul_, + aten.addr, + aten.aminmax, + aten.arange.default, + aten.arange.start, + aten.avg_pool2d_backward, + aten.binary_cross_entropy, + aten.binary_cross_entropy_backward, + aten.binary_cross_entropy_with_logits, + aten.celu, + aten.col2im, + aten.count_nonzero, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.deg2rad, + aten.detach, + aten.diag_embed, + aten.diagonal_backward, + aten.dot, + aten.elu, + aten.elu_backward, + aten._embedding_bag, + aten.embedding_dense_backward, + aten._euclidean_dist.default, + aten.expand_as, + aten.eye, + aten.fill, + aten.frac, + aten._fused_moving_avg_obs_fq_helper, + aten.gelu, + aten.gelu_backward, + aten.glu_backward, + aten.grid_sampler_2d, + aten.hardshrink, + aten.hardshrink_backward, + aten.hardsigmoid, + aten.hardsigmoid_backward, + aten.hardswish, + aten.hardswish_, + aten.hardswish_backward, + aten.hardtanh, + aten.hardtanh_, + aten.hardtanh_backward, + aten.heaviside, + aten.huber_loss, + aten.huber_loss_backward, + aten.im2col, + aten.index_add, + aten.index_add_, + aten.index_copy, + aten.index_copy_, + aten.index_fill, + aten.index_fill_, + aten.index_select, + aten.isneginf, + aten.isposinf, + aten.l1_loss, + aten.leaky_relu, + aten.leaky_relu_, + aten.leaky_relu_backward, + aten.lerp, + aten.linspace, + aten.logaddexp, + aten.logaddexp2, + aten.logit, + aten.logit_backward, + aten.log_sigmoid_backward, + aten.log_sigmoid_forward, + aten._log_softmax, + aten._log_softmax_backward_data, + aten.logspace, + aten.logsumexp.default, + aten.masked_fill, + aten.masked_fill_, + aten.max_pool2d_with_indices_backward, + aten.mish, + aten.mse_loss, + aten.mse_loss_backward, + aten.mv, + aten.mvlgamma, + aten.nansum, + aten.nan_to_num, + aten.narrow, + # TODO: Disable the below operators once freezing is done + aten.native_batch_norm, + aten.native_batch_norm_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten.native_dropout_backward, + aten.native_group_norm, + aten.native_group_norm_backward, + aten.native_layer_norm, + aten.native_layer_norm_backward, + aten.new_empty, + aten.new_full, + aten.new_ones, + aten.new_zeros, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.norm, + aten.ones, + aten.ones_like, + aten._prelu_kernel, + aten._prelu_kernel_backward, + aten._reshape_alias, + aten.rad2deg, + aten.renorm, + aten.renorm_, + aten.rot90, + aten.rsub.Scalar, + aten.rsub.Tensor, + aten.select_backward, + aten.select_scatter, + aten.sgn, + aten.sigmoid_backward, + aten.silu, + aten.silu_, + aten.silu_backward, + aten.sinc, + aten.slice_backward, + aten.smooth_l1_loss, + aten.smooth_l1_loss_backward, + aten.soft_margin_loss, + aten.soft_margin_loss_backward, + aten._softmax, + aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, + aten.softshrink, + aten.softshrink_backward, + aten.special_entr, + aten.special_log_ndtr, + aten.special_xlog1py, + aten.stack, + aten.t, + aten.tanh_backward, + aten.threshold, + aten.threshold_backward, + aten.trace, + aten.transpose.int, + aten.tril.default, + aten.triu.default, + aten.unfold, + aten.unfold_backward, + aten.unfold_copy, + aten.upsample_bilinear2d, + aten.upsample_bilinear2d.vec, + aten.upsample_nearest2d_backward, + aten.xlogy, + aten.zero, + aten.zero_, + aten.zeros, + aten.zeros_like, + # Non-default convenience decompositions + aten.clamp_min, + aten.clamp_max, + aten.linalg_vector_norm, + aten.full, + aten.repeat, +} +torch_disabled_decompositions: Set[OpOverload] = set() + + +ENABLED_TORCH_DECOMPOSITIONS: Dict[ + OpOverload, Callable[[Any], Any] +] = get_torch_decompositions(torch_enabled_decompositions) +TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {} + + +def check_decomp_set_invariants() -> None: + """Validates no overlap between enabled and disabled decomposition sets""" + overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions) + + if overlap: + raise AssertionError( + f"Detected {overlap} registered in both torch_enabled_decompositions " + "and torch_disabled_decompositions. Ensure all operator(s) are in " + "at most one of the two sets." + ) + + +check_decomp_set_invariants() diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 666d04e779..57e1954575 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,11 +1,57 @@ -from typing import Any, Callable, Dict +import logging +from typing import Any, Callable, Dict, Optional import torch -from torch._decomp import OpOverload, core_aten_decompositions, register_decomposition - -DECOMPOSITIONS: Dict[OpOverload, Callable[..., Any]] = {**core_aten_decompositions()} - -aten = torch.ops.aten +from torch._decomp import register_decomposition +from torch._ops import OpOverload + +from ._decomposition_groups import ( + ENABLED_TORCH_DECOMPOSITIONS, + TORCH_TRT_DECOMPOSITIONS, + _core_aten_decompositions, + aten, + torch_disabled_decompositions, + torch_enabled_decompositions, +) + +logger = logging.getLogger(__name__) + + +def register_torch_trt_decomposition( + aten_op: OpOverload, registry: Optional[Any] = None +) -> Callable[[Any], Any]: + """Checks if the decomposition already exists in one of the sets + Registers the decomposition via the Torch utility + + Alerts the user if the decomposition already exists, before registering + Throws an AssertionError if the user attempts to register a decomposition + which is present in the set of explicitly disabled decompositions + """ + if aten_op in torch_enabled_decompositions: + logger.warning( + f"Detected custom decomposition for {aten_op}, which conflicts " + "with an existing Torch decomposition in torch_enabled_decompositions. " + "The custom implementation will take precedence." + ) + elif aten_op in torch_disabled_decompositions: + logger.info( + f"Detected custom decomposition for {aten_op}, which is present " + "in torch_disabled_decompositions." + ) + + # Conflicts with _core_aten_decompositions will only occur if + # enable_experimental_decompositions is True in get_decompositions + if aten_op in _core_aten_decompositions: + logger.debug( + f"Detected custom decomposition for {aten_op}, which conflicts " + "with an existing Torch decomposition in core_aten_decompositions. " + "The custom implementation will take precedence." + ) + + def register(fn: Callable[[Any], Any]) -> Any: + return register_decomposition(aten_op=aten_op, registry=registry)(fn) + + return register def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: @@ -14,8 +60,8 @@ def replace_inplace_op(aten_op: OpOverload, outplace_op: OpOverload) -> Any: https://github.com/pytorch/pytorch/blob/3344d79e3f732dadd5c85b99a7aa1a022f187929/torch/_decomp/decompositions.py#L3355-L3361 """ - @register_decomposition(aten_op, registry=DECOMPOSITIONS) # type: ignore[misc] - def inplace_op(*args: Any, **kwargs: Any) -> Any: + @register_torch_trt_decomposition(aten_op, registry=TORCH_TRT_DECOMPOSITIONS) + def inplace_op(*args, **kwargs): # type: ignore out = outplace_op(*args, **kwargs) return args[0].copy_(out) @@ -37,32 +83,36 @@ def inplace_op(*args: Any, **kwargs: Any) -> Any: replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) -@register_decomposition(aten.std, registry=DECOMPOSITIONS) # type: ignore[misc] -def std_replacement(*args: Any, **kwargs: Any) -> torch.Tensor: +@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS) +def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore return torch.sqrt(torch.var(*args, **kwargs)) -@register_decomposition(aten.rsqrt, registry=DECOMPOSITIONS) # type: ignore[misc] -def rsqrt_replacement(*args: Any, **kwargs: Any) -> torch.Tensor: +@register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS) +def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore return torch.reciprocal(torch.sqrt(*args, **kwargs)) -@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS) # type: ignore[misc] -def unsafe_view_replacement(x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: +@register_torch_trt_decomposition(aten._unsafe_view, registry=TORCH_TRT_DECOMPOSITIONS) +def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # type: ignore return torch.reshape(x, *args, **kwargs) -@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_torch_trt_decomposition( + torch.ops.aten.lift_fresh_copy, registry=TORCH_TRT_DECOMPOSITIONS +) def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(aten.alias, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_torch_trt_decomposition(aten.alias, registry=TORCH_TRT_DECOMPOSITIONS) def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x -@register_decomposition(torch.ops.aten.addmm, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_torch_trt_decomposition( + torch.ops.aten.addmm, registry=TORCH_TRT_DECOMPOSITIONS +) def addmm_replacement( input_: torch.Tensor, mat1: torch.Tensor, @@ -76,12 +126,24 @@ def addmm_replacement( ) -@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS) # type: ignore[misc] +@register_torch_trt_decomposition( + torch.ops.aten.reciprocal.default, registry=TORCH_TRT_DECOMPOSITIONS +) def reciprocal_replacement( input_: torch.Tensor, ) -> torch.Tensor: return torch.div(1, input_) -def get_decompositions() -> Dict[OpOverload, Callable[..., Any]]: - return DECOMPOSITIONS +def get_decompositions( + enable_experimental_decompositions: bool = False, +) -> Dict[OpOverload, Callable[[Any], Any]]: + if enable_experimental_decompositions: + CORE_ATEN_DECOMPOSITIONS_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = { + decomp: _core_aten_decompositions[decomp] + for decomp in _core_aten_decompositions + if decomp not in torch_disabled_decompositions + } + return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS} + else: + return {**ENABLED_TORCH_DECOMPOSITIONS, **TORCH_TRT_DECOMPOSITIONS}