@@ -733,54 +733,49 @@ def _convert_elemwise_input(data, input_type):
733733}
734734
735735
736- def run_jit_passes (graph ):
736+ def _run_jit_passes (graph ):
737737 """ The inline pass is necessary to unwrap prim::CallMethod """
738738 import torch
739739 if version .parse (torch .__version__ ) >= version .parse ("1.4.0" ):
740740 torch ._C ._jit_pass_inline (graph )
741741
742742
743- def is_int_seq (seq ):
743+ def _is_int_seq (seq ):
744744 return len (seq ) > 0 and all ([isinstance (i , int ) for i in seq ])
745745
746746
747- def get_tensor_and_var (torch_tensor , name ):
747+ def _get_tensor_and_var (torch_tensor , name ):
748748 tensor = tvm .nd .array (torch_tensor .cpu ().numpy ())
749749 var = _expr .var (name , shape = tensor .shape )
750750 return tensor , var
751751
752752
753- def get_output_name (node ):
753+ def _get_output_name (node ):
754754 assert node .outputsSize () == 1
755755 return node .output ().debugName ()
756756
757757
758- def get_output_names (node ):
758+ def _get_output_names (node ):
759759 return [output .debugName () for output in node .outputs ()]
760760
761761
762- def get_input_names (node_or_graph ):
762+ def _get_input_names (node_or_graph ):
763763 return [inp .debugName () for inp in node_or_graph .inputs ()]
764764
765765
766- def get_op_inputs (op_node , outputs , output_index_map ):
766+ def _get_op_inputs (op_node , outputs , output_index_map ):
767767 input_names = [output_index_map [name ]
768- for name in get_input_names (op_node )]
768+ for name in _get_input_names (op_node )]
769769 return [outputs [name ] for name in input_names ]
770770
771771
772- def update_outputs_from_pairs (name_output_pairs , outputs , output_index_map ):
772+ def _update_outputs_from_pairs (name_output_pairs , outputs , output_index_map ):
773773 for output_name , output in name_output_pairs :
774774 output_index_map [output_name ] = len (outputs )
775775 outputs .append (output )
776776
777777
778- def get_all_op_names (graph ):
779- nodes = list (graph .nodes ())
780- return set (node .kind () for node in nodes )
781-
782-
783- def report_missing_conversion (op_names ):
778+ def _report_missing_conversion (op_names ):
784779 """ Check if all ops in an input graph are supported by TVM """
785780 known_ops = ["prim::Constant" , "prim::GetAttr" ,
786781 "prim::ListConstruct" , "prim::ListUnpack" ,
@@ -795,66 +790,26 @@ def report_missing_conversion(op_names):
795790 raise NotImplementedError (msg )
796791
797792
798- def getattr_attr_name (node ):
793+ def _getattr_attr_name (node ):
799794 attribute_names = node .attributeNames ()
800795 assert len (attribute_names ) == 1
801796 attr_name = node .s (attribute_names [0 ])
802797 return attr_name
803798
804799
805- def get_full_attr_name (getattrs ):
806- return "." .join ([getattr_attr_name (node ) for node in getattrs ])
807-
808-
809- def get_use_chains (root_node , terminate = lambda _ : False ):
810- """
811- Track a chain of users of this node forward, returning a list of chains
812- See get_attr_chains below for its usage
813- """
814- def concat_lists (lists ):
815- return itertools .chain .from_iterable (lists )
816-
817- def inner (current , accum ):
818- users = []
819- for output in current .outputs ():
820- users += [use .user for use in output .uses ()]
821-
822- if not users or terminate (users ):
823- return [accum ]
824-
825- return concat_lists ([inner (nxt , accum + [nxt ]) for nxt in users ])
826-
827- return inner (root_node , [root_node ])
828-
829-
830- def get_attr_chains (root_getattr_node ):
831- """ Returns chains of attribute access starting from root_getattr_node
832-
833- For example, given attribute "block", as in "self.block" when "self" points
834- to the top level torch.nn.Module, it returns lists of attribute "chains",
835- e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
836-
837- These sets of attributes form full attribute accessors. For example,
838- "self.block.1", "self.block.2" will return the second and third submodule,
839- and "self.block.0._packed_params" will return the parameters of the first
840- submodule.
841- """
842- def terminate (users ):
843- next_attrs = [user for user in users if user .kind () == "prim::GetAttr" ]
844- return len (next_attrs ) == 0
845-
846- return get_use_chains (root_getattr_node , terminate )
800+ def _getattr_full_name (getattrs ):
801+ return "." .join ([_getattr_attr_name (node ) for node in getattrs ])
847802
848803
849- def get_input_types (op_node ):
804+ def _get_input_types (op_node ):
850805 """ Returns a torch type for each input nodes """
851806 input_list_types = []
852807 for input_node in op_node .inputs ():
853808 in_ty = input_node .type ()
854809 input_node_kind = in_ty .kind ()
855810 if input_node_kind == 'TensorType' :
856811 if in_ty .scalarType () is None :
857- input_list_types .append ('float' )
812+ input_list_types .append (None )
858813 else :
859814 input_list_types .append (in_ty .scalarType ().lower ())
860815 elif input_node_kind == 'ListType' :
@@ -874,7 +829,7 @@ def get_input_types(op_node):
874829 return input_list_types
875830
876831
877- def get_constant (node ):
832+ def _get_constant (node ):
878833 """ Retrieve a constant associated with this prim::Constant node """
879834 attribute_names = node .attributeNames ()
880835 num_attributes = len (attribute_names )
@@ -903,15 +858,15 @@ def get_constant(node):
903858 return None
904859
905860
906- def get_operator_nodes (nodes ):
861+ def _get_operator_nodes (nodes ):
907862 """ Returns torch IR nodes that need conversion to Relay """
908863 ops = {}
909864 # Traverse nodes and add to graph
910865 for node in nodes :
911866 if node .outputsSize () > 1 :
912- node_name = "_" .join (get_output_names (node ))
867+ node_name = "_" .join (_get_output_names (node ))
913868 else :
914- node_name = get_output_name (node )
869+ node_name = _get_output_name (node )
915870
916871 if node .kind () != "prim::GetAttr" :
917872 ops [node_name ] = node
@@ -930,6 +885,46 @@ def parse_inputs(graph_inputs, input_shapes):
930885 return input_vars
931886
932887
888+ def get_use_chains (root_node , terminate = lambda _ : False ):
889+ """
890+ Track a chain of users of this node forward, returning a list of chains
891+ See get_attr_chains below for its usage
892+ """
893+ def concat_lists (lists ):
894+ return itertools .chain .from_iterable (lists )
895+
896+ def inner (current , accum ):
897+ users = []
898+ for output in current .outputs ():
899+ users += [use .user for use in output .uses ()]
900+
901+ if not users or terminate (users ):
902+ return [accum ]
903+
904+ return concat_lists ([inner (nxt , accum + [nxt ]) for nxt in users ])
905+
906+ return inner (root_node , [root_node ])
907+
908+
909+ def get_attr_chains (root_getattr_node ):
910+ """ Returns chains of attribute access starting from root_getattr_node
911+
912+ For example, given attribute "block", as in "self.block" when "self" points
913+ to the top level torch.nn.Module, it returns lists of attribute "chains",
914+ e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
915+
916+ These sets of attributes form full attribute accessors. For example,
917+ "self.block.1", "self.block.2" will return the second and third submodule,
918+ and "self.block.0._packed_params" will return the parameters of the first
919+ submodule.
920+ """
921+ def terminate (users ):
922+ next_attrs = [user for user in users if user .kind () == "prim::GetAttr" ]
923+ return len (next_attrs ) == 0
924+
925+ return get_use_chains (root_getattr_node , terminate )
926+
927+
933928def parse_params (graph , state_dict ):
934929 """
935930 Return Relay vars and TVM NDArrays for input parameters
@@ -941,19 +936,19 @@ def parse_params(graph, state_dict):
941936 seen = set ()
942937
943938 for node in getattr_nodes :
944- if get_output_name (node ) in seen :
939+ if _get_output_name (node ) in seen :
945940 continue
946941
947942 for getattrs in get_attr_chains (node ):
948- seen .update (map (get_output_name , getattrs ))
943+ seen .update (map (_get_output_name , getattrs ))
949944
950- full_attr = get_full_attr_name (getattrs )
951- full_attr_node_name = get_output_name (getattrs [- 1 ])
945+ full_attr = _getattr_full_name (getattrs )
946+ full_attr_node_name = _get_output_name (getattrs [- 1 ])
952947
953948 if full_attr in state_dict :
954949 torch_tensor = state_dict [full_attr ]
955- tensor , var = get_tensor_and_var (torch_tensor ,
956- full_attr_node_name )
950+ tensor , var = _get_tensor_and_var (torch_tensor ,
951+ full_attr_node_name )
957952 param_tensors [full_attr_node_name ] = tensor
958953 params [full_attr_node_name ] = var
959954
@@ -964,35 +959,41 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
964959 """ Convert each Torch IR operators to Relay equivalent """
965960 for node_name , op_node in operators .items ():
966961 operator = op_node .kind ()
967- inputs = get_op_inputs (op_node , outputs , output_index_map )
962+ inputs = _get_op_inputs (op_node , outputs , output_index_map )
968963
969964 if operator == "prim::Constant" :
970965 output_index_map [node_name ] = len (outputs )
971- outputs .append (get_constant (op_node ))
972- elif operator == 'prim::ListConstruct' and is_int_seq (inputs ):
966+ outputs .append (_get_constant (op_node ))
967+ elif operator == 'prim::ListConstruct' and _is_int_seq (inputs ):
973968 output_index_map [node_name ] = len (outputs )
974969 outputs .append (_expr .var (node_name , shape = inputs ))
975970 elif operator in ['prim::ListConstruct' , 'prim::TupleConstruct' ]:
976971 output_index_map [node_name ] = len (outputs )
977972 outputs .append (inputs )
978973 elif operator in ["prim::ListUnpack" , 'prim::TupleUnpack' ]:
979974 assert len (inputs ) == 1
980- unpacked_names = get_output_names (op_node )
981- update_outputs_from_pairs (zip (unpacked_names , inputs [0 ]),
982- outputs , output_index_map )
975+ unpacked_names = _get_output_names (op_node )
976+ _update_outputs_from_pairs (zip (unpacked_names , inputs [0 ]),
977+ outputs , output_index_map )
983978 else :
984979 output_index_map [node_name ] = len (outputs )
985980 relay_op = _convert_map [operator ]
986- outputs .append (relay_op (inputs , get_input_types (op_node )))
981+ outputs .append (relay_op (inputs , _get_input_types (op_node )))
987982
988983 return outputs [output_index_map [ret_name ]]
989984
990985
986+ def get_all_op_names (graph ):
987+ """ Return all operator names in the input graph """
988+ nodes = list (graph .nodes ())
989+ return set (node .kind () for node in nodes )
990+
991+
991992def get_graph_input_names (script_module ):
992993 """ Use this function to set the keys for input_shapes"""
993994 # It seems variable names could change the first time a copy is made
994995 # Use the copy of the graph here to prevent troubles later
995- ir_inputs = get_input_names (script_module .graph .copy ())
996+ ir_inputs = _get_input_names (script_module .graph .copy ())
996997 return ir_inputs [1 :] # remove self at the 0th arg
997998
998999
@@ -1019,9 +1020,9 @@ def from_pytorch(script_module, input_shapes):
10191020 Dict of converted parameters stored in tvm.runtime.ndarray format
10201021 """
10211022 graph = script_module .graph .copy ()
1022- run_jit_passes (graph )
1023+ _run_jit_passes (graph )
10231024 op_names = get_all_op_names (graph )
1024- report_missing_conversion (op_names )
1025+ _report_missing_conversion (op_names )
10251026
10261027 params = script_module .state_dict ()
10271028 input_vars = parse_inputs (graph .inputs (), input_shapes )
@@ -1030,9 +1031,9 @@ def from_pytorch(script_module, input_shapes):
10301031 input_vars .update (param_vars )
10311032 outputs = list (input_vars .values ())
10321033 output_index_map = dict (zip (input_vars .keys (), range (len (outputs ))))
1033- ret_name = get_input_names (graph .return_node ())[0 ]
1034+ ret_name = _get_input_names (graph .return_node ())[0 ]
10341035
1035- body = parse_operators (get_operator_nodes (graph .nodes ()), outputs ,
1036+ body = parse_operators (_get_operator_nodes (graph .nodes ()), outputs ,
10361037 output_index_map , ret_name )
10371038 func = tvm .relay .Function (_analysis .free_vars (body ), body )
10381039 tvm_params = {k : tvm .nd .array (v ) for k , v in tensors .items ()}
0 commit comments