Skip to content

[flang] add FIR to FIR pass to lower assumed-rank operations #93344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
mlir::SymbolTable *symbolTable = nullptr)
: OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
symbolTable{symbolTable} {}
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap,
mlir::SymbolTable *symbolTable = nullptr)
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)},
symbolTable{symbolTable} {
setListener(this);
}
explicit FirOpBuilder(mlir::OpBuilder &builder, mlir::ModuleOp mod)
Expand Down
6 changes: 6 additions & 0 deletions flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ constexpr TypeBuilderFunc getModel<signed char>() {
};
}
template <>
constexpr TypeBuilderFunc getModel<unsigned char>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(context, 8 * sizeof(unsigned char));
};
}
template <>
constexpr TypeBuilderFunc getModel<void *>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::LLVMPointerType::get(context,
Expand Down
31 changes: 31 additions & 0 deletions flang/include/flang/Optimizer/Builder/Runtime/Support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===-- Support.h - generate support runtime API calls ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H
#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H

namespace mlir {
class Value;
class Location;
} // namespace mlir

namespace fir {
class FirOpBuilder;
}

namespace fir::runtime {

/// Generate call to `CopyAndUpdateDescriptor` runtime routine.
void genCopyAndUpdateDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value to, mlir::Value from,
mlir::Value newDynamicType,
mlir::Value newAttribute,
mlir::Value newLowerBounds);

} // namespace fir::runtime
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class BaseBoxType : public mlir::Type {
/// Return the same type, except for the shape, that is taken the shape
/// of shapeMold.
BaseBoxType getBoxTypeWithNewShape(mlir::Type shapeMold) const;
BaseBoxType getBoxTypeWithNewShape(int rank) const;

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(mlir::Type type);
Expand Down
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace fir {
#define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
#define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
#define GEN_PASS_DECL_ARRAYVALUECOPY
#define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
#define GEN_PASS_DECL_CHARACTERCONVERSION
#define GEN_PASS_DECL_CFGCONVERSION
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
Expand Down
12 changes: 12 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,16 @@ def FunctionAttr : Pass<"function-attr", "mlir::func::FuncOp"> {
let constructor = "::fir::createFunctionAttrPass()";
}

def AssumedRankOpConversion : Pass<"fir-assumed-rank-op", "mlir::ModuleOp"> {
let summary =
"Simplify operations on assumed-rank types";
let description = [{
This pass breaks up the lowering of operations on assumed-rank types by
introducing an intermediate FIR level that simplifies code generation.
}];
let dependentDialects = [
"fir::FIROpsDialect", "mlir::func::FuncDialect"
];
}

#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ inline void createDefaultFIROptimizerPassPipeline(

// Polymorphic types
pm.addPass(fir::createPolymorphicOpConversion());
pm.addPass(fir::createAssumedRankOpConversion());

if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags)
pm.addPass(fir::createAddAliasTags());
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Builder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_flang_library(FIRBuilder
Runtime/Ragged.cpp
Runtime/Reduction.cpp
Runtime/Stop.cpp
Runtime/Support.cpp
Runtime/TemporaryStack.cpp
Runtime/Transformational.cpp
TemporaryStorage.cpp
Expand Down
46 changes: 46 additions & 0 deletions flang/lib/Optimizer/Builder/Runtime/Support.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
//===-- Support.cpp - generate support runtime API calls --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/Runtime/Support.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Runtime/support.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

using namespace Fortran::runtime;

template <>
constexpr fir::runtime::TypeBuilderFunc
fir::runtime::getModel<Fortran::runtime::LowerBoundModifier>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(
context, sizeof(Fortran::runtime::LowerBoundModifier) * 8);
};
}

void fir::runtime::genCopyAndUpdateDescriptor(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value to, mlir::Value from,
mlir::Value newDynamicType,
mlir::Value newAttribute,
mlir::Value newLowerBounds) {
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CopyAndUpdateDescriptor)>(loc,
builder);
auto fTy = func.getFunctionType();
auto args =
fir::runtime::createArguments(builder, loc, fTy, to, from, newDynamicType,
newAttribute, newLowerBounds);
llvm::StringRef noCapture = mlir::LLVM::LLVMDialect::getNoCaptureAttrName();
if (!func.getArgAttr(0, noCapture)) {
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(func.getContext());
func.setArgAttr(0, noCapture, unitAttr);
func.setArgAttr(1, noCapture, unitAttr);
}
builder.create<fir::CallOp>(loc, func, args);
}
11 changes: 11 additions & 0 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,17 @@ fir::BaseBoxType::getBoxTypeWithNewShape(mlir::Type shapeMold) const {
return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
}

fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewShape(int rank) const {
std::optional<fir::SequenceType::ShapeRef> newShape;
fir::SequenceType::Shape shapeVector;
if (rank > 0) {
shapeVector =
fir::SequenceType::Shape(rank, fir::SequenceType::getUnknownExtent());
newShape = shapeVector;
}
return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
}

bool fir::BaseBoxType::isAssumedRank() const {
if (auto seqTy =
mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(getEleTy())))
Expand Down
131 changes: 131 additions & 0 deletions flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
//===-- AssumedRankOpConversion.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Common/Fortran.h"
#include "flang/Lower/BuiltinModules.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/Support.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Support/TypeCode.h"
#include "flang/Optimizer/Support/Utils.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/support.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace fir {
#define GEN_PASS_DEF_ASSUMEDRANKOPCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace fir;
using namespace mlir;

