Skip to content

Commit 9ba0324

Browse files
cyx-6Siyuan Feng
authored andcommitted
basic types and exprs (apache#41)
1 parent ee36fa8 commit 9ba0324

11 files changed

Lines changed: 421 additions & 11 deletions

File tree

include/tvm/tir/op.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
796796
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
797797
Span span = Span());
798798

799+
/*!
800+
* \brief Calculate fmod(x, y)
801+
* \param x Left operand.
802+
* \param y Right operand.
803+
* \param span The location of this operation in the source.
804+
* \return The result expression.
805+
*/
806+
TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
807+
799808
/*!
800809
* \brief Calculate floor(x)
801810
* \param x The input expression.
@@ -896,6 +905,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
896905
TVM_DECLARE_INTRIN_UNARY(log);
897906
TVM_DECLARE_INTRIN_UNARY(log2);
898907
TVM_DECLARE_INTRIN_UNARY(log10);
908+
TVM_DECLARE_INTRIN_UNARY(log1p);
899909
TVM_DECLARE_INTRIN_UNARY(popcount);
900910
TVM_DECLARE_INTRIN_UNARY(tan);
901911
TVM_DECLARE_INTRIN_UNARY(cos);

python/tvm/script/builder/_ffi_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"""FFI APIs for tvm.script.builder"""
1818
import tvm._ffi
1919

20-
tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access
20+
tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access

python/tvm/script/builder/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@
3131
)
3232
from .prim_func_frame import arg, prim_func
3333
from .var import Buffer
34+
from .op import *

python/tvm/script/builder/tir/_ffi_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
"""FFI APIs for tvm.script.builder.tir"""
1818
import tvm._ffi
1919

20-
tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access
20+
tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""TVM Script TIR Op"""
18+
19+
from . import _ffi_api
20+
21+
22+
from tvm.tir.op import abs, popcount, nextafter, copysign, fmod
23+
from tvm.tir.op import (
24+
floor,
25+
floordiv,
26+
floormod,
27+
ceil,
28+
round,
29+
trunc,
30+
truncdiv,
31+
truncmod,
32+
nearbyint,
33+
)
34+
from tvm.tir.op import (
35+
hypot,
36+
ldexp,
37+
power,
38+
exp,
39+
exp2,
40+
exp10,
41+
erf,
42+
sqrt,
43+
rsqrt,
44+
log,
45+
log2,
46+
log10,
47+
log1p,
48+
sigmoid,
49+
)
50+
from tvm.tir.op import isnan, isfinite, isinf
51+
from tvm.tir.op import cos, cosh, sin, sinh, tan, tanh
52+
from tvm.tir.op import acos, acosh, asin, asinh, atan, atanh
53+
from tvm.tir.op import atan2, clz, comm_reducer, infinity, reinterpret
54+
from tvm.tir.op import min_value, max_value, if_then_else
55+
from tvm.tir.op import call_packed, call_extern
56+
from tvm.tir.expr import Select, Ramp, Broadcast, Shuffle
57+
from tvm.tir.generic import cast
58+
59+
60+
def boolean(expr):
61+
return _ffi_api.PrimType("bool", expr)
62+
63+
64+
def int8(expr):
65+
return _ffi_api.PrimType("int8", expr)
66+
67+
68+
def int16(expr):
69+
return _ffi_api.PrimType("int16", expr)
70+
71+
72+
def int32(expr):
73+
return _ffi_api.PrimType("int32", expr)
74+
75+
76+
def int64(expr):
77+
return _ffi_api.PrimType("int64", expr)
78+
79+
80+
def uint8(expr):
81+
return _ffi_api.PrimType("uint8", expr)
82+
83+
84+
def uint16(expr):
85+
return _ffi_api.PrimType("uint16", expr)
86+
87+
88+
def uint32(expr):
89+
return _ffi_api.PrimType("uint32", expr)
90+
91+
92+
def uint64(expr):
93+
return _ffi_api.PrimType("uint64", expr)
94+
95+
96+
def float8(expr):
97+
return _ffi_api.PrimType("float8", expr)
98+
99+
100+
def float16(expr):
101+
return _ffi_api.PrimType("float16", expr)
102+
103+
104+
def float32(expr):
105+
return _ffi_api.PrimType("float32", expr)
106+
107+
108+
def float64(expr):
109+
return _ffi_api.PrimType("float64", expr)
110+
111+
112+
def min(a, b, span=None):
113+
"""Compute the minimum value of two expressions.
114+
115+
Parameters
116+
----------
117+
a : PrimExpr
118+
The left hand operand
119+
120+
b : PrimExpr
121+
The right hand operand
122+
123+
span : Optional[Span]
124+
The location of this operator in the source.
125+
126+
Returns
127+
-------
128+
res : PrimExpr
129+
The result expression.
130+
131+
Note
132+
----
133+
This is the default integer division behavior in C.
134+
"""
135+
return _ffi_api.min(a, b, span) # type: ignore
136+
137+
138+
def max(a, b, span=None):
139+
"""Compute the maximum value of two expressions.
140+
141+
Parameters
142+
----------
143+
a : PrimExpr
144+
The left hand operand
145+
146+
b : PrimExpr
147+
The right hand operand
148+
149+
span : Optional[Span]
150+
The location of this operator in the source.
151+
152+
Returns
153+
-------
154+
res : PrimExpr
155+
The result expression.
156+
157+
Note
158+
----
159+
This is the default integer division behavior in C.
160+
"""
161+
return _ffi_api.max(a, b, span) # type: ignore

python/tvm/script/builder/tir/var.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
from . import _ffi_api
2121

2222

23-
def Buffer( # pylint: disable=invalid-name
23+
def Buffer( # pylint: disable=invalid-name
2424
shape,
2525
dtype,
2626
name="buffer",
2727
storage_scope="",
2828
) -> tir.Buffer:
29-
return _ffi_api.Buffer(shape, dtype, name, storage_scope) # pylint: disable=no-member # type: ignore
29+
return _ffi_api.Buffer(
30+
shape, dtype, name, storage_scope
31+
) # pylint: disable=no-member # type: ignore

python/tvm/tir/__init__.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,36 @@
4444

4545
from .function import PrimFunc, TensorIntrin, IndexMap
4646

47-
from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
48-
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
47+
from .op import call_packed, call_intrin, call_pure_extern, call_extern
48+
from .op import (
49+
call_llvm_intrin,
50+
call_llvm_pure_intrin,
51+
ret,
52+
all,
53+
any,
54+
min_value,
55+
max_value,
56+
trace,
57+
)
4958
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
5059
from .op import sin, sinh, asin, asinh
5160
from .op import cos, cosh, acos, acosh
5261
from .op import tan, tanh, atan, atan2, atanh
5362
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
54-
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
63+
from .op import (
64+
trunc,
65+
abs,
66+
round,
67+
nextafter,
68+
nearbyint,
69+
power,
70+
popcount,
71+
fmod,
72+
if_then_else,
73+
)
5574
from .op import isnan, isfinite, isinf, copysign
5675
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
57-
from .op import comm_reducer, min, max, sum
76+
from .op import comm_reducer, min, max, sum, infinity, reinterpret
5877
from .op import q_multiply_shift
5978

6079
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

python/tvm/tir/op.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ def call_pure_extern(dtype, func_name, *args, span=None):
151151
The call expression.
152152
"""
153153
return Call(
154-
dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span
154+
dtype,
155+
Op.get("tir.call_pure_extern"),
156+
convert((StringImm(func_name),) + args),
157+
span,
155158
)
156159

157160

@@ -178,7 +181,10 @@ def call_extern(dtype, func_name, *args, span=None):
178181
The call expression.
179182
"""
180183
return Call(
181-
dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span
184+
dtype,
185+
Op.get("tir.call_extern"),
186+
convert((StringImm(func_name),) + args),
187+
span=span,
182188
)
183189

184190

@@ -210,7 +216,11 @@ def call_llvm_intrin(dtype, name, *args, span=None):
210216
llvm_id = codegen.llvm_lookup_intrinsic_id(name)
211217
assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
212218
return call_intrin(
213-
dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
219+
dtype,
220+
Op.get("tir.call_llvm_intrin"),
221+
tvm.tir.const(llvm_id, "uint32"),
222+
*args,
223+
span=span,
214224
)
215225

216226

@@ -394,6 +404,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
394404
return _ffi_api.max_value(dtype, span) # type: ignore
395405

396406

407+
def infinity(dtype: str, span: Optional[Span] = None) -> Any:
408+
"""infinity value of dtype
409+
410+
Parameters
411+
----------
412+
dtype : str
413+
The data type.
414+
415+
span : Optional[Span]
416+
The location of this operator in the source code.
417+
418+
Returns
419+
-------
420+
value : tvm.Expr
421+
The infinity value of dtype.
422+
"""
423+
return _ffi_api.infinity(dtype, span) # type: ignore
424+
425+
426+
def reinterpret(dtype, value, span=None) -> Any:
427+
"""infinity value of dtype
428+
429+
Parameters
430+
----------
431+
dtype : str
432+
The data type.
433+
434+
value : PrimExpr
435+
The input value.
436+
437+
span : Optional[Span]
438+
The location of this operator in the source code.
439+
440+
Returns
441+
-------
442+
value : tvm.Expr
443+
The reinterpret cast value of dtype.
444+
"""
445+
return _ffi_api.reinterpret(dtype, value, span) # type: ignore
446+
447+
397448
def exp(x):
398449
"""Take exponential of input x.
399450

src/script/builder/tir/op.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "./op.h"
20+
21+
namespace tvm {
22+
namespace script {
23+
namespace builder {
24+
namespace tir {
25+
26+
PrimExpr prim_type(String type_name, PrimExpr expr) {
27+
return cast(DataType(runtime::String2DLDataType(type_name)), expr);
28+
}
29+
30+
TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(prim_type);
31+
TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
32+
return tvm::min(a, b, span);
33+
});
34+
TVM_REGISTER_GLOBAL("script.builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) {
35+
return tvm::max(a, b, span);
36+
});
37+
38+
} // namespace tir
39+
} // namespace builder
40+
} // namespace script
41+
} // namespace tvm

0 commit comments

Comments
 (0)