From 706bbf5b0a935f3f4a37295d8f6cb06965503a58 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 5 Sep 2025 17:38:12 -0700 Subject: [PATCH 1/8] init --- ffi/include/tvm/ffi/container/ndarray.h | 19 +++++++++++++------ ffi/tests/cpp/test_ndarray.cc | 6 ++++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 6acdbc3a2692..91688fbe3739 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -41,7 +41,6 @@ namespace ffi { * \return The check result. */ inline bool IsContiguous(const DLTensor& arr) { - if (arr.strides == nullptr) return true; int64_t expected_stride = 1; for (int32_t i = arr.ndim; i != 0; --i) { int32_t k = i - 1; @@ -110,6 +109,15 @@ inline size_t GetDataSize(const DLTensor& arr) { return GetDataSize(size, arr.dtype); } +inline Shape InferStrideFromShape(Shape shape) { + size_t ndim = shape.size(); + Array strides(ndim, 1); + for (int i = ndim - 2; i >= 0; --i) { + strides.Set(i, shape[i + 1] * strides[i + 1]); + } + return Shape(strides); +} + /*! \brief An object representing an NDArray. */ class NDArrayObj : public Object, public DLTensor { public: @@ -151,6 +159,7 @@ class NDArrayObj : public Object, public DLTensor { protected: // backs up the shape of the NDArray Optional shape_data_; + Optional stride_data_; static void DLManagedTensorDeleter(DLManagedTensor* tensor) { NDArrayObj* obj = static_cast(tensor->manager_ctx); @@ -184,9 +193,11 @@ class NDArrayObjFromNDAlloc : public NDArrayObj { this->ndim = static_cast(shape.size()); this->dtype = dtype; this->shape = const_cast(shape.data()); - this->strides = nullptr; + Shape strides = InferStrideFromShape(shape); + this->strides = const_cast(strides.data());; this->byte_offset = 0; this->shape_data_ = std::move(shape); + this->stride_data_ = std::move(strides); alloc_.AllocData(static_cast(this), std::forward(extra_args)...); } @@ -202,10 +213,6 @@ class NDArrayObjFromDLPack : public NDArrayObj { public: explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { *static_cast(this) = tensor_->dl_tensor; - // set strides to nullptr if the tensor is contiguous. - if (IsContiguous(tensor->dl_tensor)) { - this->strides = nullptr; - } } ~NDArrayObjFromDLPack() { diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc index 3d7b00cd33c3..0196bfc4fb25 100644 --- a/ffi/tests/cpp/test_ndarray.cc +++ b/ffi/tests/cpp/test_ndarray.cc @@ -69,7 +69,9 @@ TEST(NDArray, DLPack) { EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + EXPECT_EQ(dlpack->dl_tensor.strides[0], 6); + EXPECT_EQ(dlpack->dl_tensor.strides[1], 3); + EXPECT_EQ(dlpack->dl_tensor.strides[2], 1); EXPECT_EQ(nd.use_count(), 2); { NDArray nd2 = NDArray::FromDLPack(dlpack); @@ -96,7 +98,7 @@ TEST(NDArray, DLPackVersioned) { EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU); EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0); EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0); - EXPECT_EQ(dlpack->dl_tensor.strides, nullptr); + EXPECT_EQ(dlpack->dl_tensor.strides[0], 1); EXPECT_EQ(nd.use_count(), 2); { From c0938184f8eb3ddcc5d5ecae9b9a9fc31af343ae Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 5 Sep 2025 17:44:33 -0700 Subject: [PATCH 2/8] add doc --- ffi/include/tvm/ffi/container/ndarray.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 91688fbe3739..456a880ed527 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -109,6 +109,12 @@ inline size_t GetDataSize(const DLTensor& arr) { return GetDataSize(size, arr.dtype); } +/*! + * \brief Infer the stride from shape + * + * \param shape the input Shape + * \return the inferred stride + */ inline Shape InferStrideFromShape(Shape shape) { size_t ndim = shape.size(); Array strides(ndim, 1); From 72eba166e6c936d7020e577c2e17a892fdf1260a Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 5 Sep 2025 18:16:48 -0700 Subject: [PATCH 3/8] upd --- ffi/include/tvm/ffi/container/ndarray.h | 23 +++++++---------------- ffi/include/tvm/ffi/container/shape.h | 11 +++++++++++ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 456a880ed527..1da707006601 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -41,6 +41,7 @@ namespace ffi { * \return The check result. */ inline bool IsContiguous(const DLTensor& arr) { + if (arr.strides == nullptr) return true; int64_t expected_stride = 1; for (int32_t i = arr.ndim; i != 0; --i) { int32_t k = i - 1; @@ -109,21 +110,6 @@ inline size_t GetDataSize(const DLTensor& arr) { return GetDataSize(size, arr.dtype); } -/*! - * \brief Infer the stride from shape - * - * \param shape the input Shape - * \return the inferred stride - */ -inline Shape InferStrideFromShape(Shape shape) { - size_t ndim = shape.size(); - Array strides(ndim, 1); - for (int i = ndim - 2; i >= 0; --i) { - strides.Set(i, shape[i + 1] * strides[i + 1]); - } - return Shape(strides); -} - /*! \brief An object representing an NDArray. */ class NDArrayObj : public Object, public DLTensor { public: @@ -199,7 +185,7 @@ class NDArrayObjFromNDAlloc : public NDArrayObj { this->ndim = static_cast(shape.size()); this->dtype = dtype; this->shape = const_cast(shape.data()); - Shape strides = InferStrideFromShape(shape); + Shape strides = Shape(details::InferStrideFromShape(this->ndim, this->shape)); this->strides = const_cast(strides.data());; this->byte_offset = 0; this->shape_data_ = std::move(shape); @@ -219,6 +205,11 @@ class NDArrayObjFromDLPack : public NDArrayObj { public: explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { *static_cast(this) = tensor_->dl_tensor; + if (tensor_->dl_tensor.strides == nullptr) { + Shape strides = Shape(details::InferStrideFromShape(ndim, shape)); + this->strides = const_cast(strides.data());; + this->stride_data_ = std::move(strides); + } } ~NDArrayObjFromDLPack() { diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h index 2fccc028a5b3..182ed3349e96 100644 --- a/ffi/include/tvm/ffi/container/shape.h +++ b/ffi/include/tvm/ffi/container/shape.h @@ -91,6 +91,17 @@ TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end return p; } +TVM_FFI_INLINE ObjectPtr InferStrideFromShape(int64_t ndim, int64_t* shape) { + int64_t* strides_data; + ObjectPtr strides = details::MakeEmptyShape(ndim, &strides_data); + int64_t stride = 1; + for (int i = ndim - 1; i >= 0; --i) { + strides_data[i] = stride; + stride *= shape[i]; + } + return strides; +} + } // namespace details /*! From 3b7bbf0a19961ccfe428ee0db8e6330b88a09219 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 5 Sep 2025 18:22:18 -0700 Subject: [PATCH 4/8] rename --- ffi/include/tvm/ffi/container/ndarray.h | 4 ++-- ffi/include/tvm/ffi/container/shape.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 1da707006601..8a0da2d4a21d 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -185,7 +185,7 @@ class NDArrayObjFromNDAlloc : public NDArrayObj { this->ndim = static_cast(shape.size()); this->dtype = dtype; this->shape = const_cast(shape.data()); - Shape strides = Shape(details::InferStrideFromShape(this->ndim, this->shape)); + Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape)); this->strides = const_cast(strides.data());; this->byte_offset = 0; this->shape_data_ = std::move(shape); @@ -206,7 +206,7 @@ class NDArrayObjFromDLPack : public NDArrayObj { explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { *static_cast(this) = tensor_->dl_tensor; if (tensor_->dl_tensor.strides == nullptr) { - Shape strides = Shape(details::InferStrideFromShape(ndim, shape)); + Shape strides = Shape(details::MakeStridesFromShape(ndim, shape)); this->strides = const_cast(strides.data());; this->stride_data_ = std::move(strides); } diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h index 182ed3349e96..6360fcd1e398 100644 --- a/ffi/include/tvm/ffi/container/shape.h +++ b/ffi/include/tvm/ffi/container/shape.h @@ -91,7 +91,7 @@ TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end return p; } -TVM_FFI_INLINE ObjectPtr InferStrideFromShape(int64_t ndim, int64_t* shape) { +TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(int64_t ndim, int64_t* shape) { int64_t* strides_data; ObjectPtr strides = details::MakeEmptyShape(ndim, &strides_data); int64_t stride = 1; From ffa9ce6031dd5f64d168fb570d6a47f49a1a08d4 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 5 Sep 2025 18:41:36 -0700 Subject: [PATCH 5/8] fix lint --- ffi/include/tvm/ffi/container/ndarray.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 8a0da2d4a21d..f65e386c0619 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -186,7 +186,7 @@ class NDArrayObjFromNDAlloc : public NDArrayObj { this->dtype = dtype; this->shape = const_cast(shape.data()); Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape)); - this->strides = const_cast(strides.data());; + this->strides = const_cast(strides.data()); this->byte_offset = 0; this->shape_data_ = std::move(shape); this->stride_data_ = std::move(strides); @@ -207,7 +207,7 @@ class NDArrayObjFromDLPack : public NDArrayObj { *static_cast(this) = tensor_->dl_tensor; if (tensor_->dl_tensor.strides == nullptr) { Shape strides = Shape(details::MakeStridesFromShape(ndim, shape)); - this->strides = const_cast(strides.data());; + this->strides = const_cast(strides.data()); this->stride_data_ = std::move(strides); } } From d84aeac5cf777922d9ea3ae4615a99dc20cad4dc Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 5 Sep 2025 20:38:53 -0700 Subject: [PATCH 6/8] using ffi::IsContiguous --- include/tvm/runtime/ndarray.h | 2 +- src/relax/transform/fold_constant.cc | 2 +- src/runtime/contrib/coreml/coreml_runtime.mm | 2 +- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 2 +- src/runtime/contrib/mps/conv.mm | 6 +++--- src/runtime/contrib/mps/gemm.mm | 6 +++--- src/runtime/contrib/random/mt_random_engine.cc | 4 ++-- src/runtime/contrib/random/random.cc | 2 +- src/runtime/contrib/rocblas/rocblas.cc | 6 +++--- src/runtime/contrib/tflite/tflite_runtime.cc | 2 +- src/runtime/minrpc/rpc_reference.h | 2 +- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 6eebe49ff135..9a295e491e82 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -239,7 +239,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { strm->Write(data_byte_size); if (DMLC_IO_NO_ENDIAN_SWAP && tensor->device.device_type == kDLCPU && - tensor->strides == nullptr && tensor->byte_offset == 0) { + ffi::IsContiguous(*tensor) && tensor->byte_offset == 0) { // quick path strm->Write(tensor->data, data_byte_size); } else { diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index c1aee73cc258..856f5a5b8741 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -270,7 +270,7 @@ class ConstantFolder : public ExprMutator { Constant constant = Downcast(arg); runtime::NDArray ndarray = constant->data; ICHECK_EQ(ndarray->device.device_type, kDLCPU); - ICHECK(ndarray->strides == nullptr); + ICHECK(ffi::IsContiguous(ndarray)); ICHECK_EQ(ndarray->byte_offset, 0); ICHECK_EQ(ndarray->ndim, 1); const int64_t* data = static_cast(ndarray->data); diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm index 8e0b2542b443..fb5faa8621b2 100644 --- a/src/runtime/contrib/coreml/coreml_runtime.mm +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -60,7 +60,7 @@ MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil]; - ICHECK(data_in->strides == NULL); + ICHECK(ffi::IsContiguous(*data_in)); memcpy(dest.dataPointer, data_in->data, size); NSString* nsKey = [NSString stringWithUTF8String:key.c_str()]; diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 686a8048c7b5..59b162e76503 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -821,7 +821,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { TensorRequisite res; if (const_dl_tensor) { ICHECK(const_dl_tensor->data); - ICHECK(const_dl_tensor->strides == nullptr); + ICHECK(ffi::IsContiguous(*const_dl_tensor)); auto mem = dnnl::memory(desc, engine_, const_dl_tensor->data); res = TensorRequisite::AsIs(mem, eid); } else { diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index dfc98388d372..2bf38796fd66 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -91,9 +91,9 @@ ICHECK_EQ(data->ndim, 4); ICHECK_EQ(weight->ndim, 4); ICHECK_EQ(output->ndim, 4); - ICHECK(output->strides == nullptr); - ICHECK(weight->strides == nullptr); - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*output)); + ICHECK(ffi::IsContiguous(*weight)); + ICHECK(ffi::IsContiguous(*data)); ICHECK_EQ(data->shape[0], 1); ICHECK_EQ(output->shape[0], 1); diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index 9f5270f38fec..7f386172f642 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -37,9 +37,9 @@ ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); + ICHECK(ffi::IsContiguous(*C)); + ICHECK(ffi::IsContiguous(*B)); + ICHECK(ffi::IsContiguous(*A)); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 04b53d74b404..eeffbd1ee0e2 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -75,7 +75,7 @@ class RandomEngine { */ void SampleUniform(DLTensor* data, float low, float high) { ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; @@ -99,7 +99,7 @@ class RandomEngine { */ void SampleNormal(DLTensor* data, float loc, float scale) { ICHECK_GT(scale, 0) << "standard deviation must be positive"; - ICHECK(data->strides == nullptr); + ICHECK(ffi::IsContiguous(data)); DLDataType dtype = data->dtype; int64_t size = 1; diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 580ed1073a47..b7ca1f8fd705 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -80,7 +80,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ int64_t high = args[1].cast(); auto out = args[2].cast(); ICHECK_GT(high, low) << "high must be bigger than low"; - ICHECK(out->strides == nullptr); + ICHECK(ffi::IsContiguous(*out)); DLDataType dtype = out->dtype; int64_t size = 1; diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 8fdce7e43bf0..be3c49e12196 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -81,9 +81,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK_EQ(A->ndim, 2); ICHECK_EQ(B->ndim, 2); ICHECK_EQ(C->ndim, 2); - ICHECK(C->strides == nullptr); - ICHECK(B->strides == nullptr); - ICHECK(A->strides == nullptr); + ICHECK(ffi::IsContiguous(*C)); + ICHECK(ffi::IsContiguous(*B)); + ICHECK(ffi::IsContiguous(*A)); ICHECK(TypeMatch(A->dtype, kDLFloat, 32)); ICHECK(TypeMatch(B->dtype, kDLFloat, 32)); ICHECK(TypeMatch(C->dtype, kDLFloat, 32)); diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index c35af35eae13..d65f2ad65b63 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -118,7 +118,7 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { TVM_DTYPE_DISPATCH(dtype, DType, { DType* dest = interpreter_->typed_input_tensor(index); DType* src = static_cast(data_in->data); - ICHECK(data_in->strides == NULL); + ICHECK(ffi::IsContiguous(*data_in)); int64_t size = 1; for (int64_t i = 0; i < data_in->ndim; ++i) { size *= data_in->shape[i]; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index b5f1e6995f83..7b27539f2c43 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -255,7 +255,7 @@ struct RPCReference { channel->Write(arr->ndim); channel->Write(arr->dtype); channel->WriteArray(arr->shape, arr->ndim); - if (arr->strides != nullptr) { + if (!ffi::IsContiguous(*arr)) { channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); } channel->Write(arr->byte_offset); From 8791500bbf9e558b34d60239ec4a67241696e4cb Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Fri, 5 Sep 2025 21:34:27 -0700 Subject: [PATCH 7/8] fix --- src/relax/transform/fold_constant.cc | 2 +- src/runtime/contrib/random/mt_random_engine.cc | 2 +- src/runtime/minrpc/rpc_reference.h | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 856f5a5b8741..33e077d72641 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -270,7 +270,7 @@ class ConstantFolder : public ExprMutator { Constant constant = Downcast(arg); runtime::NDArray ndarray = constant->data; ICHECK_EQ(ndarray->device.device_type, kDLCPU); - ICHECK(ffi::IsContiguous(ndarray)); + ICHECK(ffi::IsContiguous(*ndarray.get())); ICHECK_EQ(ndarray->byte_offset, 0); ICHECK_EQ(ndarray->ndim, 1); const int64_t* data = static_cast(ndarray->data); diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index eeffbd1ee0e2..3ab0309630cf 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -99,7 +99,7 @@ class RandomEngine { */ void SampleNormal(DLTensor* data, float loc, float scale) { ICHECK_GT(scale, 0) << "standard deviation must be positive"; - ICHECK(ffi::IsContiguous(data)); + ICHECK(ffi::IsContiguous(*data)); DLDataType dtype = data->dtype; int64_t size = 1; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 7b27539f2c43..dfca27c8c3ed 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -24,6 +24,8 @@ #ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ #define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ +#include + namespace tvm { namespace ffi { // Forward declare TVM Object to use `Object*` in RPC protocol. From 46885a68dfe1560a2eb0eadf24eb03091c06eafc Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 6 Sep 2025 08:14:45 +0000 Subject: [PATCH 8/8] fix --- src/runtime/vm/rnn_state.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/vm/rnn_state.cc b/src/runtime/vm/rnn_state.cc index 8963df065258..085860348e2f 100644 --- a/src/runtime/vm/rnn_state.cc +++ b/src/runtime/vm/rnn_state.cc @@ -396,6 +396,7 @@ class RNNStateImpObj : public RNNStateObj { _state.byte_offset = elem_offset * state->dtype.bits / 8; _state.ndim = state->ndim - 2; _state.shape = const_cast(_state.shape + 2); + _state.strides = const_cast(_state.strides + 2); return _state; } @@ -411,6 +412,7 @@ class RNNStateImpObj : public RNNStateObj { _state.byte_offset = elem_offset * state->dtype.bits / 8; _state.ndim = state->ndim - 1; _state.shape = const_cast(_state.shape + 1); + _state.strides = const_cast(_state.strides + 1); return _state; } @@ -428,7 +430,7 @@ class RNNStateImpObj : public RNNStateObj { copy_src.ndim = 1; copy_src.dtype = array->dtype; copy_src.shape = array->shape; - copy_src.strides = nullptr; + copy_src.strides = array->strides; copy_src.byte_offset = 0; NDArray::CopyFromTo(©_src, ©_dst); };