diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index fad234a9dcae9..abb79278eddd4 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1336,7 +1336,7 @@ structured_op: !LinalgStructuredOpConfig name: C kind: output_tensor type_var: U - shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> + shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: cast kind: type_fn_attr diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 43410aaa6af1b..59b3ba914eaab 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -429,8 +429,8 @@ def quantized_matmul( @linalg_structured_op def matmul_transpose_a( - A=TensorDef(T1, S.K, S.N), - B=TensorDef(T2, S.K, S.M), + A=TensorDef(T1, S.K, S.M), + B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), cast=TypeFnAttrDef(default=TypeFn.cast_signed), ):