Skip to content

[OpenMP][MLIR] Extend record member map support for omp dialect to LLVM-IR #82852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 191 additions & 66 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
#include "llvm/Transforms/Utils/ModuleUtils.h"

#include <any>
#include <cstdint>
#include <iterator>
#include <numeric>
#include <optional>
#include <utility>

Expand Down Expand Up @@ -2037,7 +2039,7 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
bounds.getDefiningOp())) {
// The below calculation for the size to be mapped calculated from the
// map_info's bounds is: (elemCount * [UB - LB] + 1), later we
// map.info's bounds is: (elemCount * [UB - LB] + 1), later we
// multiply by the underlying element types byte size to get the full
// size to be offloaded based on the bounds
elementCount = builder.CreateMul(
Expand Down Expand Up @@ -2089,9 +2091,9 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,

mapData.BaseType.push_back(
moduleTranslation.convertType(mapOp.getVarType()));
mapData.Sizes.push_back(getSizeInBytes(
dl, mapOp.getVarType(), mapOp, mapData.BasePointers.back(),
mapData.BaseType.back(), builder, moduleTranslation));
mapData.Sizes.push_back(
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
mapData.BaseType.back(), builder, moduleTranslation));
mapData.MapClause.push_back(mapOp.getOperation());
mapData.Types.push_back(
llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
Expand Down Expand Up @@ -2122,6 +2124,67 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
}
}

static int getMapDataMemberIdx(MapInfoData &mapData,
mlir::omp::MapInfoOp memberOp) {
auto *res = llvm::find(mapData.MapClause, memberOp);
assert(res != mapData.MapClause.end() &&
"MapInfoOp for member not found in MapData, cannot return index");
return std::distance(mapData.MapClause.begin(), res);
}

static mlir::omp::MapInfoOp
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
mlir::DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();

// Only 1 member has been mapped, we can return it.
if (indexAttr.size() == 1)
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
mapInfo.getMembers()[0].getDefiningOp()))
return mapOp;

llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
llvm::SmallVector<size_t> indices(shape[0]);
std::iota(indices.begin(), indices.end(), 0);

llvm::sort(
indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
auto indexValues = indexAttr.getValues<int32_t>();
for (int i = 0;
i < shape[1];
++i) {
int aIndex = indexValues[a * shape[1] + i];
int bIndex = indexValues[b * shape[1] + i];

if (aIndex != -1 && bIndex == -1)
return false;

if (aIndex == -1 && bIndex != -1)
return true;

if (aIndex == -1)
return first;

if (bIndex == -1)
return !first;

// A is earlier in the record type layout than B
if (aIndex < bIndex)
return first;

if (bIndex < aIndex)
return !first;
}

// iterated the entire list and couldn't make a decision, all elements
// were likely the same, return true for now similar to reaching the end
// of both and finding invalid indices.
return true;
});

return llvm::cast<mlir::omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
}

/// This function calculates the array/pointer offset for map data provided
/// with bounds operations, e.g. when provided something like the following:
///
Expand Down Expand Up @@ -2227,6 +2290,9 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
// which is utilised in subsequent member mappings (by modifying there map type
// with it) to indicate that a member is part of this parent and should be
// treated by the runtime as such. Important to achieve the correct mapping.
//
// This function borrows a lot from Clang's emitCombinedEntry function
// inside of CGOpenMPRuntime.cpp
static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
Expand All @@ -2242,7 +2308,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);

// Calculate size of the parent object being mapped based on the
// addresses at runtime, highAddr - lowAddr = size. This of course
Expand All @@ -2251,42 +2316,68 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// Fortran pointers and allocatables, the mapping of the pointed to
// data by the descriptor (which itself, is a structure containing
// runtime information on the dynamically allocated data).
llvm::Value *lowAddr = builder.CreatePointerCast(
mapData.Pointers[mapDataIndex], builder.getPtrTy());
llvm::Value *highAddr = builder.CreatePointerCast(
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
mapData.Pointers[mapDataIndex], 1),
builder.getPtrTy());
auto parentClause =
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);

llvm::Value *lowAddr, *highAddr;
if (!parentClause.getPartialMap()) {
lowAddr = builder.CreatePointerCast(mapData.Pointers[mapDataIndex],
builder.getPtrTy());
highAddr = builder.CreatePointerCast(
builder.CreateConstGEP1_32(mapData.BaseType[mapDataIndex],
mapData.Pointers[mapDataIndex], 1),
builder.getPtrTy());
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
} else {
auto mapOp =
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
int firstMemberIdx = getMapDataMemberIdx(
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
builder.getPtrTy());
int lastMemberIdx = getMapDataMemberIdx(
mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
highAddr = builder.CreatePointerCast(
builder.CreateGEP(mapData.BaseType[lastMemberIdx],
mapData.Pointers[lastMemberIdx], builder.getInt64(1)),
builder.getPtrTy());
combinedInfo.Pointers.emplace_back(mapData.Pointers[firstMemberIdx]);
}

llvm::Value *size = builder.CreateIntCast(
builder.CreatePtrDiff(builder.getInt8Ty(), highAddr, lowAddr),
builder.getInt64Ty(),
/*isSigned=*/false);
combinedInfo.Sizes.push_back(size);

// This creates the initial MEMBER_OF mapping that consists of
// the parent/top level container (same as above effectively, except
// with a fixed initial compile time size and seperate maptype which
// indicates the true mape type (tofrom etc.) and that it is a part
// of a larger mapping and indicating the link between it and it's
// members that are also explicitly mapped).
// TODO: This will need to be expanded to include the whole host of logic for
// the map flags that Clang currently supports (e.g. it should take the map
// flag of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some
// further case specific flag modifications). For the moment, it handles what
// we support as expected.
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
if (isTargetParams)
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;

llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);

combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);

// This creates the initial MEMBER_OF mapping that consists of
// the parent/top level container (same as above effectively, except
// with a fixed initial compile time size and seperate maptype which
// indicates the true mape type (tofrom etc.). This parent mapping is
// only relevant if the structure in its totality is being mapped,
// otherwise the above suffices.
if (!parentClause.getPartialMap()) {
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.DevicePointers.emplace_back(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(LLVM::createMappingInformation(
mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder));
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIndex]);
}
return memberOfFlag;
}

Expand Down Expand Up @@ -2319,21 +2410,17 @@ static void processMapMembersWithParent(
uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {

auto parentClause =
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);

for (auto mappedMembers : parentClause.getMembers()) {
auto memberClause =
mlir::dyn_cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
int memberDataIdx = -1;
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
if (mapData.MapClause[i] == memberClause)
memberDataIdx = i;
}
llvm::cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);

assert(memberDataIdx >= 0 && "could not find mapped member of structure");

// Same MemberOfFlag to indicate its link with parent and other members
// of, and we flag that it's part of a pointer and object coupling.
// of.
auto mapFlag =
llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType().value());
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
Expand All @@ -2347,18 +2434,81 @@ static void processMapMembersWithParent(
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
combinedInfo.Names.emplace_back(
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));

combinedInfo.BasePointers.emplace_back(mapData.BasePointers[memberDataIdx]);
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
}
}

static void
processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo,
bool isTargetParams, int mapDataParentIdx = -1) {
// Declare Target Mappings are excluded from being marked as
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
// marked with OMP_MAP_PTR_AND_OBJ instead.
auto mapFlag = mapData.Types[mapDataIdx];
auto mapInfoOp =
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);

bool isPtrTy = checkIfPointerMap(mapInfoOp);
if (isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;

if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;

if (mapInfoOp.getMapCaptureType().value() ==
mlir::omp::VariableCaptureKind::ByCopy &&
!isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;

// if we're provided a mapDataParentIdx, then the data being mapped is
// part of a larger object (in a parent <-> member mapping) and in this
// case our BasePointer should be the parent.
if (mapDataParentIdx >= 0)
combinedInfo.BasePointers.emplace_back(
mapData.BasePointers[mapDataParentIdx]);
else
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIdx]);

combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]);
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]);
combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]);
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]);
}

static void processMapWithMembersOf(
LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl,
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
uint64_t mapDataIndex, bool isTargetParams) {
auto parentClause =
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);

// If we have a partial map (no parent referenced in the map clauses of the
// directive, only members) and only a single member, we do not need to bind
// the map of the member to the parent, we can pass the member seperately.
if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
parentClause.getMembers()[0].getDefiningOp());
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
// Note: Clang treats arrays with explicit bounds that fall into this
// category as a parent with map case, however, it seems this isn't a
// requirement, and processing them as an individual map is fine. So,
// we will handle them as individual maps for the moment, as it's
// difficult for us to check this as we always require bounds to be
// specified currently and it's also marginally more optimal (single
// map rather than two). The difference may come from the fact that
// Clang maps array without bounds as pointers (which we do not
// currently do), whereas we treat them as arrays in all cases
// currently.
processIndividualMap(mapData, memberDataIdx, combinedInfo, isTargetParams,
mapDataIndex);
return;
}

llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
combinedInfo, mapData, mapDataIndex, isTargetParams);
Expand Down Expand Up @@ -2477,12 +2627,8 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
// utilise the size from any component of MapInfoData, if we can't
// something is missing from the initial MapInfoData construction.
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
// NOTE/TODO: We currently do not handle member mapping seperately from it's
// parent or explicit mapping of a parent and member in the same operation,
// this will need to change in the near future, for now we primarily handle
// descriptor mapping from fortran, generalised as mapping record types
// with implicit member maps. This lowering needs further generalisation to
// fully support fortran derived types, and C/C++ structures and classes.
// NOTE/TODO: We currently do not support arbitrary depth record
// type mapping.
if (mapData.IsAMember[i])
continue;

Expand All @@ -2493,28 +2639,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
continue;
}

auto mapFlag = mapData.Types[i];
bool isPtrTy = checkIfPointerMap(mapInfoOp);
if (isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;

// Declare Target Mappings are excluded from being marked as
// OMP_MAP_TARGET_PARAM as they are not passed as parameters.
if (isTargetParams && !mapData.IsDeclareTarget[i])
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;

if (auto mapInfoOp = dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]))
if (mapInfoOp.getMapCaptureType().value() ==
mlir::omp::VariableCaptureKind::ByCopy &&
!isPtrTy)
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;

combinedInfo.BasePointers.emplace_back(mapData.BasePointers[i]);
combinedInfo.Pointers.emplace_back(mapData.Pointers[i]);
combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[i]);
combinedInfo.Names.emplace_back(mapData.Names[i]);
combinedInfo.Types.emplace_back(mapFlag);
combinedInfo.Sizes.emplace_back(mapData.Sizes[i]);
processIndividualMap(mapData, i, combinedInfo, isTargetParams);
}

auto findMapInfo = [&combinedInfo](llvm::Value *val, unsigned &index) {
Expand Down
Loading