Skip to content

[MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass #98653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTMATHTOROCDL
#include "mlir/Conversion/Passes.h.inc"

/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,23 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// MathToLibm
//===----------------------------------------------------------------------===//

def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass doesn't need to operate on a ModuleOp specifically - you probably want this to work on anything that has a symbol table

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to have a DataLayoutOpInterface as well. Should this be checked dynamically, so that the pass can be used on any operation but should only be triggered on ops with a symbol table and data layout interface?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, yeah, good point

(I'm basically trying to make sure this works on gpu.module and builtin.module)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about making it work for both as well, but opted to use the "add patterns" function in the GPUToROCDL conversion since the options and converter are more complicated, and I didn't want to put that in this simpler pass, so it seemed okay for it to operate on the ModuleOp. I can try to see if I can make it work with both if you need it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Good point, this pass is for testing anyway, let's keep it at ModuleOp

let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.
}];
let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
}

//===----------------------------------------------------------------------===//
// MathToSPIRV
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
MLIRArithToLLVM
MLIRArithTransforms
MLIRMathToLLVM
MLIRMathToROCDL
MLIRAMDGPUToROCDL
MLIRFuncToLLVM
MLIRGPUDialect
Expand Down
46 changes: 2 additions & 44 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
Expand Down Expand Up @@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(

patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);

populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
"__ocml_fabs_f64");
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
"__ocml_atan_f64");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
"__ocml_atan2_f64");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
"__ocml_cbrt_f64");
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
"__ocml_ceil_f64");
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
"__ocml_cos_f64");
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
"__ocml_exp_f64");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
"__ocml_log_f64");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
"__ocml_log10_f64");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
"__ocml_log1p_f64");
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
"__ocml_log2_f64");
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
"__ocml_pow_f64");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
"__ocml_rsqrt_f64");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
"__ocml_sin_f64");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
"__ocml_sqrt_f64");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
"__ocml_tanh_f64");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
"__ocml_tan_f64");
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
"__ocml_erf_f64");
populateMathToROCDLConversionPatterns(converter, patterns);
}

std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
add_mlir_conversion_library(MLIRMathToROCDL
MathToROCDL.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
MLIRMathDialect
MLIRLLVMCommonConversion
MLIRPass
MLIRTransformUtils
MLIRVectorDialect
MLIRVectorUtils
)
146 changes: 146 additions & 0 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

#define DEBUG_TYPE "math-to-rocdl"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
}

void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::CopySignOp
// Handled by mathToLLVM: math::CountLeadingZerosOp
// Handled by mathToLLVM: math::CountTrailingZerosOp
// Handled by mathToLLVM: math::CgPopOp
// Handled by mathToLLVM: math::FmaOp
// FIXME: math::IPowIOp
// FIXME: math::FPowIOp
// Handled by mathToLLVM: math::RoundEvenOp
// Handled by mathToLLVM: math::RoundOp
// Handled by mathToLLVM: math::TruncOp
populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
"__ocml_fabs_f64");
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
"__ocml_acos_f64");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
"__ocml_acosh_f64");
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
"__ocml_asin_f64");
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
"__ocml_asinh_f64");
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
"__ocml_atan_f64");
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
"__ocml_atanh_f64");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
"__ocml_atan2_f64");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
"__ocml_cbrt_f64");
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
"__ocml_ceil_f64");
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
"__ocml_cos_f64");
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
"__ocml_cosh_f64");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64");
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
"__ocml_exp_f64");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're here (tm), the OCML math libraries now have native f16 implementations of these functions - could you add that case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(IIRC "these" is all of them, I just commented here arbitrarily)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... ah, we're using the GPU pattern, might not be worth the bother

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably add those in a separate patch since this is only a refactoring change.

"__ocml_log_f64");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
"__ocml_log10_f64");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
"__ocml_log1p_f64");
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
"__ocml_log2_f64");
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
"__ocml_pow_f64");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
"__ocml_rsqrt_f64");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
"__ocml_sin_f64");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
"__ocml_sqrt_f64");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
"__ocml_tanh_f64");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
"__ocml_tan_f64");
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
"__ocml_erf_f64");
// Single arith pattern that needs a ROCDL call, probably not
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64");
}

namespace {
struct ConvertMathToROCDLPass
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
ConvertMathToROCDLPass() = default;
void runOnOperation() override;
};
} // namespace

void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();

RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
populateMathToROCDLConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
LLVM::SqrtOp>();
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
}
Loading
Loading