Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions exir/backend/test/demos/rpc/test_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from executorch import exir
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
ExecutorBackendPartitioner,
Expand All @@ -20,6 +21,7 @@
from executorch.extension.pybindings.portable_lib import ( # @manual
_load_for_executorch_from_buffer,
)
from torch.export import export
from torch.utils._pytree import tree_flatten

"""
Expand Down Expand Up @@ -101,16 +103,15 @@ def test_delegate_whole_program(self):

simple_net = self.get_a_simple_net()
simple_net_input = simple_net.get_example_inputs()
exported_program = exir.capture(
simple_net, simple_net_input, exir.CaptureConfig()
).to_edge(
exir.EdgeCompileConfig(
exported_program = to_edge(
export(simple_net, simple_net_input),
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False,
)
),
)
# delegate the whole graph to the client executor
lowered_module = to_backend(
ExecutorBackend.__name__, exported_program.exported_program, []
ExecutorBackend.__name__, exported_program.exported_program(), []
)

class CompositeModule(torch.nn.Module):
Expand All @@ -123,11 +124,7 @@ def forward(self, *args):

composite_model = CompositeModule()

exec_prog = (
exir.capture(composite_model, simple_net_input, exir.CaptureConfig())
.to_edge()
.to_executorch()
)
exec_prog = to_edge(export(composite_model, simple_net_input)).to_executorch()

executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)

Expand Down Expand Up @@ -162,18 +159,14 @@ def forward(self, a, x, b):
model = Model()
inputs = (torch.ones(2, 2), torch.ones(2, 2), torch.ones(2, 2))

exported_program = exir.capture(model, inputs, exir.CaptureConfig()).to_edge()
exported_program = to_edge(export(model, inputs))

# First lower to demo backend
demo_backend_lowered = exported_program
demo_backend_lowered.exported_program = to_backend(
exported_program.exported_program, AddMulPartitionerDemo()
)
demo_backend_lowered = exported_program.to_backend(AddMulPartitionerDemo())

# Then lower to executor backend
executor_backend_lowered = demo_backend_lowered
executor_backend_lowered.exported_program = to_backend(
demo_backend_lowered.exported_program, ExecutorBackendPartitioner()
executor_backend_lowered = demo_backend_lowered.to_backend(
ExecutorBackendPartitioner()
)

prog_buffer = executor_backend_lowered.to_executorch()
Expand Down