diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 6d62f035a9674..8c00bc68a0779 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -832,8 +832,8 @@ class KernelObjVisitor { // Implements the 'for-each-visitor' pattern. template - void VisitElement(CXXRecordDecl *Owner, FieldDecl *ArrayField, - QualType ElementTy, Handlers &... handlers) { + void VisitElementImpl(CXXRecordDecl *Owner, FieldDecl *ArrayField, + QualType ElementTy, Handlers &... handlers) { if (Util::isSyclAccessorType(ElementTy)) KF_FOR_EACH(handleSyclAccessorType, ArrayField, ElementTy); else if (Util::isSyclStreamType(ElementTy)) @@ -854,6 +854,16 @@ class KernelObjVisitor { KF_FOR_EACH(handleScalarType, ArrayField, ElementTy); } + template + void VisitFirstElement(CXXRecordDecl *Owner, FieldDecl *ArrayField, + QualType ElementTy, Handlers &... handlers) { + VisitElementImpl(Owner, ArrayField, ElementTy, handlers...); + } + + template + void VisitNthElement(CXXRecordDecl *Owner, FieldDecl *ArrayField, + QualType ElementTy, Handlers &... handlers); + template void VisitArrayElements(FieldDecl *FD, QualType FieldTy, Handlers &... handlers) { @@ -863,17 +873,35 @@ class KernelObjVisitor { QualType ET = CAT->getElementType(); int64_t ElemCount = CAT->getSize().getSExtValue(); std::initializer_list{(handlers.enterArray(), 0)...}; - for (int64_t Count = 0; Count < ElemCount; Count++) { - VisitElement(nullptr, FD, ET, handlers...); + + assert(ElemCount > 0 && "SYCL prohibits 0 sized arrays"); + VisitFirstElement(nullptr, FD, ET, handlers...); + (void)std::initializer_list{(handlers.nextElement(ET), 0)...}; + + for (int64_t Count = 1; Count < ElemCount; Count++) { + VisitNthElement(nullptr, FD, ET, handlers...); (void)std::initializer_list{(handlers.nextElement(ET), 0)...}; } + (void)std::initializer_list{ (handlers.leaveArray(FD, ET, ElemCount), 0)...}; } + // Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter + // the Wrapper structure that we're currently visiting. Owner is the parent + // type (which doesn't exist in cases where it is a FieldDecl in the + // 'root'), and Wrapper is the current struct being unwrapped. template void VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent, - const CXXRecordDecl *Wrapper, Handlers &... handlers); + const CXXRecordDecl *Wrapper, Handlers &... handlers) { + (void)std::initializer_list{ + (handlers.enterStruct(Owner, Parent), 0)...}; + VisitRecordHelper(Wrapper, Wrapper->bases(), handlers...); + VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...); + (void)std::initializer_list{ + (handlers.leaveStruct(Owner, Parent), 0)...}; + } + template void VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, const CXXRecordDecl *Wrapper, Handlers &... handlers); @@ -988,25 +1016,13 @@ class KernelObjVisitor { } #undef KF_FOR_EACH }; -// Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter -// the Wrapper structure that we're currently visiting. Owner is the parent -// type (which doesn't exist in cases where it is a FieldDecl in the -// 'root'), and Wrapper is the current struct being unwrapped. -template -void KernelObjVisitor::VisitRecord(const CXXRecordDecl *Owner, ParentTy &Parent, - const CXXRecordDecl *Wrapper, - Handlers &... handlers) { - (void)std::initializer_list{(handlers.enterStruct(Owner, Parent), 0)...}; - VisitRecordHelper(Wrapper, Wrapper->bases(), handlers...); - VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...); - (void)std::initializer_list{(handlers.leaveStruct(Owner, Parent), 0)...}; -} // A base type that the SYCL OpenCL Kernel construction task uses to implement // individual tasks. class SyclKernelFieldHandlerBase { public: static constexpr const bool VisitUnionBody = false; + static constexpr const bool VisitNthElement = true; // Mark these virtual so that we can use override in the implementer classes, // despite virtual dispatch never being used. @@ -1115,6 +1131,21 @@ void KernelObjVisitor::VisitUnion(const CXXRecordDecl *Owner, ParentTy &Parent, HandlerFilter(handlers).Handler...); } +template +void KernelObjVisitor::VisitNthElement(CXXRecordDecl *Owner, + FieldDecl *ArrayField, + QualType ElementTy, + Handlers &... handlers) { + // Don't continue descending if none of the handlers 'care'. This could be 'if + // constexpr' starting in C++17. Until then, we have to count on the + // optimizer to realize "if (false)" is a dead branch. + if (AnyTrue::Value) + VisitElementImpl( + Owner, ArrayField, ElementTy, + HandlerFilter(handlers) + .Handler...); +} + // A type to check the validity of all of the argument types. class SyclKernelFieldChecker : public SyclKernelFieldHandler { bool IsInvalid = false; @@ -1237,6 +1268,7 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler { public: SyclKernelFieldChecker(Sema &S) : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {} + static constexpr const bool VisitNthElement = false; bool isValid() { return !IsInvalid; } bool handleReferenceType(FieldDecl *FD, QualType FieldTy) final { @@ -1285,6 +1317,7 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler { : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {} bool isValid() { return !IsInvalid; } static constexpr const bool VisitUnionBody = true; + static constexpr const bool VisitNthElement = false; bool checkType(SourceLocation Loc, QualType Ty) { if (UnionCount) {