Skip to content

Commit 8d96808

Browse files
tqchensergei-mironov
authored andcommitted
[COMPILER] Initial compiler infra (apache#12)
1 parent 53c3b92 commit 8d96808

39 files changed

Lines changed: 2148 additions & 119 deletions

nnvm/Makefile

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ include $(config)
1111

1212
export LDFLAGS = -pthread -lm
1313
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
14-
CFLAGS += -Itvm/include -Itvm/dlpack/include
14+
CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src
1515

1616
ifdef DMLC_CORE_PATH
1717
CFLAGS += -I$(DMLC_CORE_PATH)/include
@@ -38,7 +38,7 @@ PLUGIN_OBJ =
3838
include $(NNVM_PLUGINS)
3939

4040
# specify tensor path
41-
.PHONY: clean all test lint doc cython cython3 cyclean
41+
.PHONY: clean all test lint pylint doc cython cython3 cyclean
4242

4343
UNAME_S := $(shell uname -s)
4444

@@ -55,7 +55,7 @@ endif
5555
all: lib/libnnvm.a lib/libnnvm_top.$(SHARED_LIBRARY_SUFFIX) lib/libnnvm_top_runtime.$(SHARED_LIBRARY_SUFFIX)
5656

5757
SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc)
58-
SRC_TOP = $(wildcard src/top/*.cc, src/top/*/*.cc src/runtime/*.cc)
58+
SRC_TOP = $(wildcard src/top/*/*.cc src/runtime/*.cc src/compiler/*.cc src/compiler/*/*.cc)
5959
ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC))
6060
TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_TOP))
6161
ALL_DEP = $(ALL_OBJ)
@@ -90,9 +90,12 @@ cython3:
9090
cyclean:
9191
rm -rf python/nnvm/*/*.so python/nnvm/*/*.dylib python/nnvm/*/*.cpp
9292

93-
lint:
93+
lint: pylint
9494
python dmlc-core/scripts/lint.py nnvm cpp include src
9595

96+
pylint:
97+
pylint python/nnvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc
98+
9699
doc:
97100
doxygen docs/Doxyfile
98101

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file contrib_op_param.h
4+
* \brief Additional parameters for compiler optimized operators.
5+
*/
6+
#ifndef NNVM_COMPILER_CONTRIB_OP_PARAM_H_
7+
#define NNVM_COMPILER_CONTRIB_OP_PARAM_H_
8+
9+
#include <dmlc/parameter.h>
10+
#include <string>
11+
12+
namespace nnvm {
13+
namespace compiler {
14+
15+
/*! \brief Parameters of layout transform operator */
16+
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
17+
std::string src_layout;
18+
std::string dst_layout;
19+
20+
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
21+
DMLC_DECLARE_FIELD(src_layout);
22+
DMLC_DECLARE_FIELD(dst_layout);
23+
}
24+
};
25+
} // namespace compiler
26+
} // namespace nnvm
27+
28+
#endif // NNVM_COMPILER_CONTRIB_OP_PARAM_H_
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file op_attr_types.h
4+
* \brief The Expr and related elements in DataFlow construction.
5+
*/
6+
#ifndef NNVM_COMPILER_OP_ATTR_TYPES_H_
7+
#define NNVM_COMPILER_OP_ATTR_TYPES_H_
8+
9+
#include <tvm/expr.h>
10+
#include <tvm/tensor.h>
11+
#include <tvm/schedule.h>
12+
#include <tvm/packed_func_ext.h>
13+
#include <tvm/runtime/registry.h>
14+
#include <nnvm/op_attr_types.h>
15+
#include <nnvm/graph_attr_types.h>
16+
#include <nnvm/graph.h>
17+
#include <vector>
18+
#include <string>
19+
20+
namespace nnvm {
21+
namespace compiler {
22+
23+
using ::tvm::Array;
24+
using ::tvm::Tensor;
25+
using ::tvm::Schedule;
26+
27+
/*! \brief operator pattern used in graph fusion */
28+
enum OpPatternKind : int {
29+
// Elementwise operation
30+
kElemWise = 0,
31+
// Broadcast operation
32+
kBroadcast = 1,
33+
// Complex operation, can fuse bcast in input/outputs
34+
// but cannot chain another complex op
35+
kComplex = 2,
36+
// Extern operation, cannot fuse anything.
37+
kExtern = 3
38+
};
39+
40+
/*! \brief the operator pattern */
41+
using TOpPattern = int;
42+
43+
/*!
44+
* \brief Computation description interface
45+
* \param attrs The attribute of the node.
46+
* \param inputs The input tensors(placeholders)
47+
* \return The output description of the tensor.
48+
*/
49+
using FTVMCompute = std::function<
50+
Array<Tensor>
51+
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
52+
53+
/*!
54+
* \brief Build the computation schedule for
55+
* op whose root is at current op.
56+
* \param attrs The attribute of the node.
57+
* \param outs The output tensors.
58+
* \param target The build target.
59+
* \return schedule The computation schedule.
60+
*/
61+
using FTVMSchedule = std::function<
62+
Schedule(const NodeAttrs& attrs,
63+
const Array<Tensor>& outs,
64+
const std::string& target)>;
65+
66+
/*! \brief Layout Information about an entry */
67+
using TLayoutInfo = std::string;
68+
69+
/*!
70+
* \brief The producer consumer function of node layout
71+
* \param attrs The attribute of the node.
72+
* \param ilayouts The input layouts that the node request.
73+
* \param olayouts The output layouts that the node produce.
74+
* \return bool The success flag.
75+
*/
76+
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
77+
std::vector<TLayoutInfo> *ilayouts,
78+
std::vector<TLayoutInfo> *olayouts)>;
79+
80+
/*!
81+
* \brief Transform from normal operator to vectorized operator
82+
* \param node The source node.
83+
* \return Transformed vectorized op.
84+
*/
85+
using FTVMVectorizedOp = std::function<nnvm::NodePtr (const nnvm::Node* node)>;
86+
87+
} // namespace compiler
88+
} // namespace nnvm
89+
#endif // NNVM_COMPILER_OP_ATTR_TYPES_H_
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*!
2+
* Copyright (c) 2017 by Contributors
3+
* \file packed_func_ext.h
4+
* \brief Extension to enable packed functionn for nnvm types
5+
*/
6+
#ifndef NNVM_COMPILER_PACKED_FUNC_EXT_H_
7+
#define NNVM_COMPILER_PACKED_FUNC_EXT_H_
8+
9+
#include <tvm/runtime/packed_func.h>
10+
#include <tvm/runtime/registry.h>
11+
#include <nnvm/graph.h>
12+
#include <nnvm/symbolic.h>
13+
#include <string>
14+
#include <unordered_map>
15+
16+
namespace nnvm {
17+
namespace compiler {
18+
19+
using tvm::runtime::PackedFunc;
20+
21+
using AttrDict = std::unordered_map<std::string, std::string>;
22+
23+
/*!
24+
* \brief Get PackedFunction from global registry and
25+
* report error if it does not exist
26+
* \param name The name of the function.
27+
* \return The created PackedFunc.
28+
*/
29+
inline const PackedFunc& GetPackedFunc(const std::string& name) {
30+
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
31+
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
32+
return *pf;
33+
}
34+
} // namespace compiler
35+
} // namespace nnvm
36+
37+
// Enable the graph and symbol object exchange.
38+
namespace tvm {
39+
namespace runtime {
40+
41+
template<>
42+
struct extension_class_info<nnvm::Symbol> {
43+
static const int code = 16;
44+
};
45+
46+
template<>
47+
struct extension_class_info<nnvm::Graph> {
48+
static const int code = 17;
49+
};
50+
51+
template<>
52+
struct extension_class_info<nnvm::compiler::AttrDict> {
53+
static const int code = 18;
54+
};
55+
} // namespace runtime
56+
} // namespace tvm
57+
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_

