Skip to content

[MLIR][Transform] apply_registered_pass: support ListOptions #144026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,14 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
with options = { "top-down" = false,
"max-iterations" = %max_iter,
"test-convergence" = true,
"max-num-rewrites" = %max_rewrites }
"max-num-rewrites" = %max_rewrites }
to %module
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
```

Options' values which are `ArrayAttr`s are converted to comma-separated
lists of options. Likewise for params which associate multiple values.

This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
pass exists either, this op fails definitely.
Expand Down
166 changes: 110 additions & 56 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,46 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
// Obtain a single options-string to pass to the pass(-pipeline) from options
// passed in as a dictionary of keys mapping to values which are either
// attributes or param-operands pointing to attributes.
OperandRange dynamicOptions = getDynamicOptions();

std::string options;
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.

OperandRange dynamicOptions = getDynamicOptions();
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
if (idx > 0)
optionsStream << " "; // Interleave options separator.
optionsStream << namedAttribute.getName().str(); // Append the key.
optionsStream << "="; // And the key-value separator.

Attribute valueAttrToAppend;
if (auto paramOperandIndex =
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
// The corresponding value attribute is passed in via a param.
// A helper to convert an option's attribute value into a corresponding
// string representation, with the ability to obtain the attr(s) from a param.
std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
// The corresponding value attribute(s) is/are passed in via a param.
// Obtain the param-operand via its specified index.
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
assert(dynamicOptionIdx < dynamicOptions.size() &&
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
"the number of ParamOperandAttrs in the options DictionaryAttr"
"should be the same as the number of options passed as params");
ArrayRef<Attribute> dynamicOption =
ArrayRef<Attribute> attrsAssociatedToParam =
state.getParams(dynamicOptions[dynamicOptionIdx]);
if (dynamicOption.size() != 1)
return emitSilenceableError()
<< "options passed as a param must have "
"a single value associated, param "
<< dynamicOptionIdx << " associates " << dynamicOption.size();
valueAttrToAppend = dynamicOption[0];
} else {
// Value is a static attribute.
valueAttrToAppend = namedAttribute.getValue();
}

// Append string representation of value attribute.
if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
// Recursive so as to append all attrs associated to the param.
llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
",");
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
// Recursive so as to append all nested attrs of the array.
llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
} else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
// Convert to unquoted string.
optionsStream << strAttr.getValue().str();
} else {
valueAttrToAppend.print(optionsStream, /*elideType=*/true);
// For all other attributes, ask the attr to print itself (without type).
valueAttr.print(optionsStream, /*elideType=*/true);
}
}
};

// Convert the options DictionaryAttr into a single string.
llvm::interleave(
getOptions(), optionsStream,
[&](auto namedAttribute) {
optionsStream << namedAttribute.getName().str(); // Append the key.
optionsStream << "="; // And the key-value separator.
appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
},
" ");
optionsStream.flush();

// Get pass or pass pipeline from registry.
Expand Down Expand Up @@ -878,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
SmallVector<NamedAttribute> keyValuePairs;

size_t dynamicOptionsIdx = 0;
auto parseKeyValuePair = [&]() -> ParseResult {
// Parse items of the form `key = value` where `key` is a bare identifier or
// a string and `value` is either an attribute or an operand.

std::string key;
Attribute valueAttr;
if (parser.parseOptionalKeywordOrString(&key))
return parser.emitError(parser.getCurrentLocation())
<< "expected key to either be an identifier or a string";
if (key.empty())
return failure();
// Helper for allowing parsing of option values which can be of the form:
// - a normal attribute
// - an operand (which would be converted to an attr referring to the operand)
// - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
std::function<ParseResult(Attribute &)> parseValue =
[&](Attribute &valueAttr) -> ParseResult {
// Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
if (succeeded(parser.parseOptionalLSquare())) {
SmallVector<Attribute> attrs;

if (parser.parseEqual())
return parser.emitError(parser.getCurrentLocation())
<< "expected '=' after key in key-value pair";
// Recursively parse the array's elements, which might be operands.
if (parser.parseCommaSeparatedList(
AsmParser::Delimiter::None,
[&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
" in options dictionary") ||
parser.parseRSquare())
return failure(); // NB: Attempted parse should've output error message.

valueAttr = ArrayAttr::get(parser.getContext(), attrs);

return success();
}

// Parse the value, which can be either an attribute or an operand.
OptionalParseResult parsedValueAttr =
Expand All @@ -903,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser::UnresolvedOperand operand;
ParseResult parsedOperand = parser.parseOperand(operand);
if (failed(parsedOperand))
return parser.emitError(parser.getCurrentLocation())
<< "expected a valid attribute or operand as value associated "
<< "to key '" << key << "'";
return failure(); // NB: Attempted parse should've output error message.
// To make use of the operand, we need to store it in the options dict.
// As SSA-values cannot occur in attributes, what we do instead is store
// an attribute in its place that contains the index of the param-operand,
Expand All @@ -924,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
<< "in the generic print format";
}

return success();
};

// Helper for `key = value`-pair parsing where `key` is a bare identifier or a
// string and `value` looks like either an attribute or an operand-in-an-attr.
std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
std::string key;
Attribute valueAttr;

if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
return parser.emitError(parser.getCurrentLocation())
<< "expected key to either be an identifier or a string";

if (failed(parser.parseEqual()))
return parser.emitError(parser.getCurrentLocation())
<< "expected '=' after key in key-value pair";

if (failed(parseValue(valueAttr)))
return parser.emitError(parser.getCurrentLocation())
<< "expected a valid attribute or operand as value associated "
<< "to key '" << key << "'";

keyValuePairs.push_back(NamedAttribute(key, valueAttr));

return success();
};

Expand All @@ -951,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
if (options.empty())
return;

printer << "{";
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
printer << namedAttribute.getName() << " = ";
Attribute value = namedAttribute.getValue();
if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
if (auto paramOperandAttr =
dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
// Resolve index of param-operand to its actual SSA-value and print that.
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
printer.printOperand(
dynamicOptions[paramOperandAttr.getIndex().getInt()]);
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
// This case is so that ArrayAttr-contained operands are pretty-printed.
printer << "[";
llvm::interleaveComma(arrayAttr, printer, printOptionValue);
printer << "]";
} else {
printer.printAttribute(value);
printer.printAttribute(valueAttr);
}
};

printer << "{";
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
printer << namedAttribute.getName();
printer << " = ";
printOptionValue(namedAttribute.getValue());
});
printer << "}";
}
Expand All @@ -970,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
// and references to dynamic options in the options dictionary.

auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
for (NamedAttribute namedAttr : getOptions())
if (auto paramOperand =
dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {

// Helper for option values to mark seen operands as having been seen (once).
std::function<LogicalResult(Attribute)> checkOptionValue =
[&](Attribute valueAttr) -> LogicalResult {
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
return emitOpError()
Expand All @@ -983,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
return emitOpError() << "dynamic option index " << dynamicOptionIdx
<< " is already used in options";
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
// Recurse into ArrayAttrs as they may contain references to operands.
for (auto eltAttr : arrayAttr)
if (failed(checkOptionValue(eltAttr)))
return failure();
}
return success();
};

for (NamedAttribute namedAttr : getOptions())
if (failed(checkOptionValue(namedAttr.getValue())))
return failure();

// All dynamicOptions-params seen in the dict will have been set to null.
for (Value dynamicOption : dynamicOptions)
if (dynamicOption)
return emitOpError() << "a param operand does not have a corresponding "
Expand Down
41 changes: 21 additions & 20 deletions mlir/python/mlir/dialects/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)


OptionValueTypes = Union[
Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
]


@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
Expand All @@ -227,12 +232,7 @@ def __init__(
target: Union[Operation, Value, OpView],
pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
loc=None,
ip=None,
):
Expand All @@ -243,26 +243,32 @@ def __init__(
context = (loc and loc.context) or Context.current

cur_param_operand_idx = 0
for key, value in options.items() if options is not None else {}:
if isinstance(key, StringAttr):
key = key.value

def option_value_to_attr(value):
nonlocal cur_param_operand_idx
if isinstance(value, (Value, Operation, OpView)):
dynamic_options.append(_get_op_result_or_value(value))
options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context)
cur_param_operand_idx += 1
return ParamOperandAttr(cur_param_operand_idx - 1, context)
elif isinstance(value, Attribute):
options_dict[key] = value
return value
# The following cases auto-convert Python values to attributes.
elif isinstance(value, bool):
options_dict[key] = BoolAttr.get(value)
return BoolAttr.get(value)
elif isinstance(value, int):
default_int_type = IntegerType.get_signless(64, context)
options_dict[key] = IntegerAttr.get(default_int_type, value)
return IntegerAttr.get(default_int_type, value)
elif isinstance(value, str):
options_dict[key] = StringAttr.get(value)
return StringAttr.get(value)
elif isinstance(value, Sequence):
return ArrayAttr.get([option_value_to_attr(elt) for elt in value])
else:
raise TypeError(f"Unsupported option type: {type(value)}")

for key, value in options.items() if options is not None else {}:
if isinstance(key, StringAttr):
key = key.value
options_dict[key] = option_value_to_attr(value)
super().__init__(
result,
_get_op_result_or_value(target),
Expand All @@ -279,12 +285,7 @@ def apply_registered_pass(
target: Union[Operation, Value, OpView],
pass_name: Union[str, StringAttr],
*,
options: Optional[
Dict[
Union[str, StringAttr],
Union[Attribute, Value, Operation, OpView, str, int, bool],
]
] = None,
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
loc=None,
ip=None,
) -> Value:
Expand Down
Loading
Loading