forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathschedule.py
More file actions
656 lines (514 loc) · 20.4 KB
/
schedule.py
File metadata and controls
656 lines (514 loc) · 20.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""The computation schedule api of TVM."""
import collections
import inspect
from typing import Callable, List
import tvm._ffi
from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import container as _container
from tvm.tir import IterVar, Buffer, Var, IndexMap
from . import tensor as _tensor
from . import _ffi_api
@tvm._ffi.register_object
class Split(Object):
"""Split operation on axis."""
@tvm._ffi.register_object
class Fuse(Object):
"""Fuse operation on axis."""
@tvm._ffi.register_object
class Singleton(Object):
"""Singleton axis."""
def create_schedule(ops):
"""Create a schedule for list of ops
Parameters
----------
ops : list of Operations
The source expression.
Returns
-------
sch : schedule.Schedule
The created schedule.
"""
if not isinstance(ops, (list, _container.Array)):
ops = [ops]
return _ffi_api.CreateSchedule(ops)
@tvm._ffi.register_object
class Schedule(Object):
"""Schedule for all the stages."""
def __getitem__(self, k):
if isinstance(k, _tensor.Tensor):
k = k.op
if not isinstance(k, _tensor.Operation):
raise ValueError("Expect schedule key to be Tensor or Operation")
if k not in self.stage_map:
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]
def normalize(self):
"""Build a normalized schedule from the current schedule.
Insert necessary rebase to make certain iter var to start from 0.
This is needed before bound inference and followup step.
Returns
-------
sch : Schedule
The normalized schedule.
"""
return _ffi_api.ScheduleNormalize(self)
def create_group(self, outputs, inputs, include_inputs=False):
"""Create stage group by giving output and input boundary.
The operators between outputs and inputs are placed as member of group.
outputs are include in the group, while inputs are not included.
Parameters
----------
outputs : list of Tensors
The outputs of the group.
inputs : list of Tensors
The inputs of the group.
include_inputs : boolean, optional
Whether include input operations in the group if they are used by outputs.
Returns
-------
group : Stage
A virtual stage represents the group, user can use compute_at to move
the attachment point of the group.
"""
if isinstance(outputs, _tensor.Tensor):
outputs = [outputs]
if isinstance(inputs, _tensor.Tensor):
inputs = [inputs]
return _ffi_api.ScheduleCreateGroup(self, outputs, inputs, include_inputs)
def cache_read(self, tensor, scope, readers):
"""Create a cache read of original tensor for readers.
This will mutate the body of the readers.
A new cache stage will be created for the tensor.
Call this before doing any split/fuse schedule.
Parameters
----------
tensor : Tensor
The tensor to be cached.
scope : str
The scope of cached
readers : list of Tensor or Operation
The readers to read the cache.
Returns
-------
cache : Tensor
The created cache tensor.
"""
if isinstance(readers, (_tensor.Tensor, _tensor.Operation)):
readers = [readers]
readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers]
return _ffi_api.ScheduleCacheRead(self, tensor, scope, readers)
def cache_write(self, tensor, scope):
"""Create a cache write of original tensor, before storing into tensor.
This will mutate the body of the tensor.
A new cache stage will created before feed into the tensor.
This function can be used to support data layout transformation.
If there is a split/fuse/reorder on the data parallel axis of tensor
before cache_write is called. The intermediate cache stores
the data in the layout as the iteration order of leave axis.
The data will be transformed back to the original layout in the original tensor.
User can further call compute_inline to inline the original layout and keep
the data stored in the transformed layout.
Parameters
----------
tensor : Tensor, list or tuple
The tensors to be feed to. All the tensors must be produced by one computeOp
scope : str
The scope of cached
Returns
-------
cache : Tensor
The created cache tensor.
"""
return _ffi_api.ScheduleCacheWrite(self, tensor, scope)
def rfactor(self, tensor, axis, factor_axis=0):
"""Factor a reduction axis in tensor's schedule to be an explicit axis.
This will create a new stage that generated the new tensor with axis
as the first dimension. The tensor's body will be rewritten as a reduction
over the factored tensor.
Parameters
----------
tensor : Tensor
The tensor to be factored.
axis : IterVar
The reduction axis in the schedule to be factored.
factor_axis : int
The position where the new axis is placed.
Returns
-------
tfactor : Tensor or Array of Tensor
The created factored tensor.
"""
factored = _ffi_api.ScheduleRFactor(self, tensor, axis, factor_axis)
return factored[0] if len(factored) == 1 else factored
@tvm._ffi.register_object
class Stage(Object):
"""A Stage represents schedule for one operation."""
def split(self, parent, factor=None, nparts=None):
"""Split the stage either by factor providing outer scope, or both
Parameters
----------
parent : IterVar
The parent iter var.
factor : Expr, optional
The splitting factor
nparts : Expr, optional
The number of outer parts.
Returns
-------
outer : IterVar
The outer variable of iteration.
inner : IterVar
The inner variable of iteration.
"""
if nparts is not None:
if factor is not None:
raise ValueError("Do not need to provide both outer and nparts")
outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts)
else:
if factor is None:
raise ValueError("Either nparts or factor need to be provided")
outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor)
return outer, inner
def fuse(self, *args):
"""Fuse multiple consecutive iteration variables into a single iteration variable.
fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1])
The order is from outer to inner.
Parameters
----------
args : list of IterVars
Itervars that proceeds each other
Returns
-------
fused : IterVar
The fused variable of iteration.
"""
fused = _ffi_api.StageFuse(self, args)
return fused
def set_scope(self, scope):
"""Set the thread scope of this stage
Parameters
----------
scope : str
The thread scope of this stage
"""
return _ffi_api.StageSetScope(self, scope)
def bind(self, ivar, thread_ivar):
"""Bind ivar to thread index thread_ivar
Parameters
----------
ivar : IterVar
The iteration to be binded to thread.
thread_ivar : IterVar
The thread to be binded.
"""
_ffi_api.StageBind(self, ivar, thread_ivar)
def env_threads(self, threads):
"""Mark threads to be launched at the outer scope of composed op.
Parameters
----------
threads : list of threads
The threads to be launched.
"""
if isinstance(threads, IterVar):
threads = [threads]
_ffi_api.StageEnvThreads(self, threads)
def set_store_predicate(self, predicate):
"""Set predicate under which store to the array can be performed.
Use this when there are duplicated threads doing the same store and we only
need one of them to do the store.
Parameters
----------
predicate : Expr
The guard condition fo store.
"""
_ffi_api.StageSetStorePredicate(self, predicate)
def compute_at(self, parent, scope):
"""Attach the stage at parent's scope
Parameters
----------
parent : Stage
The parent stage
scope : IterVar
The loop scope t be attached to.
"""
_ffi_api.StageComputeAt(self, parent, scope)
def compute_inline(self):
"""Mark stage as inline
Parameters
----------
parent : Stage
The parent stage
"""
_ffi_api.StageComputeInline(self)
def compute_root(self):
"""Attach the stage at parent, and mark it as root
Parameters
----------
parent : Stage
The parent stage
"""
_ffi_api.StageComputeRoot(self)
def reorder(self, *args):
"""reorder the arguments in the specified order.
Parameters
----------
args : list of IterVar
The order to be ordered
"""
_ffi_api.StageReorder(self, args)
def tile(self, x_parent, y_parent, x_factor, y_factor):
"""Perform tiling on two dimensions
The final loop order from outmost to inner most are
[x_outer, y_outer, x_inner, y_inner]
Parameters
----------
x_parent : IterVar
The original x dimension
y_parent : IterVar
The original y dimension
x_factor : Expr
The stride factor on x axis
y_factor : Expr
The stride factor on y axis
Returns
-------
x_outer : IterVar
Outer axis of x dimension
y_outer : IterVar
Outer axis of y dimension
x_inner : IterVar
Inner axis of x dimension
p_y_inner : IterVar
Inner axis of y dimension
"""
x_outer, y_outer, x_inner, y_inner = _ffi_api.StageTile(
self, x_parent, y_parent, x_factor, y_factor
)
return x_outer, y_outer, x_inner, y_inner
def vectorize(self, var):
"""Vectorize the iteration.
Parameters
----------
var : IterVar
The iteration to be vectorize
"""
_ffi_api.StageVectorize(self, var)
def tensorize(self, var, tensor_intrin):
"""Tensorize the computation enclosed by var with tensor_intrin
Parameters
----------
var : IterVar
The iteration boundary of tensorization.
tensor_intrin : TensorIntrin
The tensor intrinsic used for computation.
"""
_ffi_api.StageTensorize(self, var, tensor_intrin)
def unroll(self, var):
"""Unroll the iteration.
Parameters
----------
var : IterVar
The iteration to be unrolled.
"""
_ffi_api.StageUnroll(self, var)
def parallel(self, var):
"""Parallelize the iteration.
Parameters
----------
var : IterVar
The iteration to be parallelized.
"""
_ffi_api.StageParallel(self, var)
def pragma(self, var, pragma_type, pragma_value=None):
"""Annotate the iteration with pragma
This will translate to a pragma_scope surrounding
the corresponding loop generated.
Useful to support experimental features and extensions.
Parameters
----------
var : IterVar
The iteration to be anotated
pragma_type : str
The pragma string to be annotated
pragma_value : Expr, optional
The pragma value to pass along the pragma
Note
----
Most pragmas are advanced/experimental features
and may subject to change. List of supported pragmas:
- **debug_skip_region**
Force skip the region marked by the axis and turn it into no-op.
This is useful for debug purposes.
- **parallel_launch_point**
Specify to launch parallel threads outside the
specified iteration loop. By default the threads
launch at the point of parallel construct.
This pragma moves the launching point to even outer scope.
The threads are launched once and reused across multiple
parallel constructs as BSP style program.
- **parallel_barrier_when_finish**
Insert a synchronization barrier between working threads
after the specified loop iteration finishes.
- **parallel_stride_pattern**
Hint parallel loop to execute in strided pattern.
:code:`for (int i = task_id; i < end; i += num_task)`
"""
if isinstance(pragma_value, string_types):
pragma_value = convert(pragma_value)
_ffi_api.StagePragma(self, var, pragma_type, pragma_value)
def prefetch(self, tensor, var, offset):
"""Prefetch the specified variable
Parameters
----------
tensor : Tensor
The tensor to be prefetched
var : IterVar
The loop point at which the prefetching is applied
offset : Expr
The number of iterations to be prefetched before actual execution
"""
_ffi_api.StagePrefetch(self, tensor, var, offset)
def storage_align(self, axis, factor, offset):
"""Set alignment requirement for specific axis
This ensures that stride[axis] == k * factor + offset for some k.
This is useful to set memory layout to for more friendly memory
access pattern. For example, we can set alignment to be
factor=2, offset=1 to avoid bank conflict for thread access on
higher dimension in GPU shared memory.
Parameters
----------
axis : IterVar
The axis dimension to be aligned.
factor : int
The factor in alignment specification.
offset : int
The offset in the alignment specification.
"""
_ffi_api.StageStorageAlign(self, axis, factor, offset)
def double_buffer(self):
"""Compute the current stage via double buffering.
This can only be applied to intermediate stage.
This will double the storage cost of the current stage.
Can be useful to hide load latency.
"""
_ffi_api.StageDoubleBuffer(self)
def rolling_buffer(self):
"""Compute the current stage via rolling buffering.
This can only be applied to intermediate stage.
This will change the storage cost of the current stage.
"""
_ffi_api.StageRollingBuffer(self)
def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr]]):
"""Defines the layout transformation for the current stage's tensor.
The map from initial_indices to final_indices must be an
invertible affine transformation. This method may be called
more than once for a given tensor, in which case each
transformation is applied sequentially.
If the stage is a ComputeOp, then the iteration order of the
compute stage is rewritten to be a row-major traversal of the
tensor, and the new loop iteration variables are returned.
For all other stages, the loop iteration order is unmodified,
and the return value is None.
Parameters
----------
mapping_function : Callable[..., List[tvm.tir.PrimExpr]]
A callable that accepts N arguments of type tvm.tir.Var,
and outputs a list of PrimExpr. The input arguments
represent the location of a value in the current stage's
tensor, using the pre-transformation layout. The return
value of the function gives the location of that value in
the current stage's tensor, using the post-transformation
layout.
Returns
-------
new_iter_vars : Optional[List[tvm.tir.IterVar]]
If the stage is a ComputeOp, then the return will be the
updated loop iteration variables over the data array, in
the same order as the output values from the
`mapping_function`.
Otherwise, the return value is None.
Examples
--------
.. code-block:: python
# ``A`` is a tensor whose compute definition is in NHWC
# format, and should be transformed into NCHWc format.
s[A].transform_layout(
lambda n,h,w,c: [n, c//4, h, w, c%4]
)
.. code-block:: python
# ``A`` is a tensor whose compute definition is in an
# arbitrary format, and should be transformed such that
# the last index is split, with the slower-changing index
# of the split placed at the slowest changing dimension.
s[A].transform_layout(
lambda *indices, i: [i//4, *indices, i%4]
)
.. code-block:: python
# ``B`` is a tensor defined by te.compute to be a copy of
# ``A`, and should be transformed such that ``B``'s layout
# is a transpose of ``A``'s layout. The loop iteration
# that computes ``B`` will correspond to ``B``'s memory
# layout.
A = te.placeholder([n,m])
B = te.compute(A.shape, lambda i,j: A[i,j])
s = te.create_schedule(B.op)
s[B].transform_layout(lambda i,j: [j,i])
"""
ndim = len(self.op.output(0).shape)
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim=ndim)
new_iter_vars = _ffi_api.StageTransformLayout(
self, index_map.initial_indices, index_map.final_indices
)
_ffi_api.StageSetAxisSeparators(self, axis_separators)
return new_iter_vars or None
@tvm._ffi.register_object
class SpecializedCondition(Object):
"""Specialized condition to enable op specialization."""
def __init__(self, conditions):
"""Create a specialized condition.
.. note::
Conditions are represented in conjunctive joint form (CNF).
Each condition should be a simple expression, e.g., n > 16,
m % 8 == 0, etc., where n, m are tvm.Var that represents a
dimension in the tensor shape.
Parameters
----------
conditions : List of tvm.Expr
List of conditions in conjunctive joint form (CNF).
"""
if not isinstance(conditions, (list, _container.Array)):
conditions = [conditions]
self.__init_handle_by_constructor__(_ffi_api.CreateSpecializedCondition, conditions)
@staticmethod
def current():
"""Returns the current specialized condition"""
return _ffi_api.GetCurrentSpecialization()
def __enter__(self):
_ffi_api.EnterSpecializationScope(self)
return self
def __exit__(self, ptype, value, trace):
_ffi_api.ExitSpecializationScope(self)
# Sentinel value used to indicate which groups of pre-flattening axes
# should be used to post-flattening axes axes. Moved from
# te.AXIS_SEPARATOR to tir.IndexMap.AXIS_SEPARATOR for general use,
# maintained here for backwards compatibility.
AXIS_SEPARATOR = IndexMap.AXIS_SEPARATOR
tvm._ffi._init_api("schedule", __name__)