-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Prototype Module-Acceleration in Dynamo #1921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from torch_tensorrt.dynamo.backend.lowering._decompositions import ( | ||
from ._decompositions import ( | ||
get_decompositions, | ||
) | ||
from torch_tensorrt.dynamo.backend.lowering._partition import ( | ||
partition, | ||
get_submod_inputs, | ||
from ._pre_aot_lowering import ( | ||
MODULE_SUBSTITUTION_REGISTRY, | ||
module_substitution, | ||
) | ||
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS | ||
from .module_substitutions import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
import logging | ||
from typing import Dict, List, Optional, Sequence | ||
from typing import Dict, List, Optional, Sequence, Set | ||
|
||
import torch | ||
|
||
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE | ||
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY | ||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition | ||
from torch.fx.graph_module import GraphModule | ||
from torch.fx.node import _get_qualified_name | ||
|
@@ -14,6 +15,11 @@ | |
|
||
logger = logging.getLogger(__name__) | ||
|
||
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( | ||
"torch.ops." + str(module.new_operator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know if theres a better type than string for this registry? Like is there a op type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will look into this more - there is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After looking more into this, |
||
for module in MODULE_SUBSTITUTION_REGISTRY.values() | ||
) | ||
|
||
|
||
class TRTPartitioner(CapabilityBasedPartitioner): | ||
"""Partitioner to split an FX graph into subgraphs based on operator support | ||
|
@@ -35,7 +41,9 @@ def __init__( | |
operator_support: OperatorSupport, | ||
*, | ||
non_compute_ops: Optional[Sequence[str]] = None, | ||
allowed_single_node_partition_ops: Optional[Sequence[str]] = None, | ||
allowed_single_node_partition_ops: Optional[ | ||
Sequence[str] | ||
] = DEFAULT_SINGLE_NODE_PARTITIONS, | ||
min_block_size=MIN_BLOCK_SIZE, | ||
) -> None: | ||
super().__init__( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict | ||
import torch | ||
import logging | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ModuleReplacement: | ||
"""Class to store key functionality for module replacement""" | ||
|
||
# torch.ops.___ name for replacement function for module | ||
new_operator: torch._ops.OpOverload | ||
|
||
# Function taking a containing graph, a submodule, and a 'call_module' node and returning | ||
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected | ||
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph | ||
subgraph_insertion_fn: Callable[ | ||
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node | ||
] | ||
|
||
|
||
# Dictionary mapping module to ModuleReplacement instance | ||
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict() | ||
gs-olive marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def module_substitution( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is sick 😄 |
||
module_to_replace: torch.nn.Module, | ||
gs-olive marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new_operator: torch._ops.OpOverload, | ||
enabled: bool = True, | ||
) -> Callable[[Any], Any]: | ||
"""Decorator to register subgraph insertion functions | ||
|
||
Args: | ||
module_to_replace: nn.Module to replace | ||
new_operator: Custom torch operator to replace with | ||
enabled: Whether the substitution is enabled or disabled | ||
Returns: | ||
torch.fx.GraphModule | ||
""" | ||
|
||
def register_substitution(subgraph_insertion_fn): | ||
"""Function for use if substitution is enabled""" | ||
module_replacement = ModuleReplacement( | ||
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn | ||
) | ||
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement | ||
return subgraph_insertion_fn | ||
|
||
def disable_substitution(subgraph_insertion_fn): | ||
"""Function for use if substitution is disabled""" | ||
return subgraph_insertion_fn | ||
|
||
return register_substitution if enabled else disable_substitution | ||
|
||
|
||
def pre_aot_module_replacement(gm: torch.fx.GraphModule): | ||
"""Perform module-level graph replacement prior to AOT tracing | ||
|
||
Args: | ||
gm: FX GraphModule to perform module replacement on | ||
Returns: | ||
torch.fx.GraphModule | ||
|
||
""" | ||
# Ensure all parameters are in inference mode | ||
for param in gm.parameters(): | ||
param.requires_grad = False | ||
|
||
# Iterate over graph nodes, extracting module calls, to check for interceptions | ||
for n in gm.graph.nodes: | ||
if n.op == "call_module": | ||
# Extract submodule from graph | ||
submodule = gm.get_submodule(n.target) | ||
|
||
# If submodule is a member of the substitution registry, replace it | ||
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY: | ||
|
||
try: | ||
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)] | ||
op, insertion_fn = ( | ||
replacement.new_operator, | ||
replacement.subgraph_insertion_fn, | ||
) | ||
logger.debug( | ||
f"Replacing module of type {type(submodule)} with {op}" | ||
) | ||
|
||
# Insert new node prior to older node | ||
with gm.graph.inserting_before(n): | ||
new_node = insertion_fn(gm, submodule, n) | ||
|
||
# If submodule is not a native torch.nn module, it must be manually excluded | ||
# from Dynamo tracing | ||
if not type(submodule).__module__.startswith("torch.nn"): | ||
torch._dynamo.allowed_functions._allowed_function_ids.add( | ||
id(type(submodule)) | ||
) | ||
|
||
# Replace all original node uses and clean up graph | ||
n.replace_all_uses_with(new_node) | ||
gm.graph.eliminate_dead_code() | ||
gm.recompile() | ||
|
||
# A module replacement can fail in the event that the specific instance of the submodule cannot | ||
# be replaced | ||
except Exception: | ||
logger.debug( | ||
f"Encountered error while replacing {type(submodule)}", | ||
exc_info=True, | ||
) | ||
continue | ||
|
||
# Perform cleanup and recompilation before returning module | ||
gm.graph.eliminate_dead_code() | ||
gm.recompile() | ||
return gm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .maxpool1d import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import Dict, Tuple | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be cool to have a sphinx tutorial on how to do this from an external user perspective. could be as easy as removing maxpool1d from the registry then walking through all the parts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. I will add this, in addition to documentation on how to ensure all the relevant code is registered ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
import torch | ||
from torch._custom_op import custom_op | ||
from torch.fx.node import Argument, Target | ||
|
||
from torch_tensorrt.fx.converter_registry import tensorrt_converter | ||
from torch_tensorrt.fx.converters import acc_ops_converters | ||
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor | ||
|
||
from torch_tensorrt.dynamo.backend.lowering import module_substitution | ||
|
||
|
||
@custom_op( | ||
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", | ||
ns="tensorrt", | ||
) | ||
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): | ||
# Defines operator schema, name, namespace, and function header | ||
... | ||
|
||
|
||
@maxpool1d.impl("cpu") | ||
@maxpool1d.impl("cuda") | ||
def maxpool1d_generic( | ||
*args, | ||
**kwargs, | ||
): | ||
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation | ||
gs-olive marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return torch.nn.functional.max_pool1d( | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
|
||
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) | ||
def aten_ops_maxpool1d( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For module substitution "in library" do we want to put the converter here? or do we want to put the converter in the registry with the rest? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For external users they'd probably put it here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought it could be cleaner to have the converter implementation here, so all of the code relating to that module and its replacement is centralized. The requirement, however, is that for every new module replacement file, the user will have to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
network: TRTNetwork, | ||
target: Target, | ||
args: Tuple[Argument, ...], | ||
kwargs: Dict[str, Argument], | ||
name: str, | ||
) -> TRTTensor: | ||
# Defines converter replacing the default operator for this function | ||
kwargs_new = { | ||
"input": args[0], | ||
"kernel_size": args[1], | ||
"stride": args[2], | ||
"padding": args[3], | ||
"dilation": args[4], | ||
"ceil_mode": False if len(args) < 6 else args[5], | ||
} | ||
|
||
return acc_ops_converters.acc_ops_max_pool1d( | ||
network, target, None, kwargs_new, name | ||
) | ||
|
||
|
||
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d) | ||
def maxpool1d_insertion_fn( | ||
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node | ||
) -> torch.fx.Node: | ||
# Defines insertion function for new node | ||
new_node = gm.graph.call_function( | ||
torch.ops.tensorrt.maxpool1d, | ||
args=node.args, | ||
kwargs={ | ||
"kernel_size": submodule.kernel_size, | ||
"stride": submodule.stride, | ||
"padding": submodule.padding, | ||
"dilation": submodule.dilation, | ||
"ceil_mode": submodule.ceil_mode, | ||
}, | ||
) | ||
|
||
return new_node |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import torch | ||
from utils import lower_graph_testing | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
from torch_tensorrt.dynamo import compile | ||
|
||
|
||
class TestMaxPool1D(TestCase): | ||
def test_pre_aot_lowering_maxpool1d(self): | ||
class MaxPool1D(torch.nn.Module): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.maxpool = torch.nn.MaxPool1d(2) | ||
|
||
def forward(self, x): | ||
return self.maxpool(x) | ||
|
||
# Operations expected to be included in the traced graph after decompositions | ||
expected_ops = {torch.ops.tensorrt.maxpool1d.default} | ||
|
||
inputs = [ | ||
torch.rand( | ||
9, | ||
16, | ||
2, | ||
).cuda(), | ||
] | ||
|
||
fx_graph = torch.fx.symbolic_trace(MaxPool1D()) | ||
_, expected_ops_unseen = lower_graph_testing( | ||
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1 | ||
) | ||
|
||
self.assertEquals( | ||
len(expected_ops_unseen), | ||
0, | ||
f"The following expected ops were not encountered: {expected_ops_unseen}", | ||
) | ||
|
||
torch._dynamo.reset() | ||
|
||
# Validate that the results between Torch and Torch-TRT are similar | ||
optimized_model = compile( | ||
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True | ||
) | ||
optimized_model_results = optimized_model(*inputs).detach().cpu() | ||
torch_model_results = fx_graph(*inputs).detach().cpu() | ||
|
||
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) | ||
self.assertAlmostEqual( | ||
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,9 @@ | |
from torch_tensorrt.dynamo.backend.lowering._partition import ( | ||
partition, | ||
) | ||
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import ( | ||
pre_aot_module_replacement, | ||
) | ||
|
||
from torch._dynamo.backends.common import fake_tensor_unsupported | ||
|
||
|
@@ -31,6 +34,8 @@ def fx_dynamo_testing_backend( | |
torch_executed_ops=torch_executed_ops, | ||
) | ||
|
||
gm = pre_aot_module_replacement(gm) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is there a separate testing backend? Would we need to continue to make changes to this in step with the actual one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main reason for the separate testing backend is the argument As changes are made to the main backend, yes, those changes would need to be reflected here, and in the |
||
|
||
# Invoke AOTAutograd to translate operators to aten | ||
return aot_module_simplified( | ||
gm, | ||
|
Uh oh!
There was an error while loading. Please reload this page.