Skip to content

Commit f64f835

Browse files
[SYCL] Fix spec constants support in integration header (#2896)
Added emissions of forward-declarations of types used as names for specialization constants, which allows to use types declared within namespaces as specialization constant names.
1 parent e71eed0 commit f64f835

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3757,8 +3757,14 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37573757
PrintingPolicy Policy(LO);
37583758
Policy.SuppressTypedefs = true;
37593759
Policy.SuppressUnwrittenScope = true;
3760+
SYCLFwdDeclEmitter FwdDeclEmitter(O, S.getLangOpts());
37603761

37613762
if (SpecConsts.size() > 0) {
3763+
O << "// Forward declarations of templated spec constant types:\n";
3764+
for (const auto &SC : SpecConsts)
3765+
FwdDeclEmitter.Visit(SC.first);
3766+
O << "\n";
3767+
37623768
// Remove duplicates.
37633769
std::sort(SpecConsts.begin(), SpecConsts.end(),
37643770
[](const SpecConstID &SC1, const SpecConstID &SC2) {
@@ -3772,10 +3778,12 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37723778
// Here can do faster comparison of types.
37733779
return SC1.first == SC2.first;
37743780
});
3781+
37753782
O << "// Specialization constants IDs:\n";
37763783
for (const auto &P : llvm::make_range(SpecConsts.begin(), End)) {
37773784
O << "template <> struct sycl::detail::SpecConstantInfo<";
3778-
O << P.first.getAsString(Policy);
3785+
SYCLKernelNameTypePrinter Printer(O, Policy);
3786+
Printer.Visit(P.first);
37793787
O << "> {\n";
37803788
O << " static constexpr const char* getName() {\n";
37813789
O << " return \"" << P.second << "\";\n";
@@ -3786,8 +3794,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37863794

37873795
if (!UnnamedLambdaSupport) {
37883796
O << "// Forward declarations of templated kernel function types:\n";
3789-
3790-
SYCLFwdDeclEmitter FwdDeclEmitter(O, S.getLangOpts());
37913797
for (const KernelDesc &K : KernelDescs)
37923798
FwdDeclEmitter.Visit(K.NameType);
37933799
}

clang/test/CodeGenSYCL/int_header_spec_const.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class MyUInt32Const;
1818
class MyFloatConst;
1919
class MyDoubleConst;
2020

21+
namespace test {
22+
class MySpecConstantWithinANamespace;
23+
};
24+
2125
int main() {
2226
// Create specialization constants.
2327
cl::sycl::ONEAPI::experimental::spec_constant<bool, MyBoolConst> i1(false);
@@ -32,13 +36,31 @@ int main() {
3236
cl::sycl::ONEAPI::experimental::spec_constant<unsigned int, MyUInt32Const> ui32(0);
3337
cl::sycl::ONEAPI::experimental::spec_constant<float, MyFloatConst> f32(0);
3438
cl::sycl::ONEAPI::experimental::spec_constant<double, MyDoubleConst> f64(0);
39+
// Kernel name can be used as a spec constant name
40+
cl::sycl::ONEAPI::experimental::spec_constant<int, SpecializedKernel> spec1(0);
41+
// Spec constant name can be declared within a namespace
42+
cl::sycl::ONEAPI::experimental::spec_constant<int, test::MySpecConstantWithinANamespace> spec2(0);
3543

3644
double val;
3745
double *ptr = &val; // to avoid "unused" warnings
3846

47+
// CHECK: // Forward declarations of templated spec constant types:
48+
// CHECK: class MyInt8Const;
49+
// CHECK: class MyUInt8Const;
50+
// CHECK: class MyInt16Const;
51+
// CHECK: class MyUInt16Const;
52+
// CHECK: class MyInt32Const;
53+
// CHECK: class MyUInt32Const;
54+
// CHECK: class MyFloatConst;
55+
// CHECK: class MyDoubleConst;
56+
// CHECK: class SpecializedKernel;
57+
// CHECK: namespace test {
58+
// CHECK: class MySpecConstantWithinANamespace;
59+
// CHECK: }
60+
3961
cl::sycl::kernel_single_task<SpecializedKernel>([=]() {
4062
*ptr = i1.get() +
41-
// CHECK-DAG: template <> struct sycl::detail::SpecConstantInfo<class MyBoolConst> {
63+
// CHECK-DAG: template <> struct sycl::detail::SpecConstantInfo<::MyBoolConst> {
4264
// CHECK-DAG-NEXT: static constexpr const char* getName() {
4365
// CHECK-DAG-NEXT: return "_ZTS11MyBoolConst";
4466
// CHECK-DAG-NEXT: }
@@ -58,7 +80,11 @@ int main() {
5880
// CHECK-DAG: return "_ZTS13MyUInt32Const";
5981
f32.get() +
6082
// CHECK-DAG: return "_ZTS12MyFloatConst";
61-
f64.get();
62-
// CHECK-DAG: return "_ZTS13MyDoubleConst";
83+
f64.get() +
84+
// CHECK-DAG: return "_ZTS13MyDoubleConst";
85+
spec1.get() +
86+
// CHECK-DAG: return "_ZTS17SpecializedKernel"
87+
spec2.get();
88+
// CHECK-DAG: return "_ZTSN4test30MySpecConstantWithinANamespaceE"
6389
});
6490
}

0 commit comments

Comments
 (0)