diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 2013d3623711b..b3c658821c74a 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -570,6 +570,19 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { if (ComplexType complexTy = dyn_cast(eltType)) { eltType = complexTy.getElementType(); isComplex = true; + // Complex types have 2 elements. + if (shape.empty() && storage.size() != 2) { + p.emitError(loc) << "parsed " << storage.size() << " elements, but type (" + << complexTy << ") expected 2 elements"; + return nullptr; + } + if (!shape.empty() && + storage.size() != static_cast(type.getNumElements()) * 2) { + p.emitError(loc) << "parsed " << storage.size() << " elements, but type (" + << type << ") expected " << type.getNumElements() * 2 + << " elements"; + return nullptr; + } } // Handle integer and index types. diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir index 10988be91d84a..83aa0b3525f3f 100644 --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -63,6 +63,21 @@ func.func @elementsattr_toolarge1() -> () { // ----- +// expected-error@+1 {{parsed 1 elements, but type ('complex') expected 2 elements}} +#attr = dense<0> : tensor<2xcomplex> + +// ----- + +// expected-error@+1 {{parsed 2 elements, but type ('tensor<2xcomplex>') expected 4 elements}} +#attr = dense<[0, 1]> : tensor<2xcomplex> + +// ----- + +// expected-error@+1 {{parsed 3 elements, but type ('tensor<2xcomplex>') expected 4 elements}} +#attr = dense<[0, (0, 1)]> : tensor<2xcomplex> + +// ----- + func.func @elementsattr_toolarge2() -> () { "foo"(){bar = dense<[-777]> : tensor<1xi8>} : () -> () // expected-error {{integer constant out of range}} }