Skip to content

[WIP][OpenMP][MLIR] Lowering task_reduction clause for pass-by-value vars to LLVMIR #125218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,11 @@ class OpenMP_TaskReductionClauseSkip<
unsigned numTaskReductionBlockArgs() {
return getTaskReductionVars().size();
}

/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return getReductionVars().size(); }

auto getReductionSyms() { return getTaskReductionSyms(); }
}];

let description = [{
Expand Down
233 changes: 223 additions & 10 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
result = todo("reduction with modifier");
};
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
op.getTaskReductionSyms())
result = todo("task_reduction");
};
auto checkUntied = [&todo](auto op, LogicalResult &result) {
if (op.getUntied())
result = todo("untied");
Expand All @@ -276,10 +271,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkAllocate(op, result);
checkInReduction(op, result);
})
.Case([&](omp::TaskgroupOp op) {
checkAllocate(op, result);
checkTaskReduction(op, result);
})
.Case([&](omp::TaskgroupOp op) { checkAllocate(op, result); })
.Case([&](omp::TaskwaitOp op) {
checkDepend(op, result);
checkNowait(op, result);
Expand Down Expand Up @@ -1817,6 +1809,212 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return success();
}

template <typename OP>
llvm::Value *createTaskReductionFunction(
llvm::IRBuilderBase &builder, const std::string &name, llvm::Type *redTy,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, Region &region,
OP &op, unsigned Cnt, llvm::ArrayRef<bool> &isByRef,
SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap) {

llvm::LLVMContext &Context = builder.getContext();
// TODO: by-ref reduction variables are yet to be handled.
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0);
if (region.empty() && name == "red_fini")
// Finalization is optional for reductions.
return llvm::Constant::getNullValue(OpaquePtrTy);
llvm::FunctionType *funcType =
llvm::FunctionType::get(OpaquePtrTy, {OpaquePtrTy, OpaquePtrTy}, false);
llvm::Function *function =
llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, name,
builder.GetInsertBlock()->getModule());
function->setDoesNotRecurse();
llvm::BasicBlock *entry =
llvm::BasicBlock::Create(Context, "entry", function);
llvm::IRBuilder<> bbBuilder(entry);

llvm::Value *arg0 = function->getArg(0);
llvm::Value *arg1 = function->getArg(1);

if (name == "red_init") {
function->addParamAttr(0, llvm::Attribute::NoAlias);
function->addParamAttr(1, llvm::Attribute::NoAlias);
mapInitializationArgs(op, moduleTranslation, reductionDecls,
reductionVariableMap, Cnt);
} else if (name == "red_comb") {
llvm::Value *arg0L = bbBuilder.CreateLoad(redTy, arg0);
llvm::Value *arg1L = bbBuilder.CreateLoad(redTy, arg1);
moduleTranslation.mapValue(region.front().getArgument(0), arg0L);
moduleTranslation.mapValue(region.front().getArgument(1), arg1L);
}
if (region.empty() || isByRef[Cnt]) {
// Emit en empty function body in case of empty region or pass-by-reference
// vars
// TODO: Add support for translating pass-by-reference vars.
bbBuilder.CreateRet(arg0); // Return from the function
return function;
}

SmallVector<llvm::Value *, 1> phis;
if (failed(inlineConvertOmpRegions(region, "", bbBuilder, moduleTranslation,
&phis)))
return nullptr;
assert(
phis.size() == 1 &&
"expected one value to be yielded from the reduction declaration region");
bbBuilder.CreateStore(phis[0], arg0);
bbBuilder.CreateRet(arg0); // Return from the function
return function;
}

