@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
28
28
from libc.stdint cimport uint32_t
29
29
30
30
from dpctl._backend cimport ( # noqa: E211, E402;
31
+ DPCTLBuildOptionList_Append,
32
+ DPCTLBuildOptionList_Create,
33
+ DPCTLBuildOptionList_Delete,
34
+ DPCTLBuildOptionListRef,
31
35
DPCTLKernel_Copy,
32
36
DPCTLKernel_Delete,
33
37
DPCTLKernel_GetCompileNumSubGroups,
@@ -38,16 +42,31 @@ from dpctl._backend cimport ( # noqa: E211, E402;
38
42
DPCTLKernel_GetPreferredWorkGroupSizeMultiple,
39
43
DPCTLKernel_GetPrivateMemSize,
40
44
DPCTLKernel_GetWorkGroupSize,
45
+ DPCTLKernelBuildLog_Create,
46
+ DPCTLKernelBuildLog_Delete,
47
+ DPCTLKernelBuildLog_Get,
48
+ DPCTLKernelBuildLogRef,
41
49
DPCTLKernelBundle_Copy,
42
50
DPCTLKernelBundle_CreateFromOCLSource,
43
51
DPCTLKernelBundle_CreateFromSpirv,
52
+ DPCTLKernelBundle_CreateFromSYCLSource,
44
53
DPCTLKernelBundle_Delete,
45
54
DPCTLKernelBundle_GetKernel,
55
+ DPCTLKernelBundle_GetSyclKernel,
46
56
DPCTLKernelBundle_HasKernel,
57
+ DPCTLKernelBundle_HasSyclKernel,
58
+ DPCTLKernelNameList_Append,
59
+ DPCTLKernelNameList_Create,
60
+ DPCTLKernelNameList_Delete,
61
+ DPCTLKernelNameListRef,
47
62
DPCTLSyclContextRef,
48
63
DPCTLSyclDeviceRef,
49
64
DPCTLSyclKernelBundleRef,
50
65
DPCTLSyclKernelRef,
66
+ DPCTLVirtualHeaderList_Append,
67
+ DPCTLVirtualHeaderList_Create,
68
+ DPCTLVirtualHeaderList_Delete,
69
+ DPCTLVirtualHeaderListRef,
51
70
)
52
71
53
72
__all__ = [
@@ -196,9 +215,11 @@ cdef class SyclProgram:
196
215
"""
197
216
198
217
@staticmethod
199
- cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
218
+ cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef,
219
+ bint is_sycl_source):
200
220
cdef SyclProgram ret = SyclProgram.__new__ (SyclProgram)
201
221
ret._program_ref = KBRef
222
+ ret._is_sycl_source = is_sycl_source
202
223
return ret
203
224
204
225
def __dealloc__ (self ):
@@ -209,13 +230,19 @@ cdef class SyclProgram:
209
230
210
231
cpdef SyclKernel get_sycl_kernel(self , str kernel_name):
211
232
name = kernel_name.encode(" utf8" )
233
+ if self ._is_sycl_source:
234
+ return SyclKernel._create(
235
+ DPCTLKernelBundle_GetSyclKernel(self ._program_ref, name),
236
+ kernel_name)
212
237
return SyclKernel._create(
213
238
DPCTLKernelBundle_GetKernel(self ._program_ref, name),
214
239
kernel_name
215
240
)
216
241
217
242
def has_sycl_kernel (self , str kernel_name ):
218
243
name = kernel_name.encode(" utf8" )
244
+ if self ._is_sycl_source:
245
+ return DPCTLKernelBundle_HasSyclKernel(self ._program_ref, name)
219
246
return DPCTLKernelBundle_HasKernel(self ._program_ref, name)
220
247
221
248
def addressof_ref (self ):
@@ -271,7 +298,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
271
298
if KBref is NULL :
272
299
raise SyclProgramCompilationError()
273
300
274
- return SyclProgram._create(KBref)
301
+ return SyclProgram._create(KBref, False )
275
302
276
303
277
304
cpdef create_program_from_spirv(SyclQueue q, const unsigned char [:] IL,
@@ -317,7 +344,120 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
317
344
if KBref is NULL :
318
345
raise SyclProgramCompilationError()
319
346
320
- return SyclProgram._create(KBref)
347
+ return SyclProgram._create(KBref, False )
348
+
349
+
350
+ cpdef create_program_from_sycl_source(SyclQueue q, unicode source,
351
+ list headers = None ,
352
+ list registered_names = None ,
353
+ list copts = None ):
354
+ """
355
+ Creates an executable SYCL kernel_bundle from SYCL source code.
356
+
357
+ This uses the DPC++ ``kernel_compiler`` extension to create a
358
+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
359
+ SYCL source code.
360
+
361
+ Parameters:
362
+ q (:class:`dpctl.SyclQueue`)
363
+ The :class:`dpctl.SyclQueue` for which the
364
+ :class:`.SyclProgram` is going to be built.
365
+ source (unicode)
366
+ SYCL source code string.
367
+ headers (list)
368
+ Optional list of virtual headers, where each entry in the list
369
+ needs to be a tuple of header name and header content. See the
370
+ documentation of the ``include_files`` property in the DPC++
371
+ ``kernel_compiler`` extension for more information.
372
+ Default: []
373
+ registered_names (list, optional)
374
+ Optional list of kernel names to register. See the
375
+ documentation of the ``registered_names`` property in the DPC++
376
+ ``kernel_compiler`` extension for more information.
377
+ Default: []
378
+ copts (list)
379
+ Optional list of compilation flags that will be used
380
+ when compiling the program. Default: ``""``.
381
+
382
+ Returns:
383
+ program (:class:`.SyclProgram`)
384
+ A :class:`.SyclProgram` object wrapping the
385
+ ``sycl::kernel_bundle<sycl::bundle_state::executable>``
386
+ returned by the C API.
387
+
388
+ Raises:
389
+ SyclProgramCompilationError
390
+ If a SYCL kernel bundle could not be created. The exception
391
+ message contains the build log for more details.
392
+ """
393
+ cdef DPCTLSyclKernelBundleRef KBref
394
+ cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
395
+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
396
+ cdef bytes bSrc = source.encode(" utf8" )
397
+ cdef const char * Src = < const char * > bSrc
398
+ cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
399
+ cdef bytes bOpt
400
+ cdef const char * sOpt
401
+ cdef bytes bName
402
+ cdef const char * sName
403
+ cdef bytes bContent
404
+ cdef const char * sContent
405
+ cdef const char * buildLogContent
406
+ for opt in copts:
407
+ if not isinstance (opt, unicode ):
408
+ DPCTLBuildOptionList_Delete(BuildOpts)
409
+ raise SyclProgramCompilationError()
410
+ bOpt = opt.encode(" utf8" )
411
+ sOpt = < const char * > bOpt
412
+ DPCTLBuildOptionList_Append(BuildOpts, sOpt)
413
+
414
+ cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
415
+ for name in registered_names:
416
+ if not isinstance (name, unicode ):
417
+ DPCTLBuildOptionList_Delete(BuildOpts)
418
+ DPCTLKernelNameList_Delete(KernelNames)
419
+ raise SyclProgramCompilationError()
420
+ bName = name.encode(" utf8" )
421
+ sName = < const char * > bName
422
+ DPCTLKernelNameList_Append(KernelNames, sName)
423
+
424
+ cdef DPCTLVirtualHeaderListRef VirtualHeaders
425
+ VirtualHeaders = DPCTLVirtualHeaderList_Create()
426
+
427
+ for name, content in headers:
428
+ if not isinstance (name, unicode ) or not isinstance (content, unicode ):
429
+ DPCTLBuildOptionList_Delete(BuildOpts)
430
+ DPCTLKernelNameList_Delete(KernelNames)
431
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
432
+ raise SyclProgramCompilationError()
433
+ bName = name.encode(" utf8" )
434
+ sName = < const char * > bName
435
+ bContent = content.encode(" utf8" )
436
+ sContent = < const char * > bContent
437
+ DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
438
+
439
+ cdef DPCTLKernelBuildLogRef BuildLog
440
+ BuildLog = DPCTLKernelBuildLog_Create()
441
+
442
+ KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
443
+ VirtualHeaders, KernelNames,
444
+ BuildOpts, BuildLog)
445
+
446
+ if KBref is NULL :
447
+ buildLogContent = DPCTLKernelBuildLog_Get(BuildLog)
448
+ buildLogStr = str (buildLogContent, " utf-8" )
449
+ DPCTLBuildOptionList_Delete(BuildOpts)
450
+ DPCTLKernelNameList_Delete(KernelNames)
451
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
452
+ DPCTLKernelBuildLog_Delete(BuildLog)
453
+ raise SyclProgramCompilationError(buildLogStr)
454
+
455
+ DPCTLBuildOptionList_Delete(BuildOpts)
456
+ DPCTLKernelNameList_Delete(KernelNames)
457
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
458
+ DPCTLKernelBuildLog_Delete(BuildLog)
459
+
460
+ return SyclProgram._create(KBref, True )
321
461
322
462
323
463
cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(
@@ -336,4 +476,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
336
476
reference.
337
477
"""
338
478
cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
339
- return SyclProgram._create(copied_KBRef)
479
+ return SyclProgram._create(copied_KBRef, False )
0 commit comments