diff --git a/arrayfire/array.py b/arrayfire/array.py index 13a5fc8fa..40a0fcd9d 100644 --- a/arrayfire/array.py +++ b/arrayfire/array.py @@ -31,6 +31,27 @@ def _create_array(buf, numdims, idims, dtype, is_device): numdims, ct.pointer(c_dims), dtype.value)) return out_arr +def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides): + out_arr = ct.c_void_p(0) + c_dims = dim4(idims[0], idims[1], idims[2], idims[3]) + if offset is None: + offset = 0 + offset = ct.c_ulonglong(offset) + if strides is None: + strides = (1, idims[0], idims[0]*idims[1], idims[0]*idims[1]*idims[2]) + while len(strides) < 4: + strides = strides + (strides[-1],) + strides = dim4(strides[0], strides[1], strides[2], strides[3]) + if is_device: + location = Source.device + else: + location = Source.host + safe_call(backend.get().af_create_strided_array(ct.pointer(out_arr), ct.c_void_p(buf), + offset, numdims, ct.pointer(c_dims), + ct.pointer(strides), dtype.value, + location.value)) + return out_arr + def _create_empty_array(numdims, idims, dtype): out_arr = ct.c_void_p(0) c_dims = dim4(idims[0], idims[1], idims[2], idims[3]) @@ -352,7 +373,7 @@ class Array(BaseArray): """ - def __init__(self, src=None, dims=(0,), dtype=None, is_device=False): + def __init__(self, src=None, dims=(0,), dtype=None, is_device=False, offset=None, strides=None): super(Array, self).__init__() @@ -409,8 +430,10 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False): if (type_char is not None and type_char != _type_char): raise TypeError("Can not create array of requested type from input data type") - - self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char], is_device) + if(offset is None and strides is None): + self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char], is_device) + else: + self.arr = _create_strided_array(buf, numdims, idims, to_dtype[_type_char], is_device, offset, strides) else: @@ -454,6 +477,26 @@ def __del__(self): backend.get().af_release_array(self.arr) def device_ptr(self): + """ + Return the device pointer exclusively held by the array. + + Returns + ------ + ptr : int + Contains location of the device pointer + + Note + ---- + - This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL. + - No other arrays will share the same device pointer. + - A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory. + - In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array. + """ + ptr = ct.c_void_p(0) + backend.get().af_get_device_ptr(ct.pointer(ptr), self.arr) + return ptr.value + + def raw_ptr(self): """ Return the device pointer held by the array. @@ -466,11 +509,45 @@ def device_ptr(self): ---- - This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL. - No mem copy is peformed, this function returns the raw device pointer. + - This pointer may be shared with other arrays. Use this function with caution. + - In particular the JIT compiler will not be aware of the shared arrays. + - This results in JITed operations not being immediately visible through the other array. """ ptr = ct.c_void_p(0) - backend.get().af_get_device_ptr(ct.pointer(ptr), self.arr) + backend.get().af_get_raw_ptr(ct.pointer(ptr), self.arr) return ptr.value + def offset(self): + """ + Return the offset, of the first element relative to the raw pointer. + + Returns + ------ + offset : int + The offset in number of elements + """ + offset = ct.c_longlong(0) + safe_call(backend.get().af_get_offset(ct.pointer(offset), self.arr)) + return offset.value + + def strides(self): + """ + Return the distance in bytes between consecutive elements for each dimension. + + Returns + ------ + strides : tuple + The strides for each dimension + """ + s0 = ct.c_longlong(0) + s1 = ct.c_longlong(0) + s2 = ct.c_longlong(0) + s3 = ct.c_longlong(0) + safe_call(backend.get().af_get_strides(ct.pointer(s0), ct.pointer(s1), + ct.pointer(s2), ct.pointer(s3), self.arr)) + strides = (s0.value,s1.value,s2.value,s3.value) + return strides[:self.numdims()] + def elements(self): """ Return the number of elements in the array. @@ -622,6 +699,22 @@ def is_bool(self): safe_call(backend.get().af_is_bool(ct.pointer(res), self.arr)) return res.value + def is_linear(self): + """ + Check if all elements of the array are contiguous. + """ + res = ct.c_bool(False) + safe_call(backend.get().af_is_linear(ct.pointer(res), self.arr)) + return res.value + + def is_owner(self): + """ + Check if the array owns the raw pointer or is a derived array. + """ + res = ct.c_bool(False) + safe_call(backend.get().af_is_owner(ct.pointer(res), self.arr)) + return res.value + def __add__(self, other): """ Return self + other.