Skip to content

Commit e008538

Browse files
authored
[MLIR][Transform] apply_registered_pass: support ListOptions (#144026)
Interpret an option value with multiple values, either in the form of an `ArrayAttr` (either static or passed through a param) or as the multiple attrs associated to a param, as a comma-separated list, i.e. as a ListOption on a pass.
1 parent 299a55a commit e008538

File tree

5 files changed

+301
-114
lines changed

5 files changed

+301
-114
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,14 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
418418
with options = { "top-down" = false,
419419
"max-iterations" = %max_iter,
420420
"test-convergence" = true,
421-
"max-num-rewrites" = %max_rewrites }
421+
"max-num-rewrites" = %max_rewrites }
422422
to %module
423423
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
424424
```
425425

426+
Options' values which are `ArrayAttr`s are converted to comma-separated
427+
lists of options. Likewise for params which associate multiple values.
428+
426429
This op first looks for a pass pipeline with the specified name. If no such
427430
pipeline exists, it looks for a pass with the specified name. If no such
428431
pass exists either, this op fails definitely.

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 110 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -788,46 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
788788
// Obtain a single options-string to pass to the pass(-pipeline) from options
789789
// passed in as a dictionary of keys mapping to values which are either
790790
// attributes or param-operands pointing to attributes.
791+
OperandRange dynamicOptions = getDynamicOptions();
791792

792793
std::string options;
793794
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
794795

795-
OperandRange dynamicOptions = getDynamicOptions();
796-
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
797-
if (idx > 0)
798-
optionsStream << " "; // Interleave options separator.
799-
optionsStream << namedAttribute.getName().str(); // Append the key.
800-
optionsStream << "="; // And the key-value separator.
801-
802-
Attribute valueAttrToAppend;
803-
if (auto paramOperandIndex =
804-
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
805-
// The corresponding value attribute is passed in via a param.
796+
// A helper to convert an option's attribute value into a corresponding
797+
// string representation, with the ability to obtain the attr(s) from a param.
798+
std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
799+
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
800+
// The corresponding value attribute(s) is/are passed in via a param.
806801
// Obtain the param-operand via its specified index.
807-
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
802+
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
808803
assert(dynamicOptionIdx < dynamicOptions.size() &&
809-
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
804+
"the number of ParamOperandAttrs in the options DictionaryAttr"
810805
"should be the same as the number of options passed as params");
811-
ArrayRef<Attribute> dynamicOption =
806+
ArrayRef<Attribute> attrsAssociatedToParam =
812807
state.getParams(dynamicOptions[dynamicOptionIdx]);
813-
if (dynamicOption.size() != 1)
814-
return emitSilenceableError()
815-
<< "options passed as a param must have "
816-
"a single value associated, param "
817-
<< dynamicOptionIdx << " associates " << dynamicOption.size();
818-
valueAttrToAppend = dynamicOption[0];
819-
} else {
820-
// Value is a static attribute.
821-
valueAttrToAppend = namedAttribute.getValue();
822-
}
823-
824-
// Append string representation of value attribute.
825-
if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
808+
// Recursive so as to append all attrs associated to the param.
809+
llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
810+
",");
811+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
812+
// Recursive so as to append all nested attrs of the array.
813+
llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
814+
} else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
815+
// Convert to unquoted string.
826816
optionsStream << strAttr.getValue().str();
827817
} else {
828-
valueAttrToAppend.print(optionsStream, /*elideType=*/true);
818+
// For all other attributes, ask the attr to print itself (without type).
819+
valueAttr.print(optionsStream, /*elideType=*/true);
829820
}
830-
}
821+
};
822+
823+
// Convert the options DictionaryAttr into a single string.
824+
llvm::interleave(
825+
getOptions(), optionsStream,
826+
[&](auto namedAttribute) {
827+
optionsStream << namedAttribute.getName().str(); // Append the key.
828+
optionsStream << "="; // And the key-value separator.
829+
appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
830+
},
831+
" ");
831832
optionsStream.flush();
832833

833834
// Get pass or pass pipeline from registry.
@@ -878,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
878879
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
879880
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
880881
SmallVector<NamedAttribute> keyValuePairs;
881-
882882
size_t dynamicOptionsIdx = 0;
883-
auto parseKeyValuePair = [&]() -> ParseResult {
884-
// Parse items of the form `key = value` where `key` is a bare identifier or
885-
// a string and `value` is either an attribute or an operand.
886883

887-
std::string key;
888-
Attribute valueAttr;
889-
if (parser.parseOptionalKeywordOrString(&key))
890-
return parser.emitError(parser.getCurrentLocation())
891-
<< "expected key to either be an identifier or a string";
892-
if (key.empty())
893-
return failure();
884+
// Helper for allowing parsing of option values which can be of the form:
885+
// - a normal attribute
886+
// - an operand (which would be converted to an attr referring to the operand)
887+
// - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
888+
std::function<ParseResult(Attribute &)> parseValue =
889+
[&](Attribute &valueAttr) -> ParseResult {
890+
// Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
891+
if (succeeded(parser.parseOptionalLSquare())) {
892+
SmallVector<Attribute> attrs;
894893

895-
if (parser.parseEqual())
896-
return parser.emitError(parser.getCurrentLocation())
897-
<< "expected '=' after key in key-value pair";
894+
// Recursively parse the array's elements, which might be operands.
895+
if (parser.parseCommaSeparatedList(
896+
AsmParser::Delimiter::None,
897+
[&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
898+
" in options dictionary") ||
899+
parser.parseRSquare())
900+
return failure(); // NB: Attempted parse should've output error message.
901+
902+
valueAttr = ArrayAttr::get(parser.getContext(), attrs);
903+
904+
return success();
905+
}
898906

899907
// Parse the value, which can be either an attribute or an operand.
900908
OptionalParseResult parsedValueAttr =
@@ -903,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
903911
OpAsmParser::UnresolvedOperand operand;
904912
ParseResult parsedOperand = parser.parseOperand(operand);
905913
if (failed(parsedOperand))
906-
return parser.emitError(parser.getCurrentLocation())
907-
<< "expected a valid attribute or operand as value associated "
908-
<< "to key '" << key << "'";
914+
return failure(); // NB: Attempted parse should've output error message.
909915
// To make use of the operand, we need to store it in the options dict.
910916
// As SSA-values cannot occur in attributes, what we do instead is store
911917
// an attribute in its place that contains the index of the param-operand,
@@ -924,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
924930
<< "in the generic print format";
925931
}
926932

933+
return success();
934+
};
935+
936+
// Helper for `key = value`-pair parsing where `key` is a bare identifier or a
937+
// string and `value` looks like either an attribute or an operand-in-an-attr.
938+
std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
939+
std::string key;
940+
Attribute valueAttr;
941+
942+
if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
943+
return parser.emitError(parser.getCurrentLocation())
944+
<< "expected key to either be an identifier or a string";
945+
946+
if (failed(parser.parseEqual()))
947+
return parser.emitError(parser.getCurrentLocation())
948+
<< "expected '=' after key in key-value pair";
949+
950+
if (failed(parseValue(valueAttr)))
951+
return parser.emitError(parser.getCurrentLocation())
952+
<< "expected a valid attribute or operand as value associated "
953+
<< "to key '" << key << "'";
954+
927955
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
956+
928957
return success();
929958
};
930959

@@ -951,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
951980
if (options.empty())
952981
return;
953982

954-
printer << "{";
955-
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
956-
printer << namedAttribute.getName() << " = ";
957-
Attribute value = namedAttribute.getValue();
958-
if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
983+
std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
984+
if (auto paramOperandAttr =
985+
dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
959986
// Resolve index of param-operand to its actual SSA-value and print that.
960-
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
987+
printer.printOperand(
988+
dynamicOptions[paramOperandAttr.getIndex().getInt()]);
989+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
990+
// This case is so that ArrayAttr-contained operands are pretty-printed.
991+
printer << "[";
992+
llvm::interleaveComma(arrayAttr, printer, printOptionValue);
993+
printer << "]";
961994
} else {
962-
printer.printAttribute(value);
995+
printer.printAttribute(valueAttr);
963996
}
997+
};
998+
999+
printer << "{";
1000+
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
1001+
printer << namedAttribute.getName();
1002+
printer << " = ";
1003+
printOptionValue(namedAttribute.getValue());
9641004
});
9651005
printer << "}";
9661006
}
@@ -970,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
9701010
// and references to dynamic options in the options dictionary.
9711011

9721012
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
973-
for (NamedAttribute namedAttr : getOptions())
974-
if (auto paramOperand =
975-
dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
1013+
1014+
// Helper for option values to mark seen operands as having been seen (once).
1015+
std::function<LogicalResult(Attribute)> checkOptionValue =
1016+
[&](Attribute valueAttr) -> LogicalResult {
1017+
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
9761018
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
9771019
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
9781020
return emitOpError()
@@ -983,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
9831025
return emitOpError() << "dynamic option index " << dynamicOptionIdx
9841026
<< " is already used in options";
9851027
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1028+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1029+
// Recurse into ArrayAttrs as they may contain references to operands.
1030+
for (auto eltAttr : arrayAttr)
1031+
if (failed(checkOptionValue(eltAttr)))
1032+
return failure();
9861033
}
1034+
return success();
1035+
};
1036+
1037+
for (NamedAttribute namedAttr : getOptions())
1038+
if (failed(checkOptionValue(namedAttr.getValue())))
1039+
return failure();
9871040

1041+
// All dynamicOptions-params seen in the dict will have been set to null.
9881042
for (Value dynamicOption : dynamicOptions)
9891043
if (dynamicOption)
9901044
return emitOpError() << "a param operand does not have a corresponding "

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def __init__(
219219
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
220220

221221

222+
OptionValueTypes = Union[
223+
Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
224+
]
225+
226+
222227
@_ods_cext.register_operation(_Dialect, replace=True)
223228
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
224229
def __init__(
@@ -227,12 +232,7 @@ def __init__(
227232
target: Union[Operation, Value, OpView],
228233
pass_name: Union[str, StringAttr],
229234
*,
230-
options: Optional[
231-
Dict[
232-
Union[str, StringAttr],
233-
Union[Attribute, Value, Operation, OpView, str, int, bool],
234-
]
235-
] = None,
235+
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
236236
loc=None,
237237
ip=None,
238238
):
@@ -243,26 +243,32 @@ def __init__(
243243
context = (loc and loc.context) or Context.current
244244

245245
cur_param_operand_idx = 0
246-
for key, value in options.items() if options is not None else {}:
247-
if isinstance(key, StringAttr):
248-
key = key.value
249246

247+
def option_value_to_attr(value):
248+
nonlocal cur_param_operand_idx
250249
if isinstance(value, (Value, Operation, OpView)):
251250
dynamic_options.append(_get_op_result_or_value(value))
252-
options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context)
253251
cur_param_operand_idx += 1
252+
return ParamOperandAttr(cur_param_operand_idx - 1, context)
254253
elif isinstance(value, Attribute):
255-
options_dict[key] = value
254+
return value
256255
# The following cases auto-convert Python values to attributes.
257256
elif isinstance(value, bool):
258-
options_dict[key] = BoolAttr.get(value)
257+
return BoolAttr.get(value)
259258
elif isinstance(value, int):
260259
default_int_type = IntegerType.get_signless(64, context)
261-
options_dict[key] = IntegerAttr.get(default_int_type, value)
260+
return IntegerAttr.get(default_int_type, value)
262261
elif isinstance(value, str):
263-
options_dict[key] = StringAttr.get(value)
262+
return StringAttr.get(value)
263+
elif isinstance(value, Sequence):
264+
return ArrayAttr.get([option_value_to_attr(elt) for elt in value])
264265
else:
265266
raise TypeError(f"Unsupported option type: {type(value)}")
267+
268+
for key, value in options.items() if options is not None else {}:
269+
if isinstance(key, StringAttr):
270+
key = key.value
271+
options_dict[key] = option_value_to_attr(value)
266272
super().__init__(
267273
result,
268274
_get_op_result_or_value(target),
@@ -279,12 +285,7 @@ def apply_registered_pass(
279285
target: Union[Operation, Value, OpView],
280286
pass_name: Union[str, StringAttr],
281287
*,
282-
options: Optional[
283-
Dict[
284-
Union[str, StringAttr],
285-
Union[Attribute, Value, Operation, OpView, str, int, bool],
286-
]
287-
] = None,
288+
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
288289
loc=None,
289290
ip=None,
290291
) -> Value:

0 commit comments

Comments
 (0)