Skip to content

Commit 9f880e9

Browse files
committed
[microNPU] Add support for scalar values
PR apache#9515 enabled support for scalar constants, but didn't consider the case of a scalar value where the underlying constant data does not have a shape i.e. `constant.shape == []`. See the test case for a visual differece when the scalar value is 1. Change-Id: Id7a238cb5bf999dd5a8428c097202f9fb940a5f0
1 parent 133bb9c commit 9f880e9

4 files changed

Lines changed: 21 additions & 11 deletions

File tree

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,8 @@ def callback(
652652
ifm2_zero_point=int(params.ifm2.q_params.zero_point),
653653
ofm_scale=float(params.ofm.q_params.scale_f32),
654654
ofm_zero_point=int(params.ofm.q_params.zero_point),
655-
ifm_channels=params.ifm.shape[-1],
656-
ifm2_channels=params.ifm2.shape[-1],
655+
ifm_channels=params.ifm.shape[-1] if params.ifm.shape else 1,
656+
ifm2_channels=params.ifm2.shape[-1] if params.ifm2.shape else 1,
657657
reversed_operands=params.reversed_operands,
658658
ofm_dtype=params.ofm.dtype,
659659
activation=activation,

python/tvm/relay/backend/contrib/ethosu/tir/compiler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,11 @@ def __init__(self):
123123

124124
def visit_constant(self, const):
125125
if isinstance(const.checked_type, relay.ty.TensorType):
126-
if const.checked_type.concrete_shape != ():
127-
self.constants.append(const.data.asnumpy())
128-
name = "p" + str(len(self.constants))
129-
var = relay.var(type_annotation=const.checked_type, name_hint=name)
130-
self.const_vars.append(var)
131-
return var
126+
self.constants.append(const.data.asnumpy())
127+
name = "p" + str(len(self.constants))
128+
var = relay.var(type_annotation=const.checked_type, name_hint=name)
129+
self.const_vars.append(var)
130+
return var
132131

133132
return const
134133

python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,17 @@ def _visit(tensor, reader, lut):
136136
if tensor not in planned:
137137
planned.add(tensor)
138138
if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor != lut:
139-
index = list(cached_func.inputs).index(tensor)
139+
# Find index of input using 'same_as' check to prevent equality
140+
# ambiguity when encountering a scalar.
141+
index = -1
142+
for i, var in enumerate(cached_func.inputs):
143+
if var.same_as(tensor):
144+
index = i
145+
break
146+
assert (
147+
index >= 0
148+
), f"Tensor {tensor} was not found in inputs: {cached_func.inputs}"
149+
140150
if index in const_dict:
141151
sch.cache_read(tensor, "global", [reader])
142152

tests/python/contrib/test_ethosu/test_codegen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,13 @@ def create_mod_from_relay():
629629

630630
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
631631
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
632-
def test_elementwise_add_from_constant_scalar(accel_type, dtype):
632+
@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)])
633+
def test_elementwise_add_from_constant_scalar(accel_type, dtype, constant):
633634
ifm_shape = (1, 4, 4, 8)
634635

635636
def create_relay_graph():
636637
inp = relay.var("input", shape=ifm_shape, dtype=dtype)
637-
scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
638+
scalar = relay.const(constant, dtype=dtype)
638639
add = relay.qnn.op.add(
639640
inp,
640641
scalar,

0 commit comments

Comments
 (0)