diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 13eb97a910bd4..2f7beed549108 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -273,7 +273,7 @@ def RedundantSelectFalse : Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)), (SelectOp $pred, $a, $c)>; -// select(pred, false, true) => not(pred) +// select(pred, false, true) => not(pred) def SelectI1ToNot : Pat<(SelectOp $pred, (ConstantLikeMatcher ConstantAttr), @@ -376,6 +376,12 @@ def TruncationMatchesShiftAmount : CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == " "*getIntOrSplatIntValue($2)">]>>; +def ValueWidthMatchesShiftAmount : + Constraint, + CPred<"getScalarOrElementWidth($0) == " + "*getIntOrSplatIntValue($1)">]>>; + // trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated def TruncIExtSIToExtSI : Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)), @@ -406,7 +412,8 @@ def TruncIShrUIMulIToMulSIExtended : (Arith_MulSIExtendedOp:$res__1 $x, $y), [(ValuesWithSameType $tr, $x, $y), (ValueWiderThan $mul, $x), - (TruncationMatchesShiftAmount $mul, $x, $c0)]>; + (TruncationMatchesShiftAmount $mul, $x, $c0), + (ValueWidthMatchesShiftAmount $x, $c0)]>; // trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) def TruncIShrUIMulIToMulUIExtended : @@ -417,7 +424,8 @@ def TruncIShrUIMulIToMulUIExtended : (Arith_MulUIExtendedOp:$res__1 $x, $y), [(ValuesWithSameType $tr, $x, $y), (ValueWiderThan $mul, $x), - (TruncationMatchesShiftAmount $mul, $x, $c0)]>; + (TruncationMatchesShiftAmount $mul, $x, $c0), + (ValueWidthMatchesShiftAmount $x, $c0)]>; //===----------------------------------------------------------------------===// // TruncIOp diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index b6188c81ff912..542603722ab8a 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index { // CHECK-LABEL: @foldSubXX_tensor -// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> +// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> // CHECK: %[[sub:.+]] = arith.subi // CHECK: return %[[c0]], %[[sub]] func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor) -> (tensor<10xi32>, tensor) { @@ -2966,6 +2966,21 @@ func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 { return %hi : i32 } +// Verify that the signed extended multiplication pattern does not match +// if the right shift does not match the bitwidth of the multipliers. + +// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift +// CHECK-NOT: arith.mulsi_extended +func.func @wideMulToMulSIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 { + %x = arith.extsi %a: i32 to i33 + %y = arith.extsi %b: i32 to i33 + %m = arith.muli %x, %y: i33 + %c1 = arith.constant 1: i33 + %sh = arith.shrui %m, %c1 : i33 + %hi = arith.trunci %sh: i33 to i32 + return %hi : i32 +} + // CHECK-LABEL: @wideMulToMulSIExtendedVector // CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) // CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32> @@ -2994,6 +3009,21 @@ func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 { return %hi : i32 } +// Verify that the unsigned extended multiplication pattern does not match +// if the right shift does not match the bitwidth of the multipliers. + +// CHECK-LABEL: @wideMulToMulUIExtendedWithWrongShift +// CHECK-NOT: arith.mului_extended +func.func @wideMulToMulUIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 { + %x = arith.extui %a: i32 to i33 + %y = arith.extui %b: i32 to i33 + %m = arith.muli %x, %y: i33 + %c1 = arith.constant 1: i33 + %sh = arith.shrui %m, %c1 : i33 + %hi = arith.trunci %sh: i33 to i32 + return %hi : i32 +} + // CHECK-LABEL: @wideMulToMulUIExtendedVector // CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) // CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>