diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 9768d9ea7d..e3136b54dc 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -48,6 +48,7 @@ broadcast_to, can_cast, concat, + device_result_type, expand_dims, finfo, flip, @@ -137,4 +138,5 @@ "get_print_options", "set_print_options", "print_options", + "device_result_type", ] diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 094b959efe..0417d55808 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -436,7 +436,7 @@ def stack(arrays, axis=0): return res -def can_cast(from_, to, casting="safe"): +def can_cast(from_, to, casting="safe", device=None): """ can_cast(from: usm_ndarray or dtype, to: dtype) -> bool @@ -454,6 +454,25 @@ def can_cast(from_, to, casting="safe"): _supported_dtype([dtype_from, dtype_to]) + if device is not None: + if isinstance(device, (dpctl.SyclQueue, dpt.Device)): + device = device.sycl_device + if not isinstance(device, dpctl.SyclDevice): + raise TypeError(f"Expected sycl_device type, got {type(device)}.") + if ( + not device.has_aspect_fp16 + and dtype_to == dpt.float16 + or not device.has_aspect_fp64 + and (dtype_to == dpt.float64 or dtype_to == dpt.complex128) + ): + return False + if not device.has_aspect_fp64 and ( + dtype_to == dpt.complex64 + or dtype_to == dpt.float32 + and dtype_from is not complex + ): + return True + return np.can_cast(dtype_from, dtype_to, casting) @@ -475,6 +494,34 @@ def result_type(*arrays_and_dtypes): return np.result_type(*dtypes) +def device_result_type(device, *arrays_and_dtypes): + """ + device_result_type(device: sycl_device, arrays_and_dtypes: an arbitrary \ + number usm_ndarrays or dtypes) -> dtype + + Returns the dtype that results from applying the Type Promotion Rules to \ + the arguments on current device. + """ + dt = result_type(*arrays_and_dtypes) + + if device is not None: + if isinstance(device, (dpctl.SyclQueue, dpt.Device)): + device = device.sycl_device + if not isinstance(device, dpctl.SyclDevice): + raise TypeError(f"Expected sycl_device type, got {type(device)}.") + if ( + dt == dpt.float16 + and not device.has_aspect_fp16 + or dt == dpt.float64 + and not device.has_aspect_fp64 + ): + return dpt.float32 + if dt == dpt.complex128 and not device.has_aspect_fp64: + return dpt.complex64 + + return dt + + def iinfo(dtype): """ iinfo(dtype: integer data-type) -> iinfo_object