void emitTaskRedInitCall(
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
const llvm::OpenMPIRBuilder::LocationDescription &ompLoc, int arraySize,
llvm::Value *ArrayAlloca) {

llvm::LLVMContext &Context = builder.getContext();
uint32_t SrcLocStrSize;
llvm::Constant *SrcLocStr =
moduleTranslation.getOpenMPBuilder()->getOrCreateSrcLocStr(ompLoc,
SrcLocStrSize);
llvm::Value *Ident = moduleTranslation.getOpenMPBuilder()->getOrCreateIdent(
SrcLocStr, SrcLocStrSize);
llvm::Value *ThreadID =
moduleTranslation.getOpenMPBuilder()->getOrCreateThreadID(Ident);
llvm::Constant *ConstInt =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), arraySize);

llvm::Function *TaskRedInitFn =
moduleTranslation.getOpenMPBuilder()->getOrCreateRuntimeFunctionPtr(
llvm::omp::OMPRTL___kmpc_taskred_init);
builder.CreateCall(TaskRedInitFn, {ThreadID, ConstInt, ArrayAlloca});
}

template <typename OP>
static LogicalResult allocAndInitializeTaskReductionVars(
OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
SmallVectorImpl<llvm::Value *> &privateReductionVariables,
DenseMap<Value, llvm::Value *> &reductionVariableMap,
llvm::ArrayRef<bool> isByRef) {

if (op.getNumReductionVars() == 0)
return success();

llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::LLVMContext &Context = builder.getContext();
SmallVector<DeferredStore> deferredStores;

// Save the current insertion point
auto oldIP = builder.saveIP();

// Set insertion point after the allocations
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());

// Define the kmp_taskred_input_t structure
llvm::StructType *kmp_taskred_input_t =
llvm::StructType::create(Context, "kmp_taskred_input_t");
llvm::Type *OpaquePtrTy = llvm::PointerType::get(Context, 0); // void*
llvm::Type *SizeTy = builder.getInt64Ty(); // size_t (assumed to be i64)
llvm::Type *FlagsTy = llvm::Type::getInt32Ty(Context); // flags (i32)

// Structure members
std::vector<llvm::Type *> structMembers = {
OpaquePtrTy, // reduce_shar (void*)
OpaquePtrTy, // reduce_orig (void*)
SizeTy, // reduce_size (size_t)
OpaquePtrTy, // reduce_init (void*)
OpaquePtrTy, // reduce_fini (void*)
OpaquePtrTy, // reduce_comb (void*)
FlagsTy // flags (i32)
};

kmp_taskred_input_t->setBody(structMembers);
int arraySize = op.getNumReductionVars();
llvm::ArrayType *ArrayTy =
llvm::ArrayType::get(kmp_taskred_input_t, arraySize);

// Allocate the array for kmp_taskred_input_t
llvm::AllocaInst *ArrayAlloca =
builder.CreateAlloca(ArrayTy, nullptr, "kmp_taskred_array");

// Restore the insertion point
builder.restoreIP(oldIP);
llvm::DataLayout DL = builder.GetInsertBlock()->getModule()->getDataLayout();

for (int Cnt = 0; Cnt < arraySize; ++Cnt) {
llvm::Value *shared =
moduleTranslation.lookupValue(op.getReductionVars()[Cnt]);
// Create a GEP to access the reduction element
llvm::Value *StructPtr = builder.CreateGEP(
ArrayTy, ArrayAlloca, {builder.getInt32(0), builder.getInt32(Cnt)},
"red_element");

llvm::Value *FieldPtrReduceShar = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 0, "reduce_shar");
builder.CreateStore(shared, FieldPtrReduceShar);

llvm::Value *FieldPtrReduceOrig = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 1, "reduce_orig");
builder.CreateStore(shared, FieldPtrReduceOrig);

// Store size of the reduction variable
llvm::Value *FieldPtrReduceSize = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 2, "reduce_size");
llvm::Type *redTy;
if (auto *alloca = dyn_cast<llvm::AllocaInst>(shared)) {
redTy = alloca->getAllocatedType();
uint64_t sizeInBytes = DL.getTypeAllocSize(redTy);

llvm::ConstantInt *sizeConst =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), sizeInBytes);
builder.CreateStore(sizeConst, FieldPtrReduceSize);
} else {
llvm_unreachable("Non alloca instruction found.");
}

