@@ -788,46 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
788
788
// Obtain a single options-string to pass to the pass(-pipeline) from options
789
789
// passed in as a dictionary of keys mapping to values which are either
790
790
// attributes or param-operands pointing to attributes.
791
+ OperandRange dynamicOptions = getDynamicOptions ();
791
792
792
793
std::string options;
793
794
llvm::raw_string_ostream optionsStream (options); // For "printing" attrs.
794
795
795
- OperandRange dynamicOptions = getDynamicOptions ();
796
- for (auto [idx, namedAttribute] : llvm::enumerate (getOptions ())) {
797
- if (idx > 0 )
798
- optionsStream << " " ; // Interleave options separator.
799
- optionsStream << namedAttribute.getName ().str (); // Append the key.
800
- optionsStream << " =" ; // And the key-value separator.
801
-
802
- Attribute valueAttrToAppend;
803
- if (auto paramOperandIndex =
804
- dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue ())) {
805
- // The corresponding value attribute is passed in via a param.
796
+ // A helper to convert an option's attribute value into a corresponding
797
+ // string representation, with the ability to obtain the attr(s) from a param.
798
+ std::function<void (Attribute)> appendValueAttr = [&](Attribute valueAttr) {
799
+ if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
800
+ // The corresponding value attribute(s) is/are passed in via a param.
806
801
// Obtain the param-operand via its specified index.
807
- size_t dynamicOptionIdx = paramOperandIndex .getIndex ().getInt ();
802
+ size_t dynamicOptionIdx = paramOperand .getIndex ().getInt ();
808
803
assert (dynamicOptionIdx < dynamicOptions.size () &&
809
- " number of dynamic option markers (UnitAttr) in options ArrayAttr "
804
+ " the number of ParamOperandAttrs in the options DictionaryAttr "
810
805
" should be the same as the number of options passed as params" );
811
- ArrayRef<Attribute> dynamicOption =
806
+ ArrayRef<Attribute> attrsAssociatedToParam =
812
807
state.getParams (dynamicOptions[dynamicOptionIdx]);
813
- if (dynamicOption.size () != 1 )
814
- return emitSilenceableError ()
815
- << " options passed as a param must have "
816
- " a single value associated, param "
817
- << dynamicOptionIdx << " associates " << dynamicOption.size ();
818
- valueAttrToAppend = dynamicOption[0 ];
819
- } else {
820
- // Value is a static attribute.
821
- valueAttrToAppend = namedAttribute.getValue ();
822
- }
823
-
824
- // Append string representation of value attribute.
825
- if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
808
+ // Recursive so as to append all attrs associated to the param.
809
+ llvm::interleave (attrsAssociatedToParam, optionsStream, appendValueAttr,
810
+ " ," );
811
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
812
+ // Recursive so as to append all nested attrs of the array.
813
+ llvm::interleave (arrayAttr, optionsStream, appendValueAttr, " ," );
814
+ } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
815
+ // Convert to unquoted string.
826
816
optionsStream << strAttr.getValue ().str ();
827
817
} else {
828
- valueAttrToAppend.print (optionsStream, /* elideType=*/ true );
818
+ // For all other attributes, ask the attr to print itself (without type).
819
+ valueAttr.print (optionsStream, /* elideType=*/ true );
829
820
}
830
- }
821
+ };
822
+
823
+ // Convert the options DictionaryAttr into a single string.
824
+ llvm::interleave (
825
+ getOptions (), optionsStream,
826
+ [&](auto namedAttribute) {
827
+ optionsStream << namedAttribute.getName ().str (); // Append the key.
828
+ optionsStream << " =" ; // And the key-value separator.
829
+ appendValueAttr (namedAttribute.getValue ()); // And the attr's str repr.
830
+ },
831
+ " " );
831
832
optionsStream.flush ();
832
833
833
834
// Get pass or pass pipeline from registry.
@@ -878,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
878
879
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
879
880
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
880
881
SmallVector<NamedAttribute> keyValuePairs;
881
-
882
882
size_t dynamicOptionsIdx = 0 ;
883
- auto parseKeyValuePair = [&]() -> ParseResult {
884
- // Parse items of the form `key = value` where `key` is a bare identifier or
885
- // a string and `value` is either an attribute or an operand.
886
883
887
- std::string key;
888
- Attribute valueAttr;
889
- if (parser.parseOptionalKeywordOrString (&key))
890
- return parser.emitError (parser.getCurrentLocation ())
891
- << " expected key to either be an identifier or a string" ;
892
- if (key.empty ())
893
- return failure ();
884
+ // Helper for allowing parsing of option values which can be of the form:
885
+ // - a normal attribute
886
+ // - an operand (which would be converted to an attr referring to the operand)
887
+ // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
888
+ std::function<ParseResult (Attribute &)> parseValue =
889
+ [&](Attribute &valueAttr) -> ParseResult {
890
+ // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
891
+ if (succeeded (parser.parseOptionalLSquare ())) {
892
+ SmallVector<Attribute> attrs;
894
893
895
- if (parser.parseEqual ())
896
- return parser.emitError (parser.getCurrentLocation ())
897
- << " expected '=' after key in key-value pair" ;
894
+ // Recursively parse the array's elements, which might be operands.
895
+ if (parser.parseCommaSeparatedList (
896
+ AsmParser::Delimiter::None,
897
+ [&]() -> ParseResult { return parseValue (attrs.emplace_back ()); },
898
+ " in options dictionary" ) ||
899
+ parser.parseRSquare ())
900
+ return failure (); // NB: Attempted parse should've output error message.
901
+
902
+ valueAttr = ArrayAttr::get (parser.getContext (), attrs);
903
+
904
+ return success ();
905
+ }
898
906
899
907
// Parse the value, which can be either an attribute or an operand.
900
908
OptionalParseResult parsedValueAttr =
@@ -903,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
903
911
OpAsmParser::UnresolvedOperand operand;
904
912
ParseResult parsedOperand = parser.parseOperand (operand);
905
913
if (failed (parsedOperand))
906
- return parser.emitError (parser.getCurrentLocation ())
907
- << " expected a valid attribute or operand as value associated "
908
- << " to key '" << key << " '" ;
914
+ return failure (); // NB: Attempted parse should've output error message.
909
915
// To make use of the operand, we need to store it in the options dict.
910
916
// As SSA-values cannot occur in attributes, what we do instead is store
911
917
// an attribute in its place that contains the index of the param-operand,
@@ -924,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
924
930
<< " in the generic print format" ;
925
931
}
926
932
933
+ return success ();
934
+ };
935
+
936
+ // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
937
+ // string and `value` looks like either an attribute or an operand-in-an-attr.
938
+ std::function<ParseResult ()> parseKeyValuePair = [&]() -> ParseResult {
939
+ std::string key;
940
+ Attribute valueAttr;
941
+
942
+ if (failed (parser.parseOptionalKeywordOrString (&key)) || key.empty ())
943
+ return parser.emitError (parser.getCurrentLocation ())
944
+ << " expected key to either be an identifier or a string" ;
945
+
946
+ if (failed (parser.parseEqual ()))
947
+ return parser.emitError (parser.getCurrentLocation ())
948
+ << " expected '=' after key in key-value pair" ;
949
+
950
+ if (failed (parseValue (valueAttr)))
951
+ return parser.emitError (parser.getCurrentLocation ())
952
+ << " expected a valid attribute or operand as value associated "
953
+ << " to key '" << key << " '" ;
954
+
927
955
keyValuePairs.push_back (NamedAttribute (key, valueAttr));
956
+
928
957
return success ();
929
958
};
930
959
@@ -951,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
951
980
if (options.empty ())
952
981
return ;
953
982
954
- printer << " {" ;
955
- llvm::interleaveComma (options, printer, [&](NamedAttribute namedAttribute) {
956
- printer << namedAttribute.getName () << " = " ;
957
- Attribute value = namedAttribute.getValue ();
958
- if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
983
+ std::function<void (Attribute)> printOptionValue = [&](Attribute valueAttr) {
984
+ if (auto paramOperandAttr =
985
+ dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
959
986
// Resolve index of param-operand to its actual SSA-value and print that.
960
- printer.printOperand (dynamicOptions[indexAttr.getIndex ().getInt ()]);
987
+ printer.printOperand (
988
+ dynamicOptions[paramOperandAttr.getIndex ().getInt ()]);
989
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
990
+ // This case is so that ArrayAttr-contained operands are pretty-printed.
991
+ printer << " [" ;
992
+ llvm::interleaveComma (arrayAttr, printer, printOptionValue);
993
+ printer << " ]" ;
961
994
} else {
962
- printer.printAttribute (value );
995
+ printer.printAttribute (valueAttr );
963
996
}
997
+ };
998
+
999
+ printer << " {" ;
1000
+ llvm::interleaveComma (options, printer, [&](NamedAttribute namedAttribute) {
1001
+ printer << namedAttribute.getName ();
1002
+ printer << " = " ;
1003
+ printOptionValue (namedAttribute.getValue ());
964
1004
});
965
1005
printer << " }" ;
966
1006
}
@@ -970,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
970
1010
// and references to dynamic options in the options dictionary.
971
1011
972
1012
auto dynamicOptions = SmallVector<Value>(getDynamicOptions ());
973
- for (NamedAttribute namedAttr : getOptions ())
974
- if (auto paramOperand =
975
- dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue ())) {
1013
+
1014
+ // Helper for option values to mark seen operands as having been seen (once).
1015
+ std::function<LogicalResult (Attribute)> checkOptionValue =
1016
+ [&](Attribute valueAttr) -> LogicalResult {
1017
+ if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
976
1018
size_t dynamicOptionIdx = paramOperand.getIndex ().getInt ();
977
1019
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size ())
978
1020
return emitOpError ()
@@ -983,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
983
1025
return emitOpError () << " dynamic option index " << dynamicOptionIdx
984
1026
<< " is already used in options" ;
985
1027
dynamicOptions[dynamicOptionIdx] = nullptr ; // Mark this option as used.
1028
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1029
+ // Recurse into ArrayAttrs as they may contain references to operands.
1030
+ for (auto eltAttr : arrayAttr)
1031
+ if (failed (checkOptionValue (eltAttr)))
1032
+ return failure ();
986
1033
}
1034
+ return success ();
1035
+ };
1036
+
1037
+ for (NamedAttribute namedAttr : getOptions ())
1038
+ if (failed (checkOptionValue (namedAttr.getValue ())))
1039
+ return failure ();
987
1040
1041
+ // All dynamicOptions-params seen in the dict will have been set to null.
988
1042
for (Value dynamicOption : dynamicOptions)
989
1043
if (dynamicOption)
990
1044
return emitOpError () << " a param operand does not have a corresponding "
0 commit comments