Skip to content

Commit 6620ca4

Browse files
committed
Update Python-bindings
1 parent fb370bd commit 6620ca4

File tree

3 files changed

+65
-37
lines changed

3 files changed

+65
-37
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def __init__(
219219
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
220220

221221

222+
OptionValueTypes = Union[
223+
Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
224+
]
225+
226+
222227
@_ods_cext.register_operation(_Dialect, replace=True)
223228
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
224229
def __init__(
@@ -227,12 +232,7 @@ def __init__(
227232
target: Union[Operation, Value, OpView],
228233
pass_name: Union[str, StringAttr],
229234
*,
230-
options: Optional[
231-
Dict[
232-
Union[str, StringAttr],
233-
Union[Attribute, Value, Operation, OpView, str, int, bool],
234-
]
235-
] = None,
235+
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
236236
loc=None,
237237
ip=None,
238238
):
@@ -243,26 +243,32 @@ def __init__(
243243
context = (loc and loc.context) or Context.current
244244

245245
cur_param_operand_idx = 0
246-
for key, value in options.items() if options is not None else {}:
247-
if isinstance(key, StringAttr):
248-
key = key.value
249246

247+
def option_value_to_attr(value):
248+
nonlocal cur_param_operand_idx
250249
if isinstance(value, (Value, Operation, OpView)):
251250
dynamic_options.append(_get_op_result_or_value(value))
252-
options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context)
253251
cur_param_operand_idx += 1
252+
return ParamOperandAttr(cur_param_operand_idx - 1, context)
254253
elif isinstance(value, Attribute):
255-
options_dict[key] = value
254+
return value
256255
# The following cases auto-convert Python values to attributes.
257256
elif isinstance(value, bool):
258-
options_dict[key] = BoolAttr.get(value)
257+
return BoolAttr.get(value)
259258
elif isinstance(value, int):
260259
default_int_type = IntegerType.get_signless(64, context)
261-
options_dict[key] = IntegerAttr.get(default_int_type, value)
260+
return IntegerAttr.get(default_int_type, value)
262261
elif isinstance(value, str):
263-
options_dict[key] = StringAttr.get(value)
262+
return StringAttr.get(value)
263+
elif isinstance(value, Sequence):
264+
return ArrayAttr.get([option_value_to_attr(elt) for elt in value])
264265
else:
265266
raise TypeError(f"Unsupported option type: {type(value)}")
267+
268+
for key, value in options.items() if options is not None else {}:
269+
if isinstance(key, StringAttr):
270+
key = key.value
271+
options_dict[key] = option_value_to_attr(value)
266272
super().__init__(
267273
result,
268274
_get_op_result_or_value(target),
@@ -279,12 +285,7 @@ def apply_registered_pass(
279285
target: Union[Operation, Value, OpView],
280286
pass_name: Union[str, StringAttr],
281287
*,
282-
options: Optional[
283-
Dict[
284-
Union[str, StringAttr],
285-
Union[Attribute, Value, Operation, OpView, str, int, bool],
286-
]
287-
] = None,
288+
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
288289
loc=None,
289290
ip=None,
290291
) -> Value:

mlir/test/Dialect/Transform/test-pass-application.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ module attributes {transform.with_named_sequence} {
284284
}
285285
}
286286

287-
288287
// -----
289288

290289
func.func @invalid_options_as_str() {

mlir/test/python/dialects/transform.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -256,30 +256,45 @@ def testReplicateOp(module: Module):
256256
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
257257

258258

259+
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
259260
@run
260261
def testApplyRegisteredPassOp(module: Module):
262+
# CHECK: transform.sequence
261263
sequence = transform.SequenceOp(
262264
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
263265
)
264266
with InsertionPoint(sequence.body):
267+
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
265268
mod = transform.ApplyRegisteredPassOp(
266269
transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
267270
)
271+
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
272+
# CHECK-SAME: with options = {"top-down" = false}
273+
# CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
268274
mod = transform.ApplyRegisteredPassOp(
269275
transform.AnyOpType.get(),
270276
mod.result,
271277
"canonicalize",
272278
options={"top-down": BoolAttr.get(False)},
273279
)
280+
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
274281
max_iter = transform.param_constant(
275282
transform.AnyParamType.get(),
276283
IntegerAttr.get(IntegerType.get_signless(64), 10),
277284
)
285+
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
278286
max_rewrites = transform.param_constant(
279287
transform.AnyParamType.get(),
280288
IntegerAttr.get(IntegerType.get_signless(64), 1),
281289
)
282-
transform.apply_registered_pass(
290+
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
291+
# NB: MLIR has sorted the dict lexicographically by key:
292+
# CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
293+
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
294+
# CHECK-SAME: "test-convergence" = true,
295+
# CHECK-SAME: "top-down" = false}
296+
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
297+
mod = transform.apply_registered_pass(
283298
transform.AnyOpType.get(),
284299
mod,
285300
"canonicalize",
@@ -290,19 +305,32 @@ def testApplyRegisteredPassOp(module: Module):
290305
"max-rewrites": max_rewrites,
291306
},
292307
)
308+
# CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
309+
# CHECK-SAME: with options = {"exclude" = ["a", "b"]}
310+
# CHECK-SAME: to %{{.*}} : (!transform.any_op) -> !transform.any_op
311+
mod = transform.apply_registered_pass(
312+
transform.AnyOpType.get(),
313+
mod,
314+
"symbol-privatize",
315+
options={ "exclude": ("a", "b") },
316+
)
317+
# CHECK: %[[SYMBOL_A:.+]] = transform.param.constant
318+
symbol_a = transform.param_constant(
319+
transform.AnyParamType.get(),
320+
StringAttr.get("a")
321+
)
322+
# CHECK: %[[SYMBOL_B:.+]] = transform.param.constant
323+
symbol_b = transform.param_constant(
324+
transform.AnyParamType.get(),
325+
StringAttr.get("b")
326+
)
327+
# CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
328+
# CHECK-SAME: with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]}
329+
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
330+
mod = transform.apply_registered_pass(
331+
transform.AnyOpType.get(),
332+
mod,
333+
"symbol-privatize",
334+
options={ "exclude": (symbol_a, symbol_b) },
335+
)
293336
transform.YieldOp()
294-
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
295-
# CHECK: transform.sequence
296-
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
297-
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
298-
# CHECK-SAME: with options = {"top-down" = false}
299-
# CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
300-
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
301-
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
302-
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
303-
# NB: MLIR has sorted the dict lexicographically by key:
304-
# CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
305-
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
306-
# CHECK-SAME: "test-convergence" = true,
307-
# CHECK-SAME: "top-down" = false}
308-
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op

0 commit comments

Comments
 (0)