// Initialize reduction variable
llvm::Value *FieldPtrReduceInit = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 3, "reduce_init");
llvm::Value *initFunction = createTaskReductionFunction(
builder, "red_init", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getInitializerRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(initFunction, FieldPtrReduceInit);

// Create finish and combine functions
llvm::Value *FieldPtrReduceFini = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 4, "reduce_fini");
llvm::Value *finiFunction = createTaskReductionFunction(
builder, "red_fini", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getCleanupRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(finiFunction, FieldPtrReduceFini);

llvm::Value *FieldPtrReduceComb = builder.CreateStructGEP(
kmp_taskred_input_t, StructPtr, 5, "reduce_comb");
llvm::Value *combFunction = createTaskReductionFunction(
builder, "red_comb", redTy, moduleTranslation, reductionDecls,
reductionDecls[Cnt].getReductionRegion(), op, Cnt, isByRef,
privateReductionVariables, reductionVariableMap);
builder.CreateStore(combFunction, FieldPtrReduceComb);

llvm::Value *FieldPtrFlags =
builder.CreateStructGEP(kmp_taskred_input_t, StructPtr, 6, "flags");
llvm::ConstantInt *flagVal =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(Context), 0);
builder.CreateStore(flagVal, FieldPtrFlags);
}

// Emit the runtime call
emitTaskRedInitCall(builder, moduleTranslation, ompLoc, arraySize,
ArrayAlloca);
return success();
}

/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
Expand All @@ -1825,8 +2023,23 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*tgOp)))
return failure();

