Skip to content

Commit fb370bd

Browse files
committed
Allow ArrayAttrs options to contain param-operands
1 parent c2a8659 commit fb370bd

File tree

3 files changed

+148
-57
lines changed

3 files changed

+148
-57
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
418418
with options = { "top-down" = false,
419419
"max-iterations" = %max_iter,
420420
"test-convergence" = true,
421-
"max-num-rewrites" = %max_rewrites }
421+
"max-num-rewrites" = %max_rewrites }
422422
to %module
423423
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
424424
```

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 111 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -788,42 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
788788
// Obtain a single options-string to pass to the pass(-pipeline) from options
789789
// passed in as a dictionary of keys mapping to values which are either
790790
// attributes or param-operands pointing to attributes.
791+
OperandRange dynamicOptions = getDynamicOptions();
791792

792793
std::string options;
793794
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
794-
std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
795-
if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr))
796-
llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
797-
else if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
798-
optionsStream << strAttr.getValue().str();
799-
else
800-
valueAttr.print(optionsStream, /*elideType=*/true);
801-
};
802795

803-
OperandRange dynamicOptions = getDynamicOptions();
804-
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
805-
if (idx > 0)
806-
optionsStream << " "; // Interleave options separator.
807-
optionsStream << namedAttribute.getName().str(); // Append the key.
808-
optionsStream << "="; // And the key-value separator.
809-
810-
if (auto paramOperandIndex =
811-
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
812-
// The corresponding value attribute is passed in via a param.
796+
// A helper to convert an option's attribute value into a corresponding
797+
// string representation, with the ability to obtain the attr(s) from a param.
798+
std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
799+
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
800+
// The corresponding value attribute(s) is/are passed in via a param.
813801
// Obtain the param-operand via its specified index.
814-
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
802+
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
815803
assert(dynamicOptionIdx < dynamicOptions.size() &&
816-
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
804+
"the number of ParamOperandAttrs in the options DictionaryAttr"
817805
"should be the same as the number of options passed as params");
818-
ArrayRef<Attribute> dynamicOption =
806+
ArrayRef<Attribute> attrsAssociatedToParam =
819807
state.getParams(dynamicOptions[dynamicOptionIdx]);
820-
// Append all attributes associated to the param, separated by commas.
821-
llvm::interleave(dynamicOption, optionsStream, appendValueAttr, ",");
808+
// Recursive so as to append all attrs associated to the param.
809+
llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
810+
",");
811+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
812+
// Recursive so as to append all nested attrs of the array.
813+
llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
814+
} else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
815+
// Convert to unquoted string.
816+
optionsStream << strAttr.getValue().str();
822817
} else {
823-
// Value is a static attribute.
824-
appendValueAttr(namedAttribute.getValue());
818+
// For all other attributes, ask the attr to print itself (without type).
819+
valueAttr.print(optionsStream, /*elideType=*/true);
825820
}
826-
}
821+
};
822+
823+
// Convert the options DictionaryAttr into a single string.
824+
llvm::interleave(
825+
getOptions(), optionsStream,
826+
[&](auto namedAttribute) {
827+
optionsStream << namedAttribute.getName().str(); // Append the key.
828+
optionsStream << "="; // And the key-value separator.
829+
appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
830+
},
831+
" ");
827832
optionsStream.flush();
828833

829834
// Get pass or pass pipeline from registry.
@@ -874,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
874879
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
875880
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
876881
SmallVector<NamedAttribute> keyValuePairs;
877-
878882
size_t dynamicOptionsIdx = 0;
879-
auto parseKeyValuePair = [&]() -> ParseResult {
880-
// Parse items of the form `key = value` where `key` is a bare identifier or
881-
// a string and `value` is either an attribute or an operand.
882883

883-
std::string key;
884-
Attribute valueAttr;
885-
if (parser.parseOptionalKeywordOrString(&key))
886-
return parser.emitError(parser.getCurrentLocation())
887-
<< "expected key to either be an identifier or a string";
888-
if (key.empty())
889-
return failure();
884+
// Helper for allowing parsing of option values which can be of the form:
885+
// - a normal attribute
886+
// - an operand (which would be converted to an attr referring to the operand)
887+
// - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
888+
std::function<ParseResult(Attribute &)> parseValue =
889+
[&](Attribute &valueAttr) -> ParseResult {
890+
// Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
891+
if (succeeded(parser.parseOptionalLSquare())) {
892+
SmallVector<Attribute> attrs;
890893

891-
if (parser.parseEqual())
892-
return parser.emitError(parser.getCurrentLocation())
893-
<< "expected '=' after key in key-value pair";
894+
// Recursively parse the array's elements, which might be operands.
895+
if (parser.parseCommaSeparatedList(
896+
AsmParser::Delimiter::None,
897+
[&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
898+
" in options dictionary") ||
899+
parser.parseRSquare())
900+
return failure(); // NB: Attempted parse should've output error message.
901+
902+
valueAttr = ArrayAttr::get(parser.getContext(), attrs);
903+
904+
return success();
905+
}
894906

895907
// Parse the value, which can be either an attribute or an operand.
896908
OptionalParseResult parsedValueAttr =
@@ -899,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
899911
OpAsmParser::UnresolvedOperand operand;
900912
ParseResult parsedOperand = parser.parseOperand(operand);
901913
if (failed(parsedOperand))
902-
return parser.emitError(parser.getCurrentLocation())
903-
<< "expected a valid attribute or operand as value associated "
904-
<< "to key '" << key << "'";
914+
return failure(); // NB: Attempted parse should've output error message.
905915
// To make use of the operand, we need to store it in the options dict.
906916
// As SSA-values cannot occur in attributes, what we do instead is store
907917
// an attribute in its place that contains the index of the param-operand,
@@ -920,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
920930
<< "in the generic print format";
921931
}
922932

933+
return success();
934+
};
935+
936+
// Helper for `key = value`-pair parsing where `key` is a bare identifier or a
937+
// string and `value` looks like either an attribute or an operand-in-an-attr.
938+
std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
939+
std::string key;
940+
Attribute valueAttr;
941+
942+
if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
943+
return parser.emitError(parser.getCurrentLocation())
944+
<< "expected key to either be an identifier or a string";
945+
946+
if (failed(parser.parseEqual()))
947+
return parser.emitError(parser.getCurrentLocation())
948+
<< "expected '=' after key in key-value pair";
949+
950+
if (failed(parseValue(valueAttr)))
951+
return parser.emitError(parser.getCurrentLocation())
952+
<< "expected a valid attribute or operand as value associated "
953+
<< "to key '" << key << "'";
954+
923955
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
956+
924957
return success();
925958
};
926959

@@ -947,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
947980
if (options.empty())
948981
return;
949982

950-
printer << "{";
951-
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
952-
printer << namedAttribute.getName() << " = ";
953-
Attribute value = namedAttribute.getValue();
954-
if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
983+
std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
984+
if (auto paramOperandAttr =
985+
dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
955986
// Resolve index of param-operand to its actual SSA-value and print that.
956-
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
987+
printer.printOperand(
988+
dynamicOptions[paramOperandAttr.getIndex().getInt()]);
989+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
990+
// This case is so that ArrayAttr-contained operands are pretty-printed.
991+
printer << "[";
992+
llvm::interleaveComma(arrayAttr, printer, printOptionValue);
993+
printer << "]";
957994
} else {
958-
printer.printAttribute(value);
995+
printer.printAttribute(valueAttr);
959996
}
997+
};
998+
999+
printer << "{";
1000+
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
1001+
printer << namedAttribute.getName();
1002+
printer << " = ";
1003+
printOptionValue(namedAttribute.getValue());
9601004
});
9611005
printer << "}";
9621006
}
@@ -966,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
9661010
// and references to dynamic options in the options dictionary.
9671011

9681012
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
969-
for (NamedAttribute namedAttr : getOptions())
970-
if (auto paramOperand =
971-
dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
1013+
1014+
// Helper for option values to mark seen operands as having been seen (once).
1015+
std::function<LogicalResult(Attribute)> checkOptionValue =
1016+
[&](Attribute valueAttr) -> LogicalResult {
1017+
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
9721018
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
9731019
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
9741020
return emitOpError()
@@ -979,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
9791025
return emitOpError() << "dynamic option index " << dynamicOptionIdx
9801026
<< " is already used in options";
9811027
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1028+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1029+
// Recurse into ArrayAttrs as they may contain references to operands.
1030+
for (auto eltAttr : arrayAttr)
1031+
if (failed(checkOptionValue(eltAttr)))
1032+
return failure();
9821033
}
1034+
return success();
1035+
};
1036+
1037+
for (NamedAttribute namedAttr : getOptions())
1038+
if (failed(checkOptionValue(namedAttr.getValue())))
1039+
return failure();
9831040

1041+
// All dynamicOptions-params seen in the dict will have been set to null.
9841042
for (Value dynamicOption : dynamicOptions)
9851043
if (dynamicOption)
9861044
return emitOpError() << "a param operand does not have a corresponding "

mlir/test/Dialect/Transform/test-pass-application.mlir

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ module attributes {transform.with_named_sequence} {
164164

165165
// -----
166166

167-
// CHECK-LABEL: func private @valid_multiple_params_as_list_option()
167+
// CHECK-LABEL: func private @valid_multiple_values_as_list_option_single_param()
168168
module {
169-
func.func @valid_multiple_params_as_list_option() {
169+
func.func @valid_multiple_values_as_list_option_single_param() {
170170
return
171171
}
172172

@@ -253,6 +253,38 @@ module attributes {transform.with_named_sequence} {
253253
}
254254
}
255255

256+
// -----
257+
258+
// CHECK-LABEL: func private @valid_multiple_params_as_single_list_option()
259+
module {
260+
func.func @valid_multiple_params_as_single_list_option() {
261+
return
262+
}
263+
264+
// CHECK: func @a()
265+
func.func @a() {
266+
return
267+
}
268+
// CHECK: func @b()
269+
func.func @b() {
270+
return
271+
}
272+
}
273+
274+
module attributes {transform.with_named_sequence} {
275+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
276+
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
277+
%2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
278+
%symbol_a = transform.param.constant "a" -> !transform.any_param
279+
%symbol_b = transform.param.constant "b" -> !transform.any_param
280+
transform.apply_registered_pass "symbol-privatize"
281+
with options = { exclude = [%symbol_a, %symbol_b] } to %2
282+
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
283+
transform.yield
284+
}
285+
}
286+
287+
256288
// -----
257289

258290
func.func @invalid_options_as_str() {
@@ -294,7 +326,8 @@ func.func @invalid_options_due_to_reserved_attr() {
294326
module attributes {transform.with_named_sequence} {
295327
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
296328
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
297-
// expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
329+
// expected-error @+3 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
330+
// expected-error @+2 {{expected a valid attribute or operand as value associated to key 'top-down'}}
298331
%2 = transform.apply_registered_pass "canonicalize"
299332
with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
300333
transform.yield

0 commit comments

Comments
 (0)