Skip to content

Commit 3a2cc94

Browse files
committed
add _ prefix
1 parent 7351c42 commit 3a2cc94

1 file changed

Lines changed: 83 additions & 82 deletions

File tree

python/tvm/relay/frontend/pytorch.py

Lines changed: 83 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -733,54 +733,49 @@ def _convert_elemwise_input(data, input_type):
733733
}
734734

735735

736-
def run_jit_passes(graph):
736+
def _run_jit_passes(graph):
737737
""" The inline pass is necessary to unwrap prim::CallMethod """
738738
import torch
739739
if version.parse(torch.__version__) >= version.parse("1.4.0"):
740740
torch._C._jit_pass_inline(graph)
741741

742742

743-
def is_int_seq(seq):
743+
def _is_int_seq(seq):
744744
return len(seq) > 0 and all([isinstance(i, int) for i in seq])
745745

746746

747-
def get_tensor_and_var(torch_tensor, name):
747+
def _get_tensor_and_var(torch_tensor, name):
748748
tensor = tvm.nd.array(torch_tensor.cpu().numpy())
749749
var = _expr.var(name, shape=tensor.shape)
750750
return tensor, var
751751

752752

753-
def get_output_name(node):
753+
def _get_output_name(node):
754754
assert node.outputsSize() == 1
755755
return node.output().debugName()
756756

757757

758-
def get_output_names(node):
758+
def _get_output_names(node):
759759
return [output.debugName() for output in node.outputs()]
760760

761761

762-
def get_input_names(node_or_graph):
762+
def _get_input_names(node_or_graph):
763763
return [inp.debugName() for inp in node_or_graph.inputs()]
764764

765765

766-
def get_op_inputs(op_node, outputs, output_index_map):
766+
def _get_op_inputs(op_node, outputs, output_index_map):
767767
input_names = [output_index_map[name]
768-
for name in get_input_names(op_node)]
768+
for name in _get_input_names(op_node)]
769769
return [outputs[name] for name in input_names]
770770

771771

772-
def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
772+
def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
773773
for output_name, output in name_output_pairs:
774774
output_index_map[output_name] = len(outputs)
775775
outputs.append(output)
776776

777777

778-
def get_all_op_names(graph):
779-
nodes = list(graph.nodes())
780-
return set(node.kind() for node in nodes)
781-
782-
783-
def report_missing_conversion(op_names):
778+
def _report_missing_conversion(op_names):
784779
""" Check if all ops in an input graph are supported by TVM """
785780
known_ops = ["prim::Constant", "prim::GetAttr",
786781
"prim::ListConstruct", "prim::ListUnpack",
@@ -795,66 +790,26 @@ def report_missing_conversion(op_names):
795790
raise NotImplementedError(msg)
796791

797792

798-
def getattr_attr_name(node):
793+
def _getattr_attr_name(node):
799794
attribute_names = node.attributeNames()
800795
assert len(attribute_names) == 1
801796
attr_name = node.s(attribute_names[0])
802797
return attr_name
803798

804799

805-
def get_full_attr_name(getattrs):
806-
return ".".join([getattr_attr_name(node) for node in getattrs])
807-
808-
809-
def get_use_chains(root_node, terminate=lambda _: False):
810-
"""
811-
Track a chain of users of this node forward, returning a list of chains
812-
See get_attr_chains below for its usage
813-
"""
814-
def concat_lists(lists):
815-
return itertools.chain.from_iterable(lists)
816-
817-
def inner(current, accum):
818-
users = []
819-
for output in current.outputs():
820-
users += [use.user for use in output.uses()]
821-
822-
if not users or terminate(users):
823-
return [accum]
824-
825-
return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])
826-
827-
return inner(root_node, [root_node])
828-
829-
830-
def get_attr_chains(root_getattr_node):
831-
""" Returns chains of attribute access starting from root_getattr_node
832-
833-
For example, given attribute "block", as in "self.block" when "self" points
834-
to the top level torch.nn.Module, it returns lists of attribute "chains",
835-
e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
836-
837-
These sets of attributes form full attribute accessors. For example,
838-
"self.block.1", "self.block.2" will return the second and third submodule,
839-
and "self.block.0._packed_params" will return the parameters of the first
840-
submodule.
841-
"""
842-
def terminate(users):
843-
next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
844-
return len(next_attrs) == 0
845-
846-
return get_use_chains(root_getattr_node, terminate)
800+
def _getattr_full_name(getattrs):
801+
return ".".join([_getattr_attr_name(node) for node in getattrs])
847802