llvm::ArrayRef<bool> isByRef = getIsByRef(tgOp.getTaskReductionByref());
assert(isByRef.size() == tgOp.getNumReductionVars());
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(tgOp, reductionDecls);
SmallVector<llvm::Value *> privateReductionVariables(
tgOp.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
tgOp.getRegion().getArguments();
LogicalResult bodyGenStatus = success();
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
builder.restoreIP(codegenIP);
if (failed(allocAndInitializeTaskReductionVars(
tgOp, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
bodyGenStatus = failure();
return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
builder, moduleTranslation)
.takeError();
Expand All @@ -1842,7 +2055,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
return failure();

builder.restoreIP(*afterIP);
return success();
return bodyGenStatus;
}

static LogicalResult
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-task-reduction.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

omp.declare_reduction @add_reduction_i32 : i32 init {
^bb0(%arg0: i32):
%0 = llvm.mlir.constant(0 : i32) : i32
omp.yield(%0 : i32)
} combiner {
^bb0(%arg0: i32, %arg1: i32):
%0 = llvm.add %arg0, %arg1 : i32
omp.yield(%0 : i32)
}
llvm.func @_QPtest_task_reduciton() {
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
omp.taskgroup task_reduction(@add_reduction_i32 %1 -> %arg0 : !llvm.ptr) {
%2 = llvm.load %1 : !llvm.ptr -> i32
%3 = llvm.mlir.constant(1 : i32) : i32
%4 = llvm.add %2, %3 : i32
llvm.store %4, %1 : i32, !llvm.ptr
omp.terminator
}
llvm.return
}

//CHECK-LABEL: define void @_QPtest_task_reduciton() {
//CHECK: %[[VAL1:.*]] = alloca i32, i64 1, align 4
//CHECK: %[[RED_ARRY:.*]] = alloca [1 x %kmp_taskred_input_t], align 8
//CHECK: br label %entry

//CHECK: entry:
//CHECK: %[[TID:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
//CHECK: call void @__kmpc_taskgroup(ptr @1, i32 %[[TID]])
//CHECK: %[[RED_ELEMENT:.*]] = getelementptr [1 x %kmp_taskred_input_t], ptr %[[RED_ARRY]], i32 0, i32 0
//CHECK: %[[RED_SHARED:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 0
//CHECK: store ptr %[[VAL1]], ptr %[[RED_SHARED]], align 8
//CHECK: %[[RED_ORIG:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 1
//CHECK: store ptr %[[VAL1]], ptr %[[RED_ORIG]], align 8
//CHECK: %[[RED_SIZE:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 2
//CHECK: store i64 4, ptr %[[RED_SIZE]], align 4
//CHECK: %[[RED_INIT:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 3
//CHECK: store ptr @red_init, ptr %[[RED_INIT]], align 8
//CHECK: %[[RED_FINI:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 4
//CHECK: store ptr null, ptr %[[RED_FINI]], align 8
//CHECK: %[[RED_COMB:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 5
//CHECK: store ptr @red_comb, ptr %[[RED_COMB]], align 8
//CHECK: %[[FLAGS:.*]] = getelementptr inbounds nuw %kmp_taskred_input_t, ptr %[[RED_ELEMENT]], i32 0, i32 6
//CHECK: store i64 0, ptr %[[FLAGS]], align 4
//CHECK: %[[TID1:.*]] = call i32 @__kmpc_global_thread_num(ptr @{{.*}})
//CHECK: %2 = call ptr @__kmpc_taskred_init(i32 %[[TID1]], i32 1, ptr %[[RED_ARRY]])
//CHECK: br label %omp.taskgroup.region

//CHECK: omp.taskgroup.region:
//CHECK: %[[VAL3:.*]] = load i32, ptr %[[VAL1]], align 4
//CHECK: %4 = add i32 %[[VAL3]], 1
//CHECK: store i32 %4, ptr %[[VAL1]], align 4
//CHECK: br label %omp.region.cont

//CHECK: omp.region.cont:
//CHECK: br label %taskgroup.exit

//CHECK: taskgroup.exit:
//CHECK: call void @__kmpc_end_taskgroup(ptr @{{.+}}, i32 %[[TID]])
//CHECK: ret void
//CHECK: }

//CHECK-LABEL: define ptr @red_init(ptr noalias %0, ptr noalias %1) #2 {
//CHECK: entry:
//CHECK: store i32 0, ptr %0, align 4
//CHECK: ret ptr %0
//CHECK: }

//CHECK-LABEL: define ptr @red_comb(ptr %0, ptr %1) #2 {
//CHECK: entry:
//CHECK: %[[LD0:.*]] = load i32, ptr %0, align 4
//CHECK: %[[LD1:.*]] = load i32, ptr %1, align 4
//CHECK: %[[RES:.*]] = add i32 %[[LD0]], %[[LD1]]
//CHECK: store i32 %[[RES]], ptr %0, align 4
//CHECK: ret ptr %0
//CHECK: }
28 changes: 0 additions & 28 deletions mlir/test/Target/LLVMIR/openmp-todo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -442,34 +442,6 @@ llvm.func @taskgroup_allocate(%x : !llvm.ptr) {

// -----

omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
%0 = llvm.mlir.constant(0.0 : f32) : f32
omp.yield (%0 : f32)
}
combiner {
^bb1(%arg0: f32, %arg1: f32):
%1 = llvm.fadd %arg0, %arg1 : f32
omp.yield (%1 : f32)
}
atomic {
^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
%2 = llvm.load %arg3 : !llvm.ptr -> f32
llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
omp.yield
}
llvm.func @taskgroup_task_reduction(%x : !llvm.ptr) {
// expected-error@below {{not yet implemented: Unhandled clause task_reduction in omp.taskgroup operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
omp.taskgroup task_reduction(@add_f32 %x -> %prv : !llvm.ptr) {
omp.terminator
}
llvm.return
}

// -----

llvm.func @taskloop(%lb : i32, %ub : i32, %step : i32) {
// expected-error@below {{not yet implemented: omp.taskloop}}
// expected-error@below {{LLVM Translation failed for operation: omp.taskloop}}
Expand Down