Skip to content

Commit 5b702be

Browse files
authored
[mlir][math] Convert math.fpowi to math.powf in case of non constant (#87472)
Convert math.fpowi to math.powf by converting dtype of power operand to floating point.
1 parent 17642c7 commit 5b702be

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,20 +216,30 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
216216
// Convert `math.fpowi` to a series of `arith.mulf` operations.
217217
// If the power is negative, we divide one by the result.
218218
// If both the base and power are zero, the result is 1.
219-
static LogicalResult convertFPowICstOp(math::FPowIOp op,
220-
PatternRewriter &rewriter) {
219+
// In the case of non constant power, we convert the operation to `math.powf`.
220+
static LogicalResult convertFPowIOp(math::FPowIOp op,
221+
PatternRewriter &rewriter) {
221222
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
222223
Value base = op.getOperand(0);
223224
Value power = op.getOperand(1);
224225
Type baseType = base.getType();
225226

227+
auto convertFPowItoPowf = [&]() -> LogicalResult {
228+
Value castPowerToFp =
229+
rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
230+
Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
231+
castPowerToFp);
232+
rewriter.replaceOp(op, res);
233+
return success();
234+
};
235+
226236
Attribute cstAttr;
227237
if (!matchPattern(power, m_Constant(&cstAttr)))
228-
return failure();
238+
return convertFPowItoPowf();
229239

230240
APInt value;
231241
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
232-
return failure();
242+
return convertFPowItoPowf();
233243

234244
int64_t powerInt = value.getSExtValue();
235245
bool isNegative = powerInt < 0;
@@ -591,7 +601,7 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
591601
}
592602

593603
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
594-
patterns.add(convertFPowICstOp);
604+
patterns.add(convertFPowIOp);
595605
}
596606

597607
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,51 @@ func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 {
610610
// CHECK: return %[[RET]] : f32
611611

612612
// -----
613+
614+
// CHECK-LABEL: func.func @math_fpowi_to_powf_tensor
615+
func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> tensor<8xf32> {
616+
%2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi32>
617+
return %2 : tensor<8xf32>
618+
}
619+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> {
620+
// CHECK: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32>
621+
// CHECK: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32>
622+
// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32>
623+
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32>
624+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32>
625+
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32>
626+
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32>
627+
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32>
628+
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32>
629+
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32>
630+
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32>
631+
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32>
632+
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32>
633+
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1>
634+
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32>
635+
// CHECK: return %[[SEL]] : tensor<8xf32>
636+
637+
// -----
638+
639+
// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar
640+
func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 {
641+
%2 = math.fpowi %0, %1 : f32, i64
642+
return %2 : f32
643+
}
644+
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 {
645+
// CHECK: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32
646+
// CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
647+
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
648+
// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32
649+
// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32
650+
// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32
651+
// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32
652+
// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32
653+
// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32
654+
// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32
655+
// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32
656+
// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32
657+
// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32
658+
// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1
659+
// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32
660+
// CHECK: return %[[SEL]] : f32

0 commit comments

Comments
 (0)