Skip to content

[Bug] [Relay][PyTorch] aten::new_zeros support is broken with torch>=2.1.0 #16126

@mshr-h

Description

@mshr-h

Seems like the aten::new_zeros support is broken with torch>=2.1.0. #14747

Steps to reproduce

from tvm import relay
import torch

torch.set_grad_enabled(False)

class NewZeros1(torch.nn.Module):
  def forward(self, x):
    return x.new_zeros((2, 3))

input_data = torch.tensor((), dtype=torch.float)
module = NewZeros1().float().eval()
scripted_module = torch.jit.trace(module, input_data)
input_infos = [('x', (input_data.shape, 'float32'))]
mod, params = relay.frontend.from_pytorch(scripted_module, input_infos)

Actual behavior

Traceback (most recent call last):
  File "/home/ubuntu/workspace/sandbox/tvm_/frontend/new_zeros.py", line 17, in <module>
    mod, params = relay.frontend.from_pytorch(scripted_module, input_infos)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 5183, in from_pytorch
    outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 4402, in convert_operators
    relay_out = relay_op(
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 851, in new_zeros
    return self.full_impl(data, 0, dtype)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/frontend/pytorch.py", line 755, in full_impl
    out = _op.full(fill_value, size, dtype=dtype)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/relay/op/transform.py", line 535, in full
    return _make.full(fill_value, shape, dtype)
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  2: _ZN3tvm7runtime13PackedFuncObj
  1: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  0: tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::runtime::DataType<tvm::runtime::DataType>() const
  3: _ZN3tvm7runtime13PackedFuncObj
  2: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::runtime::Array<tvm::Integer, void>, tvm::runtime::DataType), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  1: tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::runtime::DataType<tvm::runtime::DataType>() const
  0: tvm::runtime::TVMArgValue::operator DLDataType() const
  File "/home/ubuntu/workspace/sandbox/.dep/tvm/include/tvm/runtime/packed_func.h", line 777
TVMError: In function relay.op._make.full(0: RelayExpr, 1: Array<IntImm>, 2: DataType) -> RelayExpr: error while converting argument 2: [19:47:16] /home/ubuntu/workspace/sandbox/.dep/tvm/include/tvm/runtime/packed_func.h:2210: InternalError: Check failed: type_code_ == kTVMDataType (8 vs. 5) : expected DLDataType but got Object

Environment

TVM: c8ef902
LLVM=ON
CUDA=OFF
Torch==2.1.0

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • frontend:pytorch

cc @shingjan @yelite

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions