Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 97 additions & 4 deletions arrayfire/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ def _create_array(buf, numdims, idims, dtype, is_device):
numdims, ct.pointer(c_dims), dtype.value))
return out_arr

def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides):
out_arr = ct.c_void_p(0)
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
if offset is None:
offset = 0
offset = ct.c_ulonglong(offset)
if strides is None:
strides = (1, idims[0], idims[0]*idims[1], idims[0]*idims[1]*idims[2])
while len(strides) < 4:
strides = strides + (strides[-1],)
strides = dim4(strides[0], strides[1], strides[2], strides[3])
if is_device:
location = Source.device
else:
location = Source.host
safe_call(backend.get().af_create_strided_array(ct.pointer(out_arr), ct.c_void_p(buf),
offset, numdims, ct.pointer(c_dims),
ct.pointer(strides), dtype.value,
location.value))
return out_arr

def _create_empty_array(numdims, idims, dtype):
out_arr = ct.c_void_p(0)
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
Expand Down Expand Up @@ -352,7 +373,7 @@ class Array(BaseArray):

"""

def __init__(self, src=None, dims=(0,), dtype=None, is_device=False):
def __init__(self, src=None, dims=(0,), dtype=None, is_device=False, offset=None, strides=None):

super(Array, self).__init__()

Expand Down Expand Up @@ -409,8 +430,10 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False):
if (type_char is not None and
type_char != _type_char):
raise TypeError("Can not create array of requested type from input data type")

self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char], is_device)
if(offset is None and strides is None):
self.arr = _create_array(buf, numdims, idims, to_dtype[_type_char], is_device)
else:
self.arr = _create_strided_array(buf, numdims, idims, to_dtype[_type_char], is_device, offset, strides)

else:

Expand Down Expand Up @@ -454,6 +477,26 @@ def __del__(self):
backend.get().af_release_array(self.arr)

def device_ptr(self):
"""
Return the device pointer exclusively held by the array.

Returns
------
ptr : int
Contains location of the device pointer

Note
----
- This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
- No other arrays will share the same device pointer.
- A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory.
- In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array.
"""
ptr = ct.c_void_p(0)
backend.get().af_get_device_ptr(ct.pointer(ptr), self.arr)
return ptr.value

def raw_ptr(self):
"""
Return the device pointer held by the array.

Expand All @@ -466,11 +509,45 @@ def device_ptr(self):
----
- This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
- No mem copy is peformed, this function returns the raw device pointer.
- This pointer may be shared with other arrays. Use this function with caution.
- In particular the JIT compiler will not be aware of the shared arrays.
- This results in JITed operations not being immediately visible through the other array.
"""
ptr = ct.c_void_p(0)
backend.get().af_get_device_ptr(ct.pointer(ptr), self.arr)
backend.get().af_get_raw_ptr(ct.pointer(ptr), self.arr)
return ptr.value

def offset(self):
"""
Return the offset, of the first element relative to the raw pointer.

Returns
------
offset : int
The offset in number of elements
"""
offset = ct.c_longlong(0)
safe_call(backend.get().af_get_offset(ct.pointer(offset), self.arr))
return offset.value

def strides(self):
"""
Return the distance in bytes between consecutive elements for each dimension.

Returns
------
strides : tuple
The strides for each dimension
"""
s0 = ct.c_longlong(0)
s1 = ct.c_longlong(0)
s2 = ct.c_longlong(0)
s3 = ct.c_longlong(0)
safe_call(backend.get().af_get_strides(ct.pointer(s0), ct.pointer(s1),
ct.pointer(s2), ct.pointer(s3), self.arr))
strides = (s0.value,s1.value,s2.value,s3.value)
return strides[:self.numdims()]

def elements(self):
"""
Return the number of elements in the array.
Expand Down Expand Up @@ -622,6 +699,22 @@ def is_bool(self):
safe_call(backend.get().af_is_bool(ct.pointer(res), self.arr))
return res.value

def is_linear(self):
"""
Check if all elements of the array are contiguous.
"""
res = ct.c_bool(False)
safe_call(backend.get().af_is_linear(ct.pointer(res), self.arr))
return res.value

def is_owner(self):
"""
Check if the array owns the raw pointer or is a derived array.
"""
res = ct.c_bool(False)
safe_call(backend.get().af_is_owner(ct.pointer(res), self.arr))
return res.value

def __add__(self, other):
"""
Return self + other.
Expand Down