diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 6acdbc3a2692..f65e386c0619 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -151,6 +151,7 @@ class NDArrayObj : public Object, public DLTensor { protected: // backs up the shape of the NDArray Optional shape_data_; + Optional stride_data_; static void DLManagedTensorDeleter(DLManagedTensor* tensor) { NDArrayObj* obj = static_cast(tensor->manager_ctx); @@ -184,9 +185,11 @@ class NDArrayObjFromNDAlloc : public NDArrayObj { this->ndim = static_cast(shape.size()); this->dtype = dtype; this->shape = const_cast(shape.data()); - this->strides = nullptr; + Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape)); + this->strides = const_cast(strides.data()); this->byte_offset = 0; this->shape_data_ = std::move(shape); + this->stride_data_ = std::move(strides); alloc_.AllocData(static_cast(this), std::forward(extra_args)...); } @@ -202,9 +205,10 @@ class NDArrayObjFromDLPack : public NDArrayObj { public: explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { *static_cast(this) = tensor_->dl_tensor; - // set strides to nullptr if the tensor is contiguous. - if (IsContiguous(tensor->dl_tensor)) { - this->strides = nullptr; + if (tensor_->dl_tensor.strides == nullptr) { + Shape strides = Shape(details::MakeStridesFromShape(ndim, shape)); + this->strides = const_cast(strides.data()); + this->stride_data_ = std::move(strides); } } diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h index 2fccc028a5b3..6360fcd1e398 100644 --- a/ffi/include/tvm/ffi/container/shape.h +++ b/ffi/include/tvm/ffi/container/shape.h @@ -91,6 +91,17 @@ TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end return p; } +TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(int64_t ndim, int64_t* shape) { + int64_t* strides_data; + ObjectPtr strides = details::MakeEmptyShape(ndim, &strides_data); + int64_t stride = 1; + for (int i = ndim - 1; i >= 0; --i) { + strides_data[i] = stride; + stride *= shape[i]; + } + return strides; +} + } // namespace details /*! diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc index 3d7b00cd33c3..0196bfc4fb25 100644 --- a/ffi/tests/cpp/test_ndarray.cc +++ b/ffi/tests/cpp/test_ndarray.cc @@ -69,7 +69,9 @@ TEST(NDArray, DLPack) { EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + EXPECT_EQ(dlpack->dl_tensor.strides[0], 6); + EXPECT_EQ(dlpack->dl_tensor.strides[1], 3); + EXPECT_EQ(dlpack->dl_tensor.strides[2], 1); EXPECT_EQ(nd.use_count(), 2); { NDArray nd2 = NDArray::FromDLPack(dlpack); @@ -96,7 +98,7 @@ TEST(NDArray, DLPackVersioned) { EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + EXPECT_EQ(dlpack->dl_tensor.strides[0], 1); EXPECT_EQ(nd.use_count(), 2); { diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 6eebe49ff135..9a295e491e82 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -239,7 +239,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { strm->Write(data_byte_size); if (DMLC_IO_NO_ENDIAN_SWAP && tensor->device.device_type == kDLCPU && - tensor->strides == nullptr && tensor->byte_offset == 0) { + ffi::IsContiguous(*tensor) && tensor->byte_offset == 0) { // quick path strm->Write(tensor->data, data_byte_size); } else { diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index c1aee73cc258..33e077d72641 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -270,7 +270,7 @@ class ConstantFolder : public ExprMutator { Constant constant = Downcast(arg); runtime::NDArray ndarray = constant->data; ICHECK_EQ(ndarray->device.device_type, kDLCPU); - ICHECK(ndarray->strides == nullptr); + ICHECK(ffi::IsContiguous(*ndarray.get())); ICHECK_EQ(ndarray->byte_offset, 0); ICHECK_EQ(ndarray->ndim, 1); const int64_t* data = static_cast(ndarray->data); diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 8e0b2542b443..fb5faa8621b2 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -60,7 +60,7 @@ MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil]; - ICHECK(data_in->strides == NULL); + ICHECK(ffi::IsContiguous(*data_in)); memcpy(dest.dataPointer, data_in->data, size); NSString* nsKey = [NSString stringWithUTF8String:key.c_str()]; diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 686a8048c7b5..59b162e76503 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -821,7 +821,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { TensorRequisite res; if (const_dl_tensor) { ICHECK(const_dl_tensor->data); - ICHECK(const_dl_tensor->strides == nullptr); + ICHECK(ffi::IsContiguous(*const_dl_tensor)); auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data); res = TensorRequisite::AsIs(mem, eid); } else { diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index dfc98388d372..2bf38796fd66 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -91,9 +91,9 @@ ICHECK_EQ(data->ndim, 4); ICHECK_EQ(weight->ndim, 4); ICHECK_EQ(output->ndim, 4); - ICHECK(output->strides == nullptr); - ICHECK(weight->strides == nullptr); - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*output)); + ICHECK(ffi::IsContiguous(*weight)); + ICHECK(ffi::IsContiguous(*data)); ICHECK_EQ(data->shape[0], 1); ICHECK_EQ(output->shape[0], 1); diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index 9f5270f38fec..7f386172f642 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -37,9 +37,9 @@ ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); + ICHECK(ffi::IsContiguous(*C)); + ICHECK(ffi::IsContiguous(*B)); + ICHECK(ffi::IsContiguous(*A)); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 04b53d74b404..3ab0309630cf 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -75,7 +75,7 @@ class RandomEngine { */ void SampleUniform(DLTensor* data, float low, float high) { ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; @@ -99,7 +99,7 @@ class RandomEngine { */ void SampleNormal(DLTensor* data, float loc, float scale) { ICHECK_GT(scale, 0) << "standard deviation must be positive"; - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 580ed1073a47..b7ca1f8fd705 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -80,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int64_t high = args[1].cast(); auto out = args[2].cast(); ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(out->strides == nullptr); + ICHECK(ffi::IsContiguous(*out)); DLDataType dtype = out->dtype; int64_t size = 1; diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 8fdce7e43bf0..be3c49e12196 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -81,9 +81,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); + ICHECK(ffi::IsContiguous(*C)); + ICHECK(ffi::IsContiguous(*B)); + ICHECK(ffi::IsContiguous(*A)); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index c35af35eae13..d65f2ad65b63 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -118,7 +118,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = interpreter_->typed_input_tensor(index); DType* src = static_cast(data_in->data); - ICHECK(data_in->strides == NULL); + ICHECK(ffi::IsContiguous(*data_in)); int64_t size = 1; for (int64_t i = 0; i < data_in->ndim; ++i) { size *= data_in->shape[i]; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index b5f1e6995f83..dfca27c8c3ed 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -24,6 +24,8 @@ #ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ #define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ +#include + namespace tvm { namespace ffi { // Forward declare TVM Object to use `Object*` in RPC protocol. @@ -255,7 +257,7 @@ struct RPCReference { channel->Write(arr->ndim); channel->Write(arr->dtype); channel->WriteArray(arr->shape, arr->ndim); - if (arr->strides != nullptr) { + if (!ffi::IsContiguous(*arr)) { channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); } channel->Write(arr->byte_offset); diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index 8963df065258..085860348e2f 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -396,6 +396,7 @@ class RNNStateImpObj : public RNNStateObj { _state.byte_offset = elem_offset * state->dtype.bits / 8; _state.ndim = state->ndim - 2; _state.shape = const_cast(_state.shape + 2); + _state.strides = const_cast(_state.strides + 2); return _state; } @@ -411,6 +412,7 @@ class RNNStateImpObj : public RNNStateObj { _state.byte_offset = elem_offset * state->dtype.bits / 8; _state.ndim = state->ndim - 1; _state.shape = const_cast(_state.shape + 1); + _state.strides = const_cast(_state.strides + 1); return _state; } @@ -428,7 +430,7 @@ class RNNStateImpObj : public RNNStateObj { copy_src.ndim = 1; copy_src.dtype = array->dtype; copy_src.shape = array->shape; - copy_src.strides = nullptr; + copy_src.strides = array->strides; copy_src.byte_offset = 0; NDArray::CopyFromTo(©_src, ©_dst); };