Skip to content

Commit 56bd37f

Browse files
committed
Refactor separate can_access_peer methods into a single method
1 parent 321582f commit 56bd37f

File tree

2 files changed

+71
-51
lines changed

2 files changed

+71
-51
lines changed

dpctl/_sycl_device.pyx

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,7 @@ cdef class SyclDevice(_SyclDevice):
18551855
raise ValueError("Internal error: NULL device vector encountered")
18561856
return _get_devices(cDVRef)
18571857

1858-
def can_access_peer_access_supported(self, peer):
1858+
def can_access_peer(self, peer, value="access_supported"):
18591859
""" Returns ``True`` if this device (``self``) can enable peer access
18601860
to USM device memory on ``peer``, ``False`` otherwise.
18611861
@@ -1869,60 +1869,50 @@ cdef class SyclDevice(_SyclDevice):
18691869
peer (:class:`dpctl.SyclDevice`):
18701870
The :class:`dpctl.SyclDevice` instance to check for peer access
18711871
by this device.
1872+
value (str, optional):
1873+
Specifies the kind of peer access being queried
18721874
1873-
Returns:
1874-
bool:
1875-
``True`` if this device may access USM device memory on
1876-
``peer`` when peer access is enabled, otherwise ``False``.
1875+
- ``"access_supported"``
1876+
Returns ``True`` if it is possible for this device to
1877+
enable peer access to USM device memory on ``peer``.
18771878
1878-
Raises:
1879-
TypeError:
1880-
If ``peer`` is not :class:`dpctl.SyclDevice`.
1881-
"""
1882-
cdef SyclDevice p_dev
1879+
- ``"atomics_supported"``
1880+
Returns ``True`` if it is possible for this device to
1881+
concurrently access and atomically modify USM device
1882+
memory on ``peer`` when enabled.
18831883
1884-
if not isinstance(peer, SyclDevice):
1885-
raise TypeError(
1886-
"peer device must be a `dpctl.SyclDevice`, got "
1887-
f"{type(peer)}"
1888-
)
1889-
p_dev = <SyclDevice>peer
1890-
if _check_peer_access(self, p_dev):
1891-
return DPCTLDevice_CanAccessPeer(
1892-
self._device_ref,
1893-
p_dev.get_device_ref(),
1894-
_peer_access._access_supported
1895-
)
1896-
return False
1884+
If ``False`` is returned, these operations result in
1885+
undefined behavior.
18971886
1898-
def can_access_peer_atomics_supported(self, peer):
1899-
""" Returns ``True`` if this device (``self``) can concurrently access
1900-
and modify USM device memory on ``peer`` when peer access is enabled,
1901-
``False`` otherwise.
1887+
Note: atomics must have ``memory_scope::system`` when
1888+
modifying memory on a peer device.
19021889
1903-
If peer access is supported, it may be enabled by calling
1904-
:meth:`.enable_peer_access`.
1905-
1906-
For details, see
1907-
:oneapi_peer_access:`DPC++ peer access SYCL extension <>`.
1908-
1909-
Args:
1910-
peer (:class:`dpctl.SyclDevice`):
1911-
The :class:`dpctl.SyclDevice` instance to check for concurrent
1912-
peer access and modification by this device.
1890+
Default: ``"access_supported"``
19131891
19141892
Returns:
19151893
bool:
1916-
``True`` if this device may concurrently access and modify USM
1917-
device memory on ``peer`` when peer access is enabled,
1918-
otherwise ``False``.
1894+
``True`` if this device may access USM device memory on
1895+
``peer`` when peer access is enabled, otherwise ``False``.
19191896
19201897
Raises:
19211898
TypeError:
19221899
If ``peer`` is not :class:`dpctl.SyclDevice`.
19231900
"""
19241901
cdef SyclDevice p_dev
19251902