namespace {

static int getCFIAttribute(mlir::Type boxType) {
if (fir::isAllocatableType(boxType))
return CFI_attribute_allocatable;
if (fir::isPointerType(boxType))
return CFI_attribute_pointer;
return CFI_attribute_other;
}

static Fortran::runtime::LowerBoundModifier
getLowerBoundModifier(fir::LowerBoundModifierAttribute modifier) {
switch (modifier) {
case fir::LowerBoundModifierAttribute::Preserve:
return Fortran::runtime::LowerBoundModifier::Preserve;
case fir::LowerBoundModifierAttribute::SetToOnes:
return Fortran::runtime::LowerBoundModifier::SetToOnes;
case fir::LowerBoundModifierAttribute::SetToZeroes:
return Fortran::runtime::LowerBoundModifier::SetToZeroes;
}
llvm_unreachable("bad modifier code");
}

class ReboxAssumedRankConv
: public mlir::OpRewritePattern<fir::ReboxAssumedRankOp> {
public:
using OpRewritePattern::OpRewritePattern;

ReboxAssumedRankConv(mlir::MLIRContext *context,
mlir::SymbolTable *symbolTable, fir::KindMapping kindMap)
: mlir::OpRewritePattern<fir::ReboxAssumedRankOp>(context),
symbolTable{symbolTable}, kindMap{kindMap} {};

mlir::LogicalResult
matchAndRewrite(fir::ReboxAssumedRankOp rebox,
mlir::PatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
mlir::Location loc = rebox.getLoc();
auto newBoxType = mlir::cast<fir::BaseBoxType>(rebox.getType());
mlir::Type newMaxRankBoxType =
newBoxType.getBoxTypeWithNewShape(Fortran::common::maxRank);
// CopyAndUpdateDescriptor FIR interface requires loading
// !fir.ref<fir.box> input which is expensive with assumed-rank. It could
// be best to add an entry point that takes a non "const" from to cover
// this case, but it would be good to indicate to LLVM that from does not
// get modified.
if (fir::isBoxAddress(rebox.getBox().getType()))
TODO(loc, "fir.rebox_assumed_rank codegen with fir.ref<fir.box<>> input");
mlir::Value tempDesc = builder.createTemporary(loc, newMaxRankBoxType);
mlir::Value newDtype;
mlir::Type newEleType = newBoxType.unwrapInnerType();
auto oldBoxType = mlir::cast<fir::BaseBoxType>(
fir::unwrapRefType(rebox.getBox().getType()));
auto newDerivedType = mlir::dyn_cast<fir::RecordType>(newEleType);
if (newDerivedType && (newEleType != oldBoxType.unwrapInnerType()) &&
!fir::isPolymorphicType(newBoxType)) {
newDtype = builder.create<fir::TypeDescOp>(
loc, mlir::TypeAttr::get(newDerivedType));
} else {
newDtype = builder.createNullConstant(loc);
}
mlir::Value newAttribute = builder.createIntegerConstant(
loc, builder.getIntegerType(8), getCFIAttribute(newBoxType));
int lbsModifierCode =
static_cast<int>(getLowerBoundModifier(rebox.getLbsModifier()));
mlir::Value lowerBoundModifier = builder.createIntegerConstant(
loc, builder.getIntegerType(32), lbsModifierCode);
fir::runtime::genCopyAndUpdateDescriptor(builder, loc, tempDesc,
rebox.getBox(), newDtype,
newAttribute, lowerBoundModifier);

mlir::Value descValue = builder.create<fir::LoadOp>(loc, tempDesc);
mlir::Value castDesc = builder.createConvert(loc, newBoxType, descValue);
rewriter.replaceOp(rebox, castDesc);
return mlir::success();
}

private:
mlir::SymbolTable *symbolTable = nullptr;
fir::KindMapping kindMap;
};

/// Convert FIR structured control flow ops to CFG ops.
class AssumedRankOpConversion
: public fir::impl::AssumedRankOpConversionBase<AssumedRankOpConversion> {
public:
void runOnOperation() override {
auto *context = &getContext();
mlir::ModuleOp mod = getOperation();
mlir::SymbolTable symbolTable(mod);
fir::KindMapping kindMap = fir::getKindMapping(mod);
mlir::RewritePatternSet patterns(context);
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = false;
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
}
};
} // namespace
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_flang_library(FIRTransforms
AffinePromotion.cpp
AffineDemotion.cpp
AnnotateConstant.cpp
AssumedRankOpConversion.cpp
CharacterConversion.cpp
ControlFlowConverter.cpp
ArrayValueCopy.cpp
Expand Down
1 change: 1 addition & 0 deletions flang/test/Driver/bbc-mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

! CHECK-NEXT: PolymorphicOpConversion
! CHECK-NEXT: AssumedRankOpConversion

! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! CHECK-NEXT: 'fir.global' Pipeline
Expand Down
1 change: 1 addition & 0 deletions flang/test/Driver/mlir-debug-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

! ALL-NEXT: PolymorphicOpConversion
! ALL-NEXT: AssumedRankOpConversion

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! ALL-NEXT: 'fir.global' Pipeline
Expand Down
1 change: 1 addition & 0 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

! ALL-NEXT: PolymorphicOpConversion
! ALL-NEXT: AssumedRankOpConversion
! O2-NEXT: AddAliasTags

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
Expand Down
1 change: 1 addition & 0 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

// PASSES-NEXT: PolymorphicOpConversion
// PASSES-NEXT: AssumedRankOpConversion
// PASSES-NEXT: AddAliasTags

// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
Expand Down
Loading
Loading