[Frontend][Torch] Fix up graph input handling#5204
Conversation
|
@jjohnson-arm need to update the tutorial too. |
| for output_name, output in name_output_pairs: | ||
| output_index_map[output_name] = len(outputs) | ||
| outputs.append(output) | ||
| def _update_inputs_from_pairs(name_input_pairs, input_vars): |
There was a problem hiding this comment.
I think we don't need this function anymore. Dict's update method can be used.
|
please update the doc here https://github.com/apache/incubator-tvm/blob/e722301a1c8be3c7052273961b8a408ca5524c76/python/tvm/relay/frontend/pytorch.py#L1434-L1436 We should warn that this names need be around until deployment time. Our suggestion is to choose something obvious, that doesn't require remembering. Something like "input0", "input1" etc |
|
I still like to retain the original meaning of I think you can simply replace |
| """ | ||
| input_vars = {} | ||
| ir_inputs = _get_graph_input_names(graph) | ||
| for idx, ir_input in enumerate(ir_inputs): |
There was a problem hiding this comment.
How about
for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
...
|
|
||
| params = script_module.state_dict() | ||
| input_vars = _get_relay_input_vars(input_shapes) | ||
| input_vars = _get_relay_input_vars(graph, input_shapes) |
| """ | ||
| Add quant params to outputs so that they can be referenced by other | ||
| Add quant params to inputs so that they can be referenced by other | ||
| ops later. Weights are quantized here. |
There was a problem hiding this comment.
For L104 and L107, please keep outputs
| input_names = [output_index_map[name] | ||
| for name in _get_input_names(op_node)] | ||
| return [outputs[name] for name in input_names] | ||
| def _get_op_inputs(op_node, input_vars): |
There was a problem hiding this comment.
input_vars -> outputs
Because inputs are not relay.Var.
| input_shapes = list(zip(input_names, ishapes)) | ||
|
|
||
| inputs = [torch.randn(shape, dtype=torch.float) | ||
| for name, shape in input_shapes] |
Ok - will do. I wasn't sure about the inputs rename, happy to change it back. |
Will look at this now. |
|
@jjohnson-arm Unfortunately, you've just hit a known flaky test failure. Please comment out the get_valid_count test. See #4901 (comment) Also have you verified that torch frontend tests work with this PR? I'm not sure some of the usage of |
Will comment out the test. |
|
ok good to know. I was thinking the arg of update should be dict. |
|
Thanks @jjohnson-arm this is merged! |
* [Frontend][Torch] Simplify operator input handling * [Frontend][Torch] Allow user supplied input names to override graph inputs * Fix pylint issues * Updates from code review feedback * Fix tutorial to use shape list input * Disable intermittent test failure in topi vision test
* [Frontend][Torch] Simplify operator input handling * [Frontend][Torch] Allow user supplied input names to override graph inputs * Fix pylint issues * Updates from code review feedback * Fix tutorial to use shape list input * Disable intermittent test failure in topi vision test
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.
From: https://discuss.tvm.ai/t/pytorch-frontend-graph-input-names-can-change-using-loaded-torchscript/6055
Split as two commits to make it easier to review:
Review request: @masahi