848803

849-
def get_input_types(op_node):
804+
def _get_input_types(op_node):
850805
""" Returns a torch type for each input nodes """
851806
input_list_types = []
852807
for input_node in op_node.inputs():
853808
in_ty = input_node.type()
854809
input_node_kind = in_ty.kind()
855810
if input_node_kind == 'TensorType':
856811
if in_ty.scalarType() is None:
857-
input_list_types.append('float')
812+
input_list_types.append(None)
858813
else:
859814
input_list_types.append(in_ty.scalarType().lower())
860815
elif input_node_kind == 'ListType':
@@ -874,7 +829,7 @@ def get_input_types(op_node):
874829
return input_list_types
875830

876831

877-
def get_constant(node):
832+
def _get_constant(node):
878833
""" Retrieve a constant associated with this prim::Constant node """
879834
attribute_names = node.attributeNames()
880835
num_attributes = len(attribute_names)
@@ -903,15 +858,15 @@ def get_constant(node):
903858
return None
904859

905860

906-
def get_operator_nodes(nodes):
861+
def _get_operator_nodes(nodes):
907862
""" Returns torch IR nodes that need conversion to Relay """
908863
ops = {}
909864
# Traverse nodes and add to graph
910865
for node in nodes:
911866
if node.outputsSize() > 1:
912-
node_name = "_".join(get_output_names(node))
867+
node_name = "_".join(_get_output_names(node))
913868
else:
914-
node_name = get_output_name(node)
869+
node_name = _get_output_name(node)
915870

916871
if node.kind() != "prim::GetAttr":
917872
ops[node_name] = node
@@ -930,6 +885,46 @@ def parse_inputs(graph_inputs, input_shapes):
930885
return input_vars
931886

932887

