diff --git a/exir/backend/test/demos/rpc/test_rpc.py b/exir/backend/test/demos/rpc/test_rpc.py index 0c0e72862fd..63feb954fee 100644 --- a/exir/backend/test/demos/rpc/test_rpc.py +++ b/exir/backend/test/demos/rpc/test_rpc.py @@ -8,6 +8,7 @@ import torch from executorch import exir +from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import ( ExecutorBackendPartitioner, @@ -20,6 +21,7 @@ from executorch.extension.pybindings.portable_lib import ( # @manual _load_for_executorch_from_buffer, ) +from torch.export import export from torch.utils._pytree import tree_flatten """ @@ -101,16 +103,15 @@ def test_delegate_whole_program(self): simple_net = self.get_a_simple_net() simple_net_input = simple_net.get_example_inputs() - exported_program = exir.capture( - simple_net, simple_net_input, exir.CaptureConfig() - ).to_edge( - exir.EdgeCompileConfig( + exported_program = to_edge( + export(simple_net, simple_net_input), + compile_config=exir.EdgeCompileConfig( _check_ir_validity=False, - ) + ), ) # delegate the whole graph to the client executor lowered_module = to_backend( - ExecutorBackend.__name__, exported_program.exported_program, [] + ExecutorBackend.__name__, exported_program.exported_program(), [] ) class CompositeModule(torch.nn.Module): @@ -123,11 +124,7 @@ def forward(self, *args): composite_model = CompositeModule() - exec_prog = ( - exir.capture(composite_model, simple_net_input, exir.CaptureConfig()) - .to_edge() - .to_executorch() - ) + exec_prog = to_edge(export(composite_model, simple_net_input)).to_executorch() executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer) @@ -162,18 +159,14 @@ def forward(self, a, x, b): model = Model() inputs = (torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2)) - exported_program = exir.capture(model, inputs, exir.CaptureConfig()).to_edge() + exported_program = to_edge(export(model, inputs)) # First lower to demo backend - demo_backend_lowered = exported_program - demo_backend_lowered.exported_program = to_backend( - exported_program.exported_program, AddMulPartitionerDemo() - ) + demo_backend_lowered = exported_program.to_backend(AddMulPartitionerDemo()) # Then lower to executor backend - executor_backend_lowered = demo_backend_lowered - executor_backend_lowered.exported_program = to_backend( - demo_backend_lowered.exported_program, ExecutorBackendPartitioner() + executor_backend_lowered = demo_backend_lowered.to_backend( + ExecutorBackendPartitioner() ) prog_buffer = executor_backend_lowered.to_executorch()