nnvm/include/nnvm/op_attr_types.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ template<typename AttrType>
7272
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
7373
std::vector<AttrType> *in_attrs,
7474
std::vector<AttrType> *out_attrs)>;
75+
76+
/*!
77+
* \brief Get attribute dictionary from node.
78+
*
79+
* \param attrs The attributes of the node.
80+
* \return The attribute dict.
81+
* \note Register under "FUpdateAttrDict"
82+
*/
83+
using FGetAttrDict = std::function<
84+
std::unordered_map<std::string, std::string>
85+
(const NodeAttrs& attrs)>;
86+
7587
/*!
7688
* \brief Shape inference function.
7789
* Update the shapes given the input shape information.

nnvm/include/nnvm/top/README

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
NNVM Core Operator Specs
1+
NNVM Core Operator and Compiler

nnvm/python/nnvm/_base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
# coding: utf-8
2-
# pylint: disable=invalid-name
2+
# pylint: disable=invalid-name, unused-import
33
""" ctypes library of nnvm and helper functions """
44
from __future__ import absolute_import
55

66
import sys
7-
import os
87
import ctypes
98
import numpy as np
109
from . import libinfo
1110

12-
__all__ = ['NNNetError']
11+
try:
12+
import tvm
13+
except ImportError:
14+
pass
15+
1316
#----------------------------
1417
# library loading
1518
#----------------------------
@@ -181,7 +184,7 @@ def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True)
181184
param_keys.add(key)
182185
type_info = py_str(arg_types[i])
183186
ret = '%s : %s' % (key, type_info)
184-
if len(arg_descs[i]) != 0:
187+
if arg_descs[i]:
185188
ret += '\n ' + py_str(arg_descs[i])
186189
param_str.append(ret)
187190
doc_str = ('Parameters\n' +

nnvm/python/nnvm/_ctypes/symbol.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# coding: utf-8
2-
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
2+
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines,
3+
# pylint: disable=len-as-condition, consider-iterating-dictionary
34
"""Symbolic configuration API."""
45
from __future__ import absolute_import as _abs
56

67
import copy
78
import ctypes
89
import sys
910
from .._base import _LIB
10-
from .._base import c_array, c_str, nn_uint, py_str, string_types
11+
from .._base import c_array, c_str, nn_uint, py_str
1112
from .._base import SymbolHandle, OpHandle
1213
from .._base import check_call, ctypes2docstring
1314
from ..name import NameManager
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Namespace for NNVM-TVM compiler toolchain"""
2+
from __future__ import absolute_import
3+
4+
import tvm
5+
6+
from . import build_module
7+
from . build_module import build
8+
9+
from .. import symbol as _symbol
10+
from .. import graph as _graph
11+
12+
from .registry import OpPattern
13+
from .registry import register_compute, register_schedule, register_pattern
14+
15+
from .. import top as _top
16+
17+
tvm.register_extension(_symbol.Symbol, _symbol.Symbol)
18+
tvm.register_extension(_graph.Graph, _graph.Graph)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# pylint: disable=invalid-name
2+
"""Namespace for building operators."""
3+
from __future__ import absolute_import as _abs
4+
5+
import tvm
6+
from . import graph_attr
7+
from .. import graph as _graph
8+
9+
@tvm.register_func("nnvm.compiler.lower")
10+
def _lower(sch, inputs, func_name):
11+
f = tvm.lower(sch, inputs, name=func_name)
12+
return f if isinstance(
13+
f, (tvm.container.Array, tuple, list)) else [f]
14+
15+
16+
@tvm.register_func("nnvm.compiler.build_target")
17+
def _build(funcs, target):
18+
return tvm.build(funcs, target=target)
19+
20+
21+
_move_module = tvm.get_global_func("nnvm.compiler._move_module")
22+
23+
24+
def optimize(graph):
25+
"""Perform graph optimization
26+
27+
Parameters
28+
----------
29+
graph : Graph
30+
The graph to be used in lowering.
31+
32+
Returns
33+
-------
34+
graph : Graph
35+
The optimized execution graph.
36+
"""
37+
return graph
38+
39+
40+
def build(graph, target, shape, dtype="float32"):
41+
"""Build graph into runtime library.
42+
43+
This is the final step of graph compilation.
44+
45+
Parameters
46+
----------
47+
graph : Graph
48+
The graph to be used in lowering
49+
50+
target : str
51+
The build target
52+
53+
shape : dict of str to tuple
54+
The input shape to the graph
55+
56+
dtype : str or dict of str to str
57+
The input types to the graph
58+
59+
Returns
60+
-------
61+
graph : Graph
62+
The final execution graph.
63+
64+
libmod : tvm.Module
65+
The modue that comes with the execution graph
66+
"""
67+
if not isinstance(target, str):
68+
raise TypeError("require target to be str")
69+
if not isinstance(shape, dict):
70+
raise TypeError("require shape to be dict")
71+
72+
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
73+
graph = graph_attr.set_shape(graph, shape)
74+
graph = graph_attr.set_dtype(graph, dtype)
75+
graph._set_json_attr("target", target, "str")
76+
graph = graph.apply("InferShape").apply("InferType")
77+
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
78+
libmod = _move_module(graph)
79+
return graph, libmod

0 commit comments

Comments
 (0)