diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 93a24dee29ad2..2474e88373e04 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -566,8 +566,10 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { if (ComplexType complexTy = dyn_cast(eltType)) { eltType = complexTy.getElementType(); isComplex = true; - // Complex types have 2 elements. - if (shape.empty() && storage.size() != 2) { + // Complex types have N*2 elements or complex splat. + // Empty shape may mean a splat or empty literal, only validate splats. + bool isSplat = shape.empty() && type.getNumElements() != 0; + if (isSplat && storage.size() != 2) { p.emitError(loc) << "parsed " << storage.size() << " elements, but type (" << complexTy << ") expected 2 elements"; return nullptr; diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index cace1fefa43d6..8b192ff11d573 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -730,6 +730,10 @@ func.func @densetensorattr() -> () { "complex_attr"(){bar = dense<(1.000000e+00,0.000000e+00)> : tensor>} : () -> () // CHECK: dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex> "complex_attr"(){bar = dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex>} : () -> () + // CHECK: dense<> : tensor<0xcomplex> + "complex_attr"(){bar = dense<> : tensor<0xcomplex>} : () -> () + // CHECK: dense<> : tensor<2x0xcomplex> + "complex_attr"(){bar = dense<> : tensor<2x0xcomplex>} : () -> () return }