888+
def get_use_chains(root_node, terminate=lambda _: False):
889+
"""
890+
Track a chain of users of this node forward, returning a list of chains
891+
See get_attr_chains below for its usage
892+
"""
893+
def concat_lists(lists):
894+
return itertools.chain.from_iterable(lists)
895+
896+
def inner(current, accum):
897+
users = []
898+
for output in current.outputs():
899+
users += [use.user for use in output.uses()]
900+
901+
if not users or terminate(users):
902+
return [accum]
903+
904+
return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])
905+
906+
return inner(root_node, [root_node])
907+
908+
909+
def get_attr_chains(root_getattr_node):
910+
""" Returns chains of attribute access starting from root_getattr_node
911+
912+
For example, given attribute "block", as in "self.block" when "self" points
913+
to the top level torch.nn.Module, it returns lists of attribute "chains",
914+
e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
915+
916+
These sets of attributes form full attribute accessors. For example,
917+
"self.block.1", "self.block.2" will return the second and third submodule,
918+
and "self.block.0._packed_params" will return the parameters of the first
919+
submodule.
920+
"""
921+
def terminate(users):
922+
next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
923+
return len(next_attrs) == 0
924+
925+
return get_use_chains(root_getattr_node, terminate)
926+
927+
933928
def parse_params(graph, state_dict):
934929
"""
935930
Return Relay vars and TVM NDArrays for input parameters
@@ -941,19 +936,19 @@ def parse_params(graph, state_dict):
941936
seen = set()
942937

943938
for node in getattr_nodes:
944-
if get_output_name(node) in seen:
939+
if _get_output_name(node) in seen:
945940
continue
946941

947942
for getattrs in get_attr_chains(node):
948-
seen.update(map(get_output_name, getattrs))
943+
seen.update(map(_get_output_name, getattrs))
949944

950-
full_attr = get_full_attr_name(getattrs)
951-
full_attr_node_name = get_output_name(getattrs[-1])
945+
full_attr = _getattr_full_name(getattrs)
946+
full_attr_node_name = _get_output_name(getattrs[-1])
952947

953948
if full_attr in state_dict:
954949
torch_tensor = state_dict[full_attr]
955-
tensor, var = get_tensor_and_var(torch_tensor,
956-
full_attr_node_name)
950+
tensor, var = _get_tensor_and_var(torch_tensor,
951+
full_attr_node_name)
957952
param_tensors[full_attr_node_name] = tensor
958953
params[full_attr_node_name] = var
959954

@@ -964,35 +959,41 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
964959
""" Convert each Torch IR operators to Relay equivalent """
965960
for node_name, op_node in operators.items():
966961
operator = op_node.kind()
967-
inputs = get_op_inputs(op_node, outputs, output_index_map)
962+
inputs = _get_op_inputs(op_node, outputs, output_index_map)
968963

969964
if operator == "prim::Constant":
970965
output_index_map[node_name] = len(outputs)
971-
outputs.append(get_constant(op_node))
972-
elif operator == 'prim::ListConstruct' and is_int_seq(inputs):
966+
outputs.append(_get_constant(op_node))
967+
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
973968
output_index_map[node_name] = len(outputs)
974969
outputs.append(_expr.var(node_name, shape=inputs))
975970
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
976971
output_index_map[node_name] = len(outputs)
977972
outputs.append(inputs)
978973
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
979974
assert len(inputs) == 1
980-
unpacked_names = get_output_names(op_node)
981-
update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
982-
outputs, output_index_map)
975+
unpacked_names = _get_output_names(op_node)
976+
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
977+
outputs, output_index_map)
983978
else:
984979
output_index_map[node_name] = len(outputs)
985980
relay_op = _convert_map[operator]
986-
outputs.append(relay_op(inputs, get_input_types(op_node)))
981+
outputs.append(relay_op(inputs, _get_input_types(op_node)))
987982

988983
return outputs[output_index_map[ret_name]]
989984

990985

986+
def get_all_op_names(graph):
987+
""" Return all operator names in the input graph """
988+
nodes = list(graph.nodes())
989+
return set(node.kind() for node in nodes)
990+
991+
991992
def get_graph_input_names(script_module):
992993
""" Use this function to set the keys for input_shapes"""
993994
# It seems variable names could change the first time a copy is made
994995
# Use the copy of the graph here to prevent troubles later
995-
ir_inputs = get_input_names(script_module.graph.copy())
996+
ir_inputs = _get_input_names(script_module.graph.copy())
996997
return ir_inputs[1:] # remove self at the 0th arg
997998

998999

@@ -1019,9 +1020,9 @@ def from_pytorch(script_module, input_shapes):
10191020
Dict of converted parameters stored in tvm.runtime.ndarray format
10201021
"""
10211022
graph = script_module.graph.copy()
1022-
run_jit_passes(graph)
1023+
_run_jit_passes(graph)
10231024
op_names = get_all_op_names(graph)
1024-
report_missing_conversion(op_names)
1025+
_report_missing_conversion(op_names)
10251026

10261027
params = script_module.state_dict()
10271028
input_vars = parse_inputs(graph.inputs(), input_shapes)
@@ -1030,9 +1031,9 @@ def from_pytorch(script_module, input_shapes):
10301031
input_vars.update(param_vars)
10311032
outputs = list(input_vars.values())
10321033
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
1033-
ret_name = get_input_names(graph.return_node())[0]
1034+
ret_name = _get_input_names(graph.return_node())[0]
10341035

1035-
body = parse_operators(get_operator_nodes(graph.nodes()), outputs,
1036+
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
10361037
output_index_map, ret_name)
10371038
func = tvm.relay.Function(_analysis.free_vars(body), body)
10381039
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

0 commit comments

Comments
 (0)