diff --git a/dpctl/dptensor/numpy_usm_shared.py b/dpctl/dptensor/numpy_usm_shared.py index e15d537222..21355bb3de 100644 --- a/dpctl/dptensor/numpy_usm_shared.py +++ b/dpctl/dptensor/numpy_usm_shared.py @@ -78,6 +78,19 @@ def _get_usm_base(ary): return None +def convert_ndarray_to_np_ndarray(x, require_ndarray=False): + if isinstance(x, ndarray): + return np.array(x, copy=False, subok=False) + elif isinstance(x, tuple): + return tuple( + convert_ndarray_to_np_ndarray(y, require_ndarray=require_ndarray) for y in x + ) + elif require_ndarray: + raise TypeError + else: + return x + + class ndarray(np.ndarray): """ numpy.ndarray subclass whose underlying memory buffer is allocated @@ -234,7 +247,7 @@ def __array_finalize__(self, obj): # Convert to a NumPy ndarray. def as_ndarray(self): - return np.copy(np.ndarray(self.shape, self.dtype, self)) + return np.array(self, copy=True, subok=False) def __array__(self): return self @@ -267,23 +280,51 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # USM memory. However, if kwarg has numpy_usm_shared-typed out then # array_ufunc is called recursively so we cast out as regular # NumPy ndarray (having a USM data pointer). - if kwargs.get("out", None) is None: + out_arg = kwargs.get("out", None) + if out_arg is None: # maybe copy? # deal with multiple returned arrays, so kwargs['out'] can be tuple res_type = np.result_type(*typing) - out = empty(inputs[0].shape, dtype=res_type) - out_as_np = np.ndarray(out.shape, out.dtype, out) + out_arg = empty(inputs[0].shape, dtype=res_type) + out_as_np = convert_ndarray_to_np_ndarray(out_arg) kwargs["out"] = out_as_np else: # If they manually gave numpy_usm_shared as out kwarg then we # have to also cast as regular NumPy ndarray to avoid recursion. - if isinstance(kwargs["out"], ndarray): - out = kwargs["out"] - kwargs["out"] = np.ndarray(out.shape, out.dtype, out) + try: + kwargs["out"] = convert_ndarray_to_np_ndarray( + out_arg, require_ndarray=True + ) + except TypeError: + raise TypeError( + "Return arrays must each be {}".format(self.__class__) + ) + ufunc(*scalars, **kwargs) + return out_arg + elif method == "reduce": + N = None + scalars = [] + typing = [] + for inp in inputs: + if isinstance(inp, Number): + scalars.append(inp) + typing.append(inp) + elif isinstance(inp, (self.__class__, np.ndarray)): + if isinstance(inp, self.__class__): + scalars.append(np.ndarray(inp.shape, inp.dtype, inp)) + typing.append(np.ndarray(inp.shape, inp.dtype)) + else: + scalars.append(inp) + typing.append(inp) + if N is not None: + if N != inp.shape: + raise TypeError("inconsistent sizes") + else: + N = inp.shape else: - out = kwargs["out"] - ret = ufunc(*scalars, **kwargs) - return out + return NotImplemented + assert "out" not in kwargs + return super().__array_ufunc__(ufunc, method, *scalars, **kwargs) else: return NotImplemented @@ -295,7 +336,11 @@ def __array_function__(self, func, types, args, kwargs): cm = sys.modules[__name__] affunc = getattr(cm, fname) fargs = [x.view(np.ndarray) if isinstance(x, ndarray) else x for x in args] - return affunc(*fargs, **kwargs) + fkwargs = { + key: convert_ndarray_to_np_ndarray(val) for key, val in kwargs.items() + } + res = affunc(*fargs, **fkwargs) + return kwargs["out"] if "out" in kwargs else res return NotImplemented diff --git a/dpctl/tests/test_dparray.py b/dpctl/tests/test_dparray.py index 9938a5dbc6..76e5c8a0d9 100644 --- a/dpctl/tests/test_dparray.py +++ b/dpctl/tests/test_dparray.py @@ -47,6 +47,9 @@ def test_multiplication_dparray(self): C = self.X * 5 self.assertIsInstance(C, dparray.ndarray) + def test_inplace_sub(self): + self.X -= 1 + def test_dparray_through_python_func(self): def func_operation_with_const(dpctl_array): return dpctl_array * 2.0 + 13 @@ -58,6 +61,7 @@ def func_operation_with_const(dpctl_array): def test_dparray_mixing_dpctl_and_numpy(self): dp_numpy = numpy.ones((256, 4), dtype="d") res = dp_numpy * self.X + self.assertIsInstance(self.X, dparray.ndarray) self.assertIsInstance(res, dparray.ndarray) def test_dparray_shape(self): @@ -76,6 +80,20 @@ def test_numpy_sum_with_dparray(self): res = numpy.sum(self.X) self.assertEqual(res, 1024.0) + def test_numpy_sum_with_dparray_out(self): + res = dparray.empty((self.X.shape[1],), dtype=self.X.dtype) + res2 = numpy.sum(self.X, axis=0, out=res) + self.assertTrue(res is res2) + self.assertIsInstance(res2, dparray.ndarray) + + def test_frexp_with_out(self): + X = dparray.array([0.5, 4.7]) + mant = dparray.empty((2,), dtype="d") + exp = dparray.empty((2,), dtype="i4") + res = numpy.frexp(X, out=(mant, exp)) + self.assertTrue(res[0] is mant) + self.assertTrue(res[1] is exp) + if __name__ == "__main__": unittest.main()