1903+
if not isinstance(value, str):
1904+
raise TypeError(
1905+
f"Expected `value` to be of type str, got {type(value)}"
1906+
)
1907+
if value == "access_supported":
1908+
access_type = _peer_access._access_supported
1909+
elif value == "atomics_supported":
1910+
access_type = _peer_access._atomics_supported
1911+
else:
1912+
raise ValueError(
1913+
"`value` must be 'access_supported' or 'atomics_supported', "
1914+
f"got {value}"
1915+
)
19261916
if not isinstance(peer, SyclDevice):
19271917
raise TypeError(
19281918
"peer device must be a `dpctl.SyclDevice`, got "
@@ -1933,7 +1923,7 @@ cdef class SyclDevice(_SyclDevice):
19331923
return DPCTLDevice_CanAccessPeer(
19341924
self._device_ref,
19351925
p_dev.get_device_ref(),
1936-
_peer_access._atomics_supported
1926+
access_type
19371927
)
19381928
return False
19391929

@@ -2002,11 +1992,10 @@ cdef class SyclDevice(_SyclDevice):
20021992
)
20031993
p_dev = <SyclDevice>peer
20041994
_raise_invalid_peer_access(self, p_dev)
2005-
if _check_peer_access(self, p_dev):
2006-
DPCTLDevice_DisablePeerAccess(
2007-
self._device_ref,
2008-
p_dev.get_device_ref()
2009-
)
1995+
DPCTLDevice_DisablePeerAccess(
1996+
self._device_ref,
1997+
p_dev.get_device_ref()
1998+
)
20101999
return
20112000

20122001
@property

dpctl/tests/test_sycl_device.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,10 @@ def test_can_access_peer(platform_name):
360360
)
361361
dev0 = devices[0]
362362
dev1 = devices[1]
363-
assert isinstance(dev0.can_access_peer_access_supported(dev1), bool)
364-
assert isinstance(dev0.can_access_peer_atomics_supported(dev1), bool)
363+
assert isinstance(dev0.can_access_peer(dev1), bool)
364+
assert isinstance(
365+
dev0.can_access_peer(dev1, value="atomics_supported"), bool
366+
)
365367

366368

367369
@pytest.mark.parametrize("platform_name", ["level_zero", "cuda", "hip"])
@@ -381,7 +383,7 @@ def test_enable_disable_peer_access(platform_name):
381383
)
382384
dev0 = devices[0]
383385
dev1 = devices[1]
384-
if dev0.can_access_peer_access_supported(dev1):
386+
if dev0.can_access_peer(dev1):
385387
dev0.enable_peer_access(dev1)
386388
dev0.disable_peer_access(dev1)
387389
else:
@@ -393,8 +395,7 @@ def test_enable_disable_peer_access(platform_name):
393395
@pytest.mark.parametrize(
394396
"method",
395397
[
396-
"can_access_peer_access_supported",
397-
"can_access_peer_atomics_supported",
398+
"can_access_peer",
398399
"enable_peer_access",
399400
"disable_peer_access",
400401
],
@@ -427,3 +428,33 @@ def test_peer_access_to_self(platform_name):
427428
dev.enable_peer_access(dev)
428429
with pytest.raises(ValueError):
429430
dev.disable_peer_access(dev)
431+
432+
433+
def test_peer_access_value_keyword_validation():
434+
"""
435+
Validate behavior of `can_access_peer` for invalid `value` keyword.
436+
"""
437+
# we pick an arbitrary platform that supports peer access
438+
platforms = dpctl.get_platforms()
439+
peer_access_backends = [
440+
dpctl.backend_type.cuda,
441+
dpctl.backend_type.hip,
442+
dpctl.backend_type.hip,
443+
]
444+
devs = None
445+
for p in platforms:
446+
if p.backend in peer_access_backends:
447+
p_devs = p.get_devices()
448+
if len(p_devs) >= 2:
449+
devs = p_devs
450+
break
451+
if devs is None:
452+
pytest.skip("No platform available with enough devices")
453+
dev0 = devs[0]
454+
dev1 = devs[1]
455+
bad_type = 2
456+
with pytest.raises(TypeError):
457+
dev0.can_access_peer(dev1, value=bad_type)
458+
bad_value = "wrong"
459+
with pytest.raises(ValueError):
460+
dev0.can_access_peer(dev1, value=bad_value)

0 commit comments

Comments
 (0)