Skip to content

Commit 04fb2b6

Browse files
committed
[Mlir] Implement printer, parser, verifier and builder for shape.reduce.
Differential Revision: https://reviews.llvm.org/D81186
1 parent 39e3683 commit 04fb2b6

File tree

4 files changed

+147
-16
lines changed

4 files changed

+147
-16
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
290290
let hasFolder = 1;
291291
}
292292

293-
def Shape_ReduceOp : Shape_Op<"reduce", []> {
293+
def Shape_ReduceOp : Shape_Op<"reduce",
294+
[SingleBlockImplicitTerminator<"YieldOp">]> {
294295
let summary = "Returns an expression reduced over a shape";
295296
let description = [{
296297
An operation that takes as input a shape, number of initial values and has a
@@ -310,25 +311,32 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
310311
number of elements
311312

312313
```mlir
313-
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
314-
%0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size
315-
%1 = "shape.reduce"(%shape, %0) ( {
316-
^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
317-
%acc = "shape.mul"(%lci, %dim) :
314+
func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
315+
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
316+
^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
317+
%updated_acc = "shape.mul"(%acc, %dim) :
318318
(!shape.size, !shape.size) -> !shape.size
319-
shape.yield %acc : !shape.size
320-
}) : (!shape.shape, !shape.size) -> (!shape.size)
321-
return %1 : !shape.size
319+
shape.yield %updated_acc : !shape.size
320+
}
321+
return %num_elements : !shape.size
322322
}
323323
```
324324

325325
If the shape is unranked, then the results of the op is also unranked.
326326
}];
327327

328-
let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$args);
328+
let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
329329
let results = (outs Variadic<AnyType>:$result);
330-
331330
let regions = (region SizedRegion<1>:$body);
331+
332+
let builders = [
333+
OpBuilder<"OpBuilder &builder, OperationState &result, "
334+
"Value shape, ValueRange initVals">,
335+
];
336+
337+
let verifier = [{ return ::verify(*this); }];
338+
let printer = [{ return ::print(p, *this); }];
339+
let parser = [{ return ::parse$cppClass(parser, result); }];
332340
}
333341

334342
def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,89 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef<Attribute> operands) {
481481
return DenseIntElementsAttr::get(type, shape);
482482
}
483483

484+
//===----------------------------------------------------------------------===//
485+
// ReduceOp
486+
//===----------------------------------------------------------------------===//
487+
488+
void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
489+
ValueRange initVals) {
490+
result.addOperands(shape);
491+
result.addOperands(initVals);
492+
493+
Region *bodyRegion = result.addRegion();
494+
bodyRegion->push_back(new Block);
495+
Block &bodyBlock = bodyRegion->front();
496+
bodyBlock.addArgument(builder.getIndexType());
497+
bodyBlock.addArgument(SizeType::get(builder.getContext()));
498+
499+
for (Type initValType : initVals.getTypes()) {
500+
bodyBlock.addArgument(initValType);
501+
result.addTypes(initValType);
502+
}
503+
}
504+
505+
static LogicalResult verify(ReduceOp op) {
506+
// Verify block arg types.
507+
Block &block = op.body().front();
508+
509+
auto blockArgsCount = op.initVals().size() + 2;
510+
if (block.getNumArguments() != blockArgsCount)
511+
return op.emitOpError() << "ReduceOp body is expected to have "
512+
<< blockArgsCount << " arguments";
513+
514+
if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
515+
return op.emitOpError(
516+
"argument 0 of ReduceOp body is expected to be of IndexType");
517+
518+
if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
519+
return op.emitOpError(
520+
"argument 1 of ReduceOp body is expected to be of SizeType");
521+
522+
for (auto type : llvm::enumerate(op.initVals()))
523+
if (block.getArgument(type.index() + 2).getType() != type.value().getType())
524+
return op.emitOpError()
525+
<< "type mismatch between argument " << type.index() + 2
526+
<< " of ReduceOp body and initial value " << type.index();
527+
return success();
528+
}
529+
530+
static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
531+
auto *ctx = parser.getBuilder().getContext();
532+
// Parse operands.
533+
SmallVector<OpAsmParser::OperandType, 3> operands;
534+
if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
535+
OpAsmParser::Delimiter::Paren) ||
536+
parser.parseOptionalArrowTypeList(result.types))
537+
return failure();
538+
539+
// Resolve operands.
540+
auto initVals = llvm::makeArrayRef(operands).drop_front();
541+
if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
542+
result.operands) ||
543+
parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
544+
result.operands))
545+
return failure();
546+
547+
// Parse the body.
548+
Region *body = result.addRegion();
549+
if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
550+
return failure();
551+
552+
// Parse attributes.
553+
if (parser.parseOptionalAttrDict(result.attributes))
554+
return failure();
555+
556+
return success();
557+
}
558+
559+
static void print(OpAsmPrinter &p, ReduceOp op) {
560+
p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
561+
<< ") ";
562+
p.printOptionalArrowTypeList(op.getResultTypes());
563+
p.printRegion(op.body());
564+
p.printOptionalAttrDict(op.getAttrs());
565+
}
566+
484567
namespace mlir {
485568
namespace shape {
486569

mlir/test/Dialect/Shape/invalid.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
4+
// expected-error@+1 {{ReduceOp body is expected to have 3 arguments}}
5+
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
6+
^bb0(%index: index, %dim: !shape.size):
7+
"shape.yield"(%dim) : (!shape.size) -> ()
8+
}
9+
}
10+
11+
// -----
12+
13+
func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
14+
// expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
15+
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
16+
^bb0(%index: f32, %dim: !shape.size, %lci: !shape.size):
17+
%acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
18+
"shape.yield"(%acc) : (!shape.size) -> ()
19+
}
20+
}
21+
22+
// -----
23+
24+
func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
25+
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
26+
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
27+
^bb0(%index: index, %dim: f32, %lci: !shape.size):
28+
"shape.yield"() : () -> ()
29+
}
30+
}
31+
32+
// -----
33+
34+
func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
35+
// expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
36+
%num_elements = shape.reduce(%shape, %init) -> f32 {
37+
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
38+
"shape.yield"() : () -> ()
39+
}
40+
}

mlir/test/Dialect/Shape/ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
// CHECK-LABEL: shape_num_elements
88
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
9-
%0 = shape.const_size 0
10-
%1 = "shape.reduce"(%shape, %0) ( {
11-
^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
9+
%init = shape.const_size 0
10+
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
11+
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
1212
%acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
1313
"shape.yield"(%acc) : (!shape.size) -> ()
14-
}) : (!shape.shape, !shape.size) -> (!shape.size)
15-
return %1 : !shape.size
14+
}
15+
return %num_elements : !shape.size
1616
}
1717

1818
func @test_shape_num_elements_unknown() {

0 commit comments

Comments
 (0)