Skip to content

Commit 775742d

Browse files
committed
[TIR] End-to-end tests for PrimFunc-to-PrimFunc subroutines
The functionality tested in this commit was added across several recent PRs, each of which tested their features in isolation. This PR adds unit tests to validate the end-to-end behavior of TIR subroutine calls. PRs building up to this point: - TVMScript - apache#14889 - apache#14915 - apache#14919 - apache#14941 - Functionality improvements of existing TIR passes - apache#14913 - apache#14914 - apache#14918 - apache#14951 - Changes to the TIR lowering flow - apache#14942 - apache#14985 - Codegen updates - apache#14958 - apache#14901 - Compatibility updates/fixes - apache#14892 - apache#14950 - apache#14943 - apache#14944 - apache#14945 - apache#14952 - apache#14982 - apache#14949
1 parent 60c866a commit 775742d

1 file changed

Lines changed: 275 additions & 0 deletions

File tree

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#!/usr/bin/env python3
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
# pylint: disable=missing-function-docstring,missing-module-docstring
20+
21+
import pytest
22+
import numpy as np
23+
24+
import tvm
25+
import tvm.testing
26+
27+
from tvm.script import tir as T, ir as I
28+
29+
30+
@tvm.testing.parametrize_targets("llvm")
31+
def test_call_noop(target, dev):
32+
"""TIR functions on the CPU may call other functions
33+
34+
The simplest test case, where the subroutine is a no-op.
35+
"""
36+
37+
@I.ir_module
38+
class module:
39+
@T.prim_func
40+
def subroutine():
41+
T.evaluate(0)
42+
43+
@T.prim_func
44+
def main(A: T.Buffer(1, "float32")):
45+
T.func_attr({"global_symbol": "main"})
46+
module.subroutine()
47+
A[0] = 42.0
48+
49+
built = tvm.build(module, target=target)
50+
51+
arr = tvm.nd.empty([1], dtype="float32", device=dev)
52+
built(arr)
53+
54+
assert arr.numpy()[0] == 42.0
55+
56+
57+
@tvm.testing.parametrize_targets("llvm")
58+
def test_call_noop_defined_below(target, dev):
59+
"""Calling a subroutine does not depend on the definition order
60+
61+
All GlobalVar instances are in-scope for subroutine calls.
62+
"""
63+
64+
@I.ir_module
65+
class module:
66+
@T.prim_func
67+
def main(A: T.Buffer(1, "float32")):
68+
T.func_attr({"global_symbol": "main"})
69+
module.subroutine()
70+
A[0] = 42.0
71+
72+
@T.prim_func
73+
def subroutine():
74+
T.evaluate(0)
75+
76+
built = tvm.build(module, target=target)
77+
78+
arr = tvm.nd.empty([1], dtype="float32", device=dev)
79+
built(arr)
80+
81+
assert arr.numpy()[0] == 42.0
82+
83+
84+
@tvm.testing.parametrize_targets("llvm")
85+
def test_subroutine_call_with_pointer_param(target, dev):
86+
"""TIR functions on the CPU may call other functions
87+
88+
Buffers may be exposed to subroutines through data pointers.
89+
"""
90+
91+
@I.ir_module
92+
class module:
93+
@T.prim_func
94+
def main(A: T.Buffer(2, "float32")):
95+
T.func_attr({"global_symbol": "main"})
96+
module.subroutine(A.data)
97+
module.subroutine(T.address_of(A[1]))
98+
99+
@T.prim_func
100+
def subroutine(A_data: T.handle("float32")):
101+
A = T.decl_buffer(shape=[1], dtype="float32", data=A_data)
102+
A[0] = 42.0
103+
104+
built = tvm.build(module, target=target)
105+
106+
arr = tvm.nd.empty([2], dtype="float32", device=dev)
107+
built(arr)
108+
109+
assert arr.numpy()[0] == 42.0
110+
assert arr.numpy()[1] == 42.0
111+
112+
113+
@pytest.mark.xfail(reason="Depends on LLVM version")
114+
@tvm.testing.parametrize_targets("llvm")
115+
def test_failed_subroutine_call_for_incorrect_type(target, dev):
116+
"""Calls into a subroutine must have correct argument types
117+
118+
This currently relies on the `llvm::verifyModule` function during
119+
codegen. In the future, this should be moved to a dedicated check
120+
of TIR validity.
121+
"""
122+
123+
@I.ir_module
124+
class module:
125+
@T.prim_func
126+
def main(A: T.Buffer(1, "float32")):
127+
T.func_attr({"global_symbol": "main"})
128+
module.subroutine(A.data)
129+
130+
@T.prim_func
131+
def subroutine(A_data: T.handle("int32")):
132+
A = T.decl_buffer(shape=[1], dtype="int32", data=A_data)
133+
A[0] = -1
134+
135+
lowered = tvm.lower(module)
136+
with pytest.raises(tvm.TVMError):
137+
tvm.build(lowered)
138+
139+
140+
@tvm.testing.parametrize_targets("llvm")
141+
def test_subroutine_call_with_scalar_param(target, dev):
142+
"""Subroutines may also accept scalar parameters"""
143+
144+
@I.ir_module
145+
class module:
146+
@T.prim_func
147+
def main(A: T.Buffer(1, "float32")):
148+
T.func_attr({"global_symbol": "main"})
149+
module.subroutine(A.data, 42.0)
150+
151+
@T.prim_func
152+
def subroutine(A_data: T.handle("float32"), val: T.float32):
153+
A = T.decl_buffer([1], "float32", data=A_data)
154+
A[0] = 2 * val
155+
156+
built = tvm.build(module, target=target)
157+
158+
arr = tvm.nd.empty([1], dtype="float32", device=dev)
159+
built(arr)
160+
161+
assert arr.numpy()[0] == 84.0
162+
163+
164+
@tvm.testing.parametrize_targets("llvm")
165+
def test_internal_subroutine_is_not_exposed_externally(target, dev):
166+
"""An internal subroutine may not be called externally
167+
168+
An internal subroutine is any subroutine without a "global_symbol"
169+
attribute. These are not exposed in the runtime::Module and do
170+
not have an externally linkable symbol.
171+
"""
172+
173+
@I.ir_module
174+
class module:
175+
@T.prim_func
176+
def main(A: T.Buffer(1, "float32")):
177+
T.func_attr({"global_symbol": "main"})
178+
module.subroutine(A.data, 42.0)
179+
180+
@T.prim_func
181+
def subroutine(A_data: T.handle("float32"), val: T.float32):
182+
A = T.decl_buffer([1], "float32", data=A_data)
183+
A[0] = 2 * val
184+
185+
built = tvm.build(module, target=target)
186+
with pytest.raises(AttributeError):
187+
built["subroutine"]
188+
189+
190+
@tvm.testing.parametrize_targets("llvm")
191+
def test_call_to_externally_visible_subroutine(target, dev):
192+
"""Subroutines may be exposed externally.
193+
194+
A subroutine may be exposed externally. Externally-exposed
195+
subroutines may be called by an external API, or may be called by
196+
other functions in the same IRModule.
197+
198+
The current implementation lowers internal subroutine calls to
199+
`T.tvm_call_cpacked`. This avoids the overhead of the global
200+
registry lookup used by `T.tvm_call_packed`, but still requires
201+
the overhead of packing/unpacking the `PackedFunc` interface, and
202+
is limited to callers whose target supports the `PackedFunc`
203+
interface.
204+
"""
205+
206+
@I.ir_module
207+
class module:
208+
@T.prim_func
209+
def main(A: T.Buffer(1, "float32")):
210+
T.func_attr({"global_symbol": "main"})
211+
module.subroutine(A.data, 42.0)
212+
213+
@T.prim_func
214+
def subroutine(A_data: T.handle("float32"), val: T.float32):
215+
T.func_attr({"global_symbol": "subroutine"})
216+
A = T.Buffer([1], "float32", data=A_data)
217+
A[0] = 2 * val
218+
219+
built = tvm.build(module, target=target)
220+
221+
arr = tvm.nd.empty([1], dtype="float32", device=dev)
222+
built["main"](arr)
223+
assert arr.numpy()[0] == 84.0
224+
225+
arr = np.zeros(shape=[1], dtype="float32")
226+
built["subroutine"](arr.ctypes._data, 100.0)
227+
assert arr[0] == 200.0
228+
229+
230+
is_external_subroutine = tvm.testing.parameter(by_dict={"external": True, "internal": False})
231+
232+
233+
@tvm.testing.parametrize_targets("llvm", "cuda")
234+
def test_call_to_device_subroutine(target, dev, is_external_subroutine):
235+
"""Subroutines may be exposed externally.
236+
237+
This feature is currently limited to host-side subroutine calls of
238+
externally-exposed subroutines.
239+
"""
240+
is_gpu = "gpu" in tvm.target.Target(target).keys
241+
242+
if is_gpu and not is_external_subroutine:
243+
pytest.xfail(reason="Not yet implemented.")
244+
245+
if is_external_subroutine:
246+
func_attr = {"global_symbol": "subroutine"}
247+
else:
248+
func_attr = {}
249+
250+
@I.ir_module
251+
class module:
252+
@T.prim_func
253+
def main(A: T.Buffer(1, "float32")):
254+
T.func_attr({"global_symbol": "main"})
255+
module.subroutine(A.data, 42.0)
256+
257+
@T.prim_func
258+
def subroutine(A_data: T.handle("float32"), val: T.float32):
259+
T.func_attr(func_attr)
260+
A = T.Buffer([1], "float32", data=A_data)
261+
iterator = T.meta_var(
262+
T.thread_binding(0, 1, thread="threadIdx.x") if is_gpu else range(1)
263+
)
264+
for i in iterator:
265+
A[0] = 2 * val
266+
267+
built = tvm.build(module, target=target)
268+
269+
arr = tvm.nd.empty([1], dtype="float32", device=dev)
270+
built["main"](arr)
271+
assert arr.numpy()[0] == 84.0
272+
273+
274+
if __name__ == "__main__":
275+
tvm.testing.main()

0 commit comments

Comments
 (0)