Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions ffi/include/tvm/ffi/container/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ namespace ffi {
* \return The check result.
*/
inline bool IsContiguous(const DLTensor& arr) {
if (arr.strides == nullptr) return true;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep this for now as this function is defined per DLTensor

int64_t expected_stride = 1;
for (int32_t i = arr.ndim; i != 0; --i) {
int32_t k = i - 1;
Expand Down Expand Up @@ -110,6 +109,21 @@ inline size_t GetDataSize(const DLTensor& arr) {
return GetDataSize(size, arr.dtype);
}

/*!
* \brief Infer the stride from shape
*
* \param shape the input Shape
* \return the inferred stride
*/
inline Shape InferStrideFromShape(Shape shape) {
size_t ndim = shape.size();
Array<int64_t> strides(ndim, 1);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use array, instead, use https://github.com/apache/tvm/blob/main/ffi/include/tvm/ffi/container/shape.h#L71
details::MakeEmptyShape, then fill in the strides.

do something like

int64_t stride = 1;

for  () {
   assign
}

We can also move this function to shape.h details::

for (int i = ndim - 2; i >= 0; --i) {
strides.Set(i, shape[i + 1] * strides[i + 1]);
}
return Shape(strides);
}

/*! \brief An object representing an NDArray. */
class NDArrayObj : public Object, public DLTensor {
public:
Expand Down Expand Up @@ -151,6 +165,7 @@ class NDArrayObj : public Object, public DLTensor {
protected:
// backs up the shape of the NDArray
Optional<Shape> shape_data_;
Optional<Shape> stride_data_;

static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx);
Expand Down Expand Up @@ -184,9 +199,11 @@ class NDArrayObjFromNDAlloc : public NDArrayObj {
this->ndim = static_cast<int>(shape.size());
this->dtype = dtype;
this->shape = const_cast<int64_t*>(shape.data());
this->strides = nullptr;
Shape strides = InferStrideFromShape(shape);
this->strides = const_cast<int64_t*>(strides.data());;
this->byte_offset = 0;
this->shape_data_ = std::move(shape);
this->stride_data_ = std::move(strides);
alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...);
}

Expand All @@ -202,10 +219,6 @@ class NDArrayObjFromDLPack : public NDArrayObj {
public:
explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) {
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
// set strides to nullptr if the tensor is contiguous.
if (IsContiguous(tensor->dl_tensor)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to check tensor->dl_tensor->strides == nullptr and act accordingly if needed

this->strides = nullptr;
}
}

~NDArrayObjFromDLPack() {
Expand Down
6 changes: 4 additions & 2 deletions ffi/tests/cpp/test_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ TEST(NDArray, DLPack) {
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
EXPECT_EQ(dlpack->dl_tensor.strides[0], 6);
EXPECT_EQ(dlpack->dl_tensor.strides[1], 3);
EXPECT_EQ(dlpack->dl_tensor.strides[2], 1);
EXPECT_EQ(nd.use_count(), 2);
{
NDArray nd2 = NDArray::FromDLPack(dlpack);
Expand All @@ -96,7 +98,7 @@ TEST(NDArray, DLPackVersioned) {
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
EXPECT_EQ(dlpack->dl_tensor.strides[0], 1);

EXPECT_EQ(nd.use_count(), 2);
{
Expand Down
Loading