Skip to content

Commit b1c7285

Browse files
authored
fix: add an arg in matmul (#2279)
1 parent 1d115a1 commit b1c7285

File tree

3 files changed

+69
-26
lines changed

3 files changed

+69
-26
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def aten_ops_gelu(
171171

172172
@dynamo_tensorrt_converter(torch.ops.aten.matmul) # type: ignore[misc]
173173
@dynamo_tensorrt_converter(torch.ops.aten.mm.default) # type: ignore[misc]
174+
@dynamo_tensorrt_converter(torch.ops.aten.mv.default) # type: ignore[misc]
174175
def aten_ops_matmul(
175176
network: TRTNetwork,
176177
target: Target,
@@ -179,7 +180,12 @@ def aten_ops_matmul(
179180
name: str,
180181
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181182
return impl.matmul.matrix_multiply(
182-
network, target, SourceIR.ATEN, name, args[0], args[1]
183+
network,
184+
target,
185+
SourceIR.ATEN,
186+
name,
187+
args[0],
188+
args[1],
183189
)
184190

185191

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
56
from torch_tensorrt.fx.converters.converter_utils import (
@@ -10,8 +11,6 @@
1011
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1112
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1213

13-
import tensorrt as trt
14-
1514

1615
def matrix_multiply(
1716
network: TRTNetwork,
@@ -20,6 +19,8 @@ def matrix_multiply(
2019
name: str,
2120
input: TRTTensor,
2221
other: TRTTensor,
22+
input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
23+
other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE,
2324
) -> TRTTensor:
2425
if not isinstance(input, trt.tensorrt.ITensor):
2526
input = get_trt_tensor(network, input, f"{name}_input")
@@ -31,7 +32,6 @@ def matrix_multiply(
3132
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
3233
)
3334

34-
input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE
3535
preset_diff = 0
3636

3737
if len(input.shape) == 1:
@@ -46,5 +46,5 @@ def matrix_multiply(
4646
network, input, other, f"{name}_input", f"{name}_other", preset_diff
4747
)
4848
layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
49-
set_layer_name(layer, target, name)
49+
set_layer_name(layer, target, name, source_ir)
5050
return layer.get_output(0)

tests/py/dynamo/conversion/test_matmul_aten.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,44 @@
99
class TestMatMulConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
("2_2", (2, 3), (3, 2)),
13-
("2_2", (2, 3), (3, 1)),
14-
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
15-
# (2,3), (3,) torch.ops.aten.mv.default
16-
# Following cases use torch.ops.aten.bmm.defauly
12+
(
13+
"2_2",
14+
(2, 3),
15+
(3, 2),
16+
),
17+
(
18+
"4_6",
19+
(4, 5),
20+
(5, 6),
21+
),
22+
(
23+
"2_1",
24+
(2, 3),
25+
(3, 1),
26+
),
27+
(
28+
"4_1",
29+
(4, 1),
30+
(1, 1),
31+
),
32+
(
33+
"1_2",
34+
(1, 3),
35+
(3, 2),
36+
),
37+
(
38+
"1_3",
39+
(1, 2),
40+
(2, 3),
41+
),
42+
# Following cases use torch.ops.aten.bmm.default
1743
# ("4_3", (3,1,3,2), (2,2,3)),
1844
# ("3_4", (3,1,3,2), (2,2,3)),
1945
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
2046
# ("4_2", (1, 2, 2, 3), (3, 2)),
2147
]
2248
)
23-
def test_matmul_other_constant(self, _, input_shape, other_shape):
49+
def test_matmul_mm(self, _, input_shape, other_shape):
2450
class MatMul(nn.Module):
2551
def __init__(self):
2652
super().__init__()
@@ -39,32 +65,43 @@ def forward(self, input):
3965

4066
@parameterized.expand(
4167
[
42-
("2_2", (2, 3), (3, 2)),
43-
("1_2", (1, 3), (3, 2)),
44-
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
45-
# (2,3), (3,) torch.ops.aten.mv.default
46-
# Following cases use torch.ops.aten.bmm.defauly
47-
# ("4_3", (3,1,3,2), (2,2,3)),
48-
# ("3_4", (3,1,3,2), (2,2,3)),
49-
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
50-
# ("4_2", (1, 2, 2, 3), (3, 2)),
68+
(
69+
"1_1",
70+
(1, 1),
71+
(1,),
72+
),
73+
(
74+
"1_1",
75+
(1, 2),
76+
(2,),
77+
),
78+
(
79+
"2_1",
80+
(2, 1),
81+
(1,),
82+
),
83+
(
84+
"3_1",
85+
(3, 4),
86+
(4,),
87+
),
5188
]
5289
)
53-
def test_matmul_input_constant(self, _, input_shape, other_shape):
90+
def test_matmul_mv(self, _, input_shape, other_shape):
5491
class MatMul(nn.Module):
5592
def __init__(self):
5693
super().__init__()
57-
self.input = nn.Parameter(torch.randn(*input_shape))
94+
self.other = nn.Parameter(torch.randn(*other_shape))
5895

59-
def forward(self, other):
60-
return torch.matmul(self.input, other)
96+
def forward(self, input):
97+
return torch.matmul(input, self.other)
6198

62-
inputs = [torch.randn(*other_shape)]
99+
inputs = [torch.randn(*input_shape)]
63100

64101
self.run_test(
65102
MatMul(),
66103
inputs,
67-
expected_ops={torch.ops.aten.mm.default},
104+
expected_ops={torch.ops.aten.mv.default},
68105
)
69106

70107
@parameterized.expand(

0 commit comments

Comments
 (0)