From 943119436397e5554eadf64688ad5a01205d0567 Mon Sep 17 00:00:00 2001 From: Kazuaki Matsumura Date: Thu, 30 May 2024 15:48:04 -0700 Subject: [PATCH 1/3] [flang] Add reduction semantics to fir.do_loop --- .../flang/Optimizer/Dialect/FIRAttr.td | 30 ++++++++ .../include/flang/Optimizer/Dialect/FIROps.td | 64 ++++++++++++++-- flang/lib/Optimizer/Dialect/FIRAttr.cpp | 4 +- flang/lib/Optimizer/Dialect/FIROps.cpp | 73 +++++++++++++++++-- 4 files changed, 154 insertions(+), 17 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td index 0c34b640a5c9c..aedb6769186e9 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td @@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr< let cppNamespace = "fir"; } +def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum", + "intrinsic operations and functions supported by DO CONCURRENT REDUCE", + [ + I32BitEnumAttrCaseBit<"Add", 0, "add">, + I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">, + I32BitEnumAttrCaseBit<"AND", 2, "and">, + I32BitEnumAttrCaseBit<"OR", 3, "or">, + I32BitEnumAttrCaseBit<"EQV", 4, "eqv">, + I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">, + I32BitEnumAttrCaseBit<"MAX", 6, "max">, + I32BitEnumAttrCaseBit<"MIN", 7, "min">, + I32BitEnumAttrCaseBit<"IAND", 8, "iand">, + I32BitEnumAttrCaseBit<"IOR", 9, "ior">, + I32BitEnumAttrCaseBit<"EIOR", 10, "eior"> + ]> { + let separator = ", "; + let cppNamespace = "::fir"; + let printBitEnumPrimaryGroups = 1; +} + +def fir_ReduceAttr : fir_Attr<"Reduce"> { + let mnemonic = "reduce_attr"; + + let parameters = (ins + "ReduceOperationEnum":$reduce_operation + ); + + let assemblyFormat = "`<` $reduce_operation `>`"; +} + // mlir::SideEffects::Resource for modelling operations which add debugging information def DebuggingResource : Resource<"::fir::DebuggingResource">; diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 3afc97475db11..d79f2da916d05 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2107,8 +2107,37 @@ class region_Op traits = []> : let hasVerifier = 1; } -def fir_DoLoopOp : region_Op<"do_loop", - [DeclareOpInterfaceMethods { + let summary = "Represent reduction semantics for the reduce clause"; + + let description = [{ + Given the address of a variable, creates reduction information for the + reduce clause. + + ``` + %17 = fir.reduce %8 {name = "sum"} : (!fir.ref) -> !fir.ref + fir.do_loop ... unordered reduce(#fir.reduce_attr -> %17 : !fir.ref) ... + ``` + + This operation is typically used for DO CONCURRENT REDUCE clause. The memref + operand may have a unique name while the `name` attribute preserves the + original name of a reduction variable. + }]; + + let arguments = (ins + AnyRefOrBoxLike:$memref, + Builtin_StringAttr:$name + ); + + let results = (outs AnyRefOrBox); + + let assemblyFormat = [{ + operands attr-dict `:` functional-type(operands, results) + }]; +} + +def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { let summary = "generalized loop operation"; let description = [{ @@ -2138,9 +2167,11 @@ def fir_DoLoopOp : region_Op<"do_loop", Index:$lowerBound, Index:$upperBound, Index:$step, + Variadic:$reduceOperands, Variadic:$initArgs, OptionalAttr:$unordered, - OptionalAttr:$finalValue + OptionalAttr:$finalValue, + OptionalAttr:$reduceAttrs ); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -2151,6 +2182,8 @@ def fir_DoLoopOp : region_Op<"do_loop", "mlir::Value":$step, CArg<"bool", "false">:$unordered, CArg<"bool", "false">:$finalCountValue, CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs, + CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands, + CArg<"llvm::ArrayRef", "{}">:$reduceAttrs, CArg<"llvm::ArrayRef", "{}">:$attributes)> ]; @@ -2163,11 +2196,12 @@ def fir_DoLoopOp : region_Op<"do_loop", return getBody()->getArguments().drop_front(); } mlir::Operation::operand_range getIterOperands() { - return getOperands().drop_front(getNumControlOperands()); + return getOperands() + .drop_front(getNumControlOperands() + getNumReduceOperands()); } llvm::MutableArrayRef getInitsMutable() { - return - getOperation()->getOpOperands().drop_front(getNumControlOperands()); + return getOperation()->getOpOperands() + .drop_front(getNumControlOperands() + getNumReduceOperands()); } void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); } @@ -2182,11 +2216,25 @@ def fir_DoLoopOp : region_Op<"do_loop", unsigned getNumControlOperands() { return 3; } /// Does the operation hold operands for loop-carried values bool hasIterOperands() { - return (*this)->getNumOperands() > getNumControlOperands(); + return getNumIterOperands() > 0; + } + /// Does the operation hold operands for reduction variables + bool hasReduceOperands() { + return getNumReduceOperands() > 0; + } + /// Get Number of variadic operands + unsigned getNumOperands(unsigned idx) { + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); + return static_cast(segments[idx]); + } + // Get Number of reduction operands + unsigned getNumReduceOperands() { + return getNumOperands(3); } /// Get Number of loop-carried values unsigned getNumIterOperands() { - return (*this)->getNumOperands() - getNumControlOperands(); + return getNumOperands(4); } /// Get the body of the loop diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp index 2faba63dfba07..a0202a0159228 100644 --- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp +++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp @@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr, void FIROpsDialect::registerAttributes() { addAttributes(); + LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr, + SubclassAttr, UpperBoundAttr>(); } diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index b541b7cdc7a5b..807459c8ec3c7 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2079,9 +2079,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value lb, mlir::Value ub, mlir::Value step, bool unordered, bool finalCountValue, mlir::ValueRange iterArgs, + mlir::ValueRange reduceOperands, + llvm::ArrayRef reduceAttrs, llvm::ArrayRef attributes) { result.addOperands({lb, ub, step}); + result.addOperands(reduceOperands); result.addOperands(iterArgs); + result.addAttribute(getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, static_cast(reduceOperands.size()), + static_cast(iterArgs.size())})); if (finalCountValue) { result.addTypes(builder.getIndexType()); result.addAttribute(getFinalValueAttrName(result.name), @@ -2100,6 +2107,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder, if (unordered) result.addAttribute(getUnorderedAttrName(result.name), builder.getUnitAttr()); + if (!reduceAttrs.empty()) + result.addAttribute(getReduceAttrsAttrName(result.name), + builder.getArrayAttr(reduceAttrs)); result.addAttributes(attributes); } @@ -2125,24 +2135,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) result.addAttribute("unordered", builder.getUnitAttr()); + // Parse the reduction arguments. + llvm::SmallVector reduceOperands; + llvm::SmallVector reduceArgTypes; + if (succeeded(parser.parseOptionalKeyword("reduce"))) { + // Parse reduction attributes and variables. + llvm::SmallVector attributes; + if (failed(parser.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::Paren, [&]() { + if (parser.parseAttribute(attributes.emplace_back()) || + parser.parseArrow() || + parser.parseOperand(reduceOperands.emplace_back()) || + parser.parseColonType(reduceArgTypes.emplace_back())) + return mlir::failure(); + return mlir::success(); + }))) + return mlir::failure(); + // Resolve input operands. + for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) + if (parser.resolveOperand(std::get<0>(operand_type), + std::get<1>(operand_type), result.operands)) + return mlir::failure(); + llvm::SmallVector arrayAttr(attributes.begin(), + attributes.end()); + result.addAttribute(getReduceAttrsAttrName(result.name), + builder.getArrayAttr(arrayAttr)); + } + // Parse the optional initial iteration arguments. llvm::SmallVector regionArgs; - llvm::SmallVector operands; + llvm::SmallVector iterOperands; llvm::SmallVector argTypes; bool prependCount = false; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { // Parse assignment list and results type list. - if (parser.parseAssignmentList(regionArgs, operands) || + if (parser.parseAssignmentList(regionArgs, iterOperands) || parser.parseArrowTypeList(result.types)) return mlir::failure(); - if (result.types.size() == operands.size() + 1) + if (result.types.size() == iterOperands.size() + 1) prependCount = true; // Resolve input operands. llvm::ArrayRef resTypes = result.types; - for (auto operand_type : - llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes)) + for (auto operand_type : llvm::zip( + iterOperands, prependCount ? resTypes.drop_front() : resTypes)) if (parser.resolveOperand(std::get<0>(operand_type), std::get<1>(operand_type), result.operands)) return mlir::failure(); @@ -2153,6 +2190,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, prependCount = true; } + // Set the operandSegmentSizes attribute + result.addAttribute(getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {1, 1, 1, static_cast(reduceOperands.size()), + static_cast(iterOperands.size())})); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return mlir::failure(); @@ -2229,6 +2272,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() { i++; } + auto reduceAttrs = getReduceAttrsAttr(); + if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) + return emitOpError( + "mismatch in number of reduction variables and reduction attributes"); return mlir::success(); } @@ -2238,6 +2285,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { << getUpperBound() << " step " << getStep(); if (getUnordered()) p << " unordered"; + if (hasReduceOperands()) { + p << " reduce("; + auto attrs = getReduceAttrsAttr(); + auto operands = getReduceOperands(); + llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { + p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " + << std::get<1>(it).getType(); + }); + p << ')'; + printBlockTerminators = true; + } if (hasIterOperands()) { p << " iter_args("; auto regionArgs = getRegionIterArgs(); @@ -2251,8 +2309,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { p << " -> " << getResultTypes(); printBlockTerminators = true; } - p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), - {"unordered", "finalValue"}); + p.printOptionalAttrDictWithKeyword( + (*this)->getAttrs(), + {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"}); p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, printBlockTerminators); From 74c06ae6f302813ef9a128b05ddbf70912d7e0b8 Mon Sep 17 00:00:00 2001 From: Kazuaki Matsumura Date: Mon, 3 Jun 2024 08:36:42 -0700 Subject: [PATCH 2/3] [flang] Add test/Fir/loop03.fir to test the reduction semantics of fir.do_loop --- flang/test/Fir/loop03.fir | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 flang/test/Fir/loop03.fir diff --git a/flang/test/Fir/loop03.fir b/flang/test/Fir/loop03.fir new file mode 100644 index 0000000000000..916ccaeaa2aef --- /dev/null +++ b/flang/test/Fir/loop03.fir @@ -0,0 +1,19 @@ +// Test the reduction semantics of fir.do_loop +// RUN: fir-opt %s | FileCheck %s + +func.func @reduction() { + %bound = arith.constant 10 : index + %step = arith.constant 1 : index + %sum = fir.alloca i32 + %red = fir.reduce %sum {name = "sum"} : (!fir.ref) -> !fir.ref +// CHECK: %[[VAL_0:.*]] = fir.alloca i32 +// CHECK: %[[VAL_1:.*]] = fir.reduce %[[VAL_0]] {name = "sum"} : (!fir.ref) -> !fir.ref +// CHECK: fir.do_loop %[[VAL_2:.*]] = %[[VAL_3:.*]] to %[[VAL_4:.*]] step %[[VAL_5:.*]] unordered reduce(#fir.reduce_attr -> %[[VAL_1]] : !fir.ref) { + fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr -> %red : !fir.ref) { + %index = fir.convert %iv : (index) -> i32 + %1 = fir.load %sum : !fir.ref + %2 = arith.addi %index, %1 : i32 + fir.store %2 to %sum : !fir.ref + } + return +} From 0b655d1cd476efb065e83ea15ce6821a4b49132a Mon Sep 17 00:00:00 2001 From: Kazuaki Matsumura Date: Mon, 3 Jun 2024 09:10:32 -0700 Subject: [PATCH 3/3] [flang] Remove fir.reduce --- .../include/flang/Optimizer/Dialect/FIROps.td | 29 ------------------- flang/test/Fir/loop03.fir | 6 ++-- 2 files changed, 2 insertions(+), 33 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index d79f2da916d05..0a7bd4178517a 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2107,35 +2107,6 @@ class region_Op traits = []> : let hasVerifier = 1; } -def fir_ReduceOp : fir_SimpleOp<"reduce", [NoMemoryEffect]> { - let summary = "Represent reduction semantics for the reduce clause"; - - let description = [{ - Given the address of a variable, creates reduction information for the - reduce clause. - - ``` - %17 = fir.reduce %8 {name = "sum"} : (!fir.ref) -> !fir.ref - fir.do_loop ... unordered reduce(#fir.reduce_attr -> %17 : !fir.ref) ... - ``` - - This operation is typically used for DO CONCURRENT REDUCE clause. The memref - operand may have a unique name while the `name` attribute preserves the - original name of a reduction variable. - }]; - - let arguments = (ins - AnyRefOrBoxLike:$memref, - Builtin_StringAttr:$name - ); - - let results = (outs AnyRefOrBox); - - let assemblyFormat = [{ - operands attr-dict `:` functional-type(operands, results) - }]; -} - def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { diff --git a/flang/test/Fir/loop03.fir b/flang/test/Fir/loop03.fir index 916ccaeaa2aef..b88dcaf8639be 100644 --- a/flang/test/Fir/loop03.fir +++ b/flang/test/Fir/loop03.fir @@ -5,11 +5,9 @@ func.func @reduction() { %bound = arith.constant 10 : index %step = arith.constant 1 : index %sum = fir.alloca i32 - %red = fir.reduce %sum {name = "sum"} : (!fir.ref) -> !fir.ref // CHECK: %[[VAL_0:.*]] = fir.alloca i32 -// CHECK: %[[VAL_1:.*]] = fir.reduce %[[VAL_0]] {name = "sum"} : (!fir.ref) -> !fir.ref -// CHECK: fir.do_loop %[[VAL_2:.*]] = %[[VAL_3:.*]] to %[[VAL_4:.*]] step %[[VAL_5:.*]] unordered reduce(#fir.reduce_attr -> %[[VAL_1]] : !fir.ref) { - fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr -> %red : !fir.ref) { +// CHECK: fir.do_loop %[[VAL_1:.*]] = %[[VAL_2:.*]] to %[[VAL_3:.*]] step %[[VAL_4:.*]] unordered reduce(#fir.reduce_attr -> %[[VAL_0]] : !fir.ref) { + fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr -> %sum : !fir.ref) { %index = fir.convert %iv : (index) -> i32 %1 = fir.load %sum : !fir.ref %2 = arith.addi %index, %1 : i32