forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Initial unit tests for OpenVINO backend #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ynimmaga
merged 5 commits into
ynimmaga:openvino_backend
from
cavusmustafa:openvino_backend_unit_tests
Feb 1, 2025
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5806788
Initial unit tests for OpenVINO backend
cavusmustafa 916ba64
Unit test update and cleanup
cavusmustafa e0b1bb7
Input/Output processing for example and unit tests
cavusmustafa 9108770
Added executorch parameter to openvino_compile call
cavusmustafa ecbe5e2
New op unit tests added
cavusmustafa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| import os | ||
| import subprocess | ||
| import tempfile | ||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import executorch | ||
| from executorch.backends.openvino.partitioner import OpenvinoPartitioner | ||
| from executorch.exir.backend.backend_details import CompileSpec | ||
| from torch.export import export, ExportedProgram | ||
| from executorch.exir import EdgeProgramManager, to_edge | ||
| from executorch.backends.openvino.preprocess import OpenvinoBackend | ||
|
|
||
|
|
||
| class BaseOpenvinoOpTest(unittest.TestCase): | ||
| device = "CPU" | ||
| build_folder = "" | ||
|
|
||
| atol = 1e-1 | ||
| rtol = 1e-1 | ||
|
|
||
| def execute_layer_test( | ||
| self, | ||
| module: torch.nn.Module, | ||
| sample_inputs: tuple[torch.Tensor], | ||
| expected_partitions: int = 1, | ||
| assert_output_equal: bool = True, | ||
| ): | ||
|
|
||
| module = module.eval() | ||
| # Export to aten dialect using torch.export | ||
| aten_dialect: ExportedProgram = export(module, sample_inputs) | ||
|
|
||
| # Convert to edge dialect | ||
| edge_program: EdgeProgramManager = to_edge(aten_dialect) | ||
| to_be_lowered_module = edge_program.exported_program() | ||
|
|
||
| # Lower the module to the backend with a custom partitioner | ||
| compile_spec = [CompileSpec("device", self.device.encode())] | ||
| lowered_module = edge_program.to_backend(OpenvinoPartitioner(compile_spec)) | ||
|
|
||
| # Apply backend-specific passes | ||
| exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig()) | ||
|
|
||
| # Check if the number of partitions created matches the expected number of partitions | ||
| self.assertEqual( | ||
| len(exec_prog.executorch_program.execution_plan[0].delegates), | ||
| expected_partitions, | ||
| ) | ||
| # Check if the individual partitions are assigned to Openvino backend | ||
| for i in range(expected_partitions): | ||
| self.assertEqual( | ||
| exec_prog.executorch_program.execution_plan[0].delegates[i].id, | ||
| OpenvinoBackend.__name__, | ||
| ) | ||
|
|
||
| # Execute the model and compare the outputs with the reference outputs | ||
| if (assert_output_equal): | ||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| input_list = "" | ||
| for idx, _ in enumerate(sample_inputs): | ||
| input_name = f"input_0_{idx}.raw" | ||
| input_list += input_name + " " | ||
| input_list = input_list.strip() + "\n" | ||
|
|
||
| output_dir = f"{tmp_dir}/outputs" | ||
|
|
||
| # Execute the module in eager mode to calculate the reference outputs | ||
| ref_output = module(*sample_inputs) | ||
| if isinstance(ref_output, torch.Tensor): | ||
| ref_output = [ref_output,] | ||
|
|
||
| # Serialize the executorch model and save into a temporary file | ||
| pte_fname = f"{tmp_dir}/openvino_executorch_test.pte" | ||
| with open(pte_fname, "wb") as file: | ||
| exec_prog.write_to_file(file) | ||
|
|
||
| # Save inputs into a temporary file | ||
| self.generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list) | ||
| self.make_output_dir(output_dir) | ||
|
|
||
| # Start a subprocess to execute model with openvino_executor_runner | ||
| cmd = [ | ||
| f"{self.build_folder}/examples/openvino/openvino_executor_runner", | ||
| "--model_path", | ||
| pte_fname, | ||
| "--input_list_path", | ||
| f"{tmp_dir}/input_list.txt", | ||
| "--output_folder_path", | ||
| output_dir, | ||
| ] | ||
|
|
||
| env = dict(os.environ) | ||
| proc = subprocess.run( | ||
| cmd, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.STDOUT, | ||
| env=env, | ||
| cwd=tmp_dir, | ||
| ) | ||
|
|
||
| stdout_str = proc.stdout.decode('utf-8') | ||
|
|
||
| # Check if execution completed successfully | ||
| self.assertIn("Model executed successfully.", stdout_str) | ||
|
|
||
| # Read the outputs from the temporary files | ||
| output_dir = f"{tmp_dir}/outputs" | ||
| outputs = [] | ||
|
|
||
| for i, f in enumerate(sorted(os.listdir(output_dir))): | ||
| filename = os.path.join(output_dir, f) | ||
| output = np.fromfile(filename, dtype=ref_output[i].detach().numpy().dtype) | ||
| output = torch.from_numpy(output).reshape(ref_output[i].shape) | ||
| outputs.append(output) | ||
|
|
||
| # Compare the outputs with the reference outputs | ||
| self.assertTrue(len(ref_output) == len(outputs)) | ||
| for i in range(len(ref_output)): | ||
| self.assertTrue( | ||
| torch.allclose( | ||
| outputs[i], ref_output[i], atol=self.atol, rtol=self.rtol, equal_nan=True | ||
| ), | ||
| msg=f"ref_output:\n{ref_output[i]}\n\ntest_output:\n{outputs[i]}", | ||
| ) | ||
|
|
||
| def generate_inputs(self, dest_path: str, file_name: str, inputs=None, input_list=None): | ||
| input_list_file = None | ||
| input_files = [] | ||
|
|
||
| # Prepare input list | ||
| if input_list is not None: | ||
| input_list_file = f"{dest_path}/{file_name}" | ||
| with open(input_list_file, "w") as f: | ||
| f.write(input_list) | ||
| f.flush() | ||
|
|
||
| # Prepare input data | ||
| if inputs is not None: | ||
| for idx, data in enumerate(inputs): | ||
| for i, d in enumerate(data): | ||
| file_name = f"{dest_path}/input_{idx}_{i}.raw" | ||
| d.detach().numpy().tofile(file_name) | ||
| input_files.append(file_name) | ||
|
|
||
| return input_list_file, input_files | ||
|
|
||
| def make_output_dir(self, path: str): | ||
| if os.path.exists(path): | ||
| for f in os.listdir(path): | ||
| os.remove(os.path.join(path, f)) | ||
| os.removedirs(path) | ||
| os.makedirs(path) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest | ||
| import torch | ||
|
|
||
| class TestAddOperator(BaseOpenvinoOpTest): | ||
|
|
||
| def create_model(self): | ||
| class Add(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def forward(self, x, y): | ||
| return torch.add(x, y) | ||
|
|
||
| return Add() | ||
|
|
||
| def test_add(self): | ||
| module = self.create_model() | ||
| sample_input = (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)) | ||
| self.execute_layer_test(module, sample_input) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest | ||
| import torch | ||
|
|
||
| class TestAddMMOperator(BaseOpenvinoOpTest): | ||
|
|
||
| def create_model(self): | ||
| class AddMM(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.alpha = 1. | ||
| self.beta = 1. | ||
|
|
||
| def forward(self, x, y, z): | ||
| #return torch.add(x, y) | ||
| return torch.addmm(x, y, z, alpha=self.alpha, beta=self.beta) | ||
|
|
||
| return AddMM() | ||
|
|
||
| def test_addmm(self): | ||
| module = self.create_model() | ||
| input_x = torch.randn(4,4, dtype=torch.float32) | ||
| input_y = torch.randn(4,4, dtype=torch.float32) | ||
| input_z = torch.randn(4,4, dtype=torch.float32) | ||
| sample_input = (input_x, input_y, input_z) | ||
| self.execute_layer_test(module, sample_input) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest | ||
| import torch | ||
|
|
||
| class TestArangeOperator(BaseOpenvinoOpTest): | ||
|
|
||
| def create_model(self, x): | ||
| class Arange(torch.nn.Module): | ||
| def __init__(self, x): | ||
| super().__init__() | ||
| self.x = x | ||
|
|
||
| def forward(self, y): | ||
| return torch.arange(self.x, dtype=torch.float32) + y | ||
|
|
||
| return Arange(5) | ||
|
|
||
| def test_arange(self): | ||
| module = self.create_model(5) | ||
| sample_input = (torch.randn(5),) | ||
| self.execute_layer_test(module, sample_input) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest | ||
| import torch | ||
|
|
||
| op_params = [{'weights': True, 'bias': True, 'eps': 1.0 }, | ||
| {'weights': True, 'bias': True, 'eps': 0.00005 }, | ||
| {'weights': True, 'bias': True, 'eps': 0.5 }, | ||
| {'weights': True, 'bias': True, 'eps': 0.042 }, | ||
| {'weights': True, 'bias': False, 'eps': 1.0 }, | ||
| {'weights': True, 'bias': False, 'eps': 0.00005 }, | ||
| {'weights': True, 'bias': False, 'eps': 0.5 }, | ||
| {'weights': True, 'bias': False, 'eps': 0.042 }, | ||
| {'weights': False, 'bias': True, 'eps': 1.0 }, | ||
| {'weights': False, 'bias': True, 'eps': 0.00005 }, | ||
| {'weights': False, 'bias': True, 'eps': 0.5 }, | ||
| {'weights': False, 'bias': True, 'eps': 0.042 }, | ||
| {'weights': False, 'bias': False, 'eps': 1.0 }, | ||
| {'weights': False, 'bias': False, 'eps': 0.00005 }, | ||
| {'weights': False, 'bias': False, 'eps': 0.5 }, | ||
| {'weights': False, 'bias': False, 'eps': 0.042 }, | ||
| ] | ||
|
|
||
|
|
||
| class TestBatchNormOperator(BaseOpenvinoOpTest): | ||
|
|
||
| def create_model(self, weights, bias, eps): | ||
|
|
||
| class BatchNorm(torch.nn.Module): | ||
| def __init__(self, weights=True, bias=True, eps=1e-05): | ||
| super(BatchNorm, self).__init__() | ||
| self.weight = torch.nn.Parameter(torch.randn(6)) if weights else None | ||
| self.bias = torch.nn.Parameter(torch.randn(6)) if bias else None | ||
| self.running_mean = torch.randn(6) | ||
| self.running_var = torch.randn(6) | ||
| self.eps = eps | ||
|
|
||
| def forward(self, x): | ||
| return torch.nn.functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=False) | ||
|
|
||
| return BatchNorm(weights, bias, eps) | ||
|
|
||
|
|
||
| def test_batch_norm(self): | ||
| for params in op_params: | ||
| with self.subTest(params=params): | ||
| module = self.create_model(weights=params['weights'], | ||
| bias=params['bias'], | ||
| eps=params['eps']) | ||
|
|
||
| sample_input = (torch.randn(20, 6, 10),) | ||
|
|
||
| self.execute_layer_test(module, sample_input) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest | ||
| import torch | ||
|
|
||
| d2_params = [{'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [1, 1], 'groups': 1, | ||
| 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ | ||
| 1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ | ||
| 1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ | ||
| 1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ | ||
| 1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ | ||
| 1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ | ||
| 3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ | ||
| 3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ | ||
| 1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ | ||
| 0, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ | ||
| 1, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ | ||
| 0, 1], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ | ||
| 1, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ | ||
| 0, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [2, 1], 'bias_shape': [1], 'pads': [ | ||
| 1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [ | ||
| 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ | ||
| 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 3, 1, 1], 'strides': [2, 2], 'pads': [ | ||
| 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [ | ||
| 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ | ||
| 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, | ||
| {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ | ||
| 1, 1], 'dilations': [2, 2], 'groups': 1, 'output_padding': [1, 1], 'transposed': True}, | ||
| ] | ||
|
|
||
| class TestConvolutionOperator(BaseOpenvinoOpTest): | ||
|
|
||
| def create_model(self, weights_shape, strides, pads, dilations, groups, bias, transposed, output_padding=0, | ||
| bias_shape=None, underscore=False): | ||
|
|
||
| bias_dim = 0 | ||
|
|
||
| class Convolution(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.weight = torch.nn.Parameter(torch.randn(weights_shape)) | ||
| self.bias_shape = bias_shape | ||
| if self.bias_shape is None: | ||
| self.bias_shape = weights_shape[bias_dim] | ||
| self.bias = torch.nn.Parameter(torch.randn(self.bias_shape)) if bias else None | ||
| self.strides = strides | ||
| self.pads = pads | ||
| self.dilations = dilations | ||
| self.groups = groups | ||
| self.transposed = transposed | ||
| self.output_padding = output_padding | ||
| if underscore: | ||
| self.forward = self.forward_ | ||
|
|
||
| def forward(self, x): | ||
| return torch.convolution( | ||
| x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed, | ||
| self.output_padding, self.groups | ||
| ) | ||
|
|
||
| def forward_(self, x): | ||
| return torch._convolution( | ||
| x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed, | ||
| self.output_padding, self.groups, False, False, False, False | ||
| ) | ||
|
|
||
| return Convolution() | ||
|
|
||
| def test_convolution(self): | ||
| bias_underscore_config = [(False, False), (True, False)] | ||
| for bias, underscore in bias_underscore_config: | ||
| for params in d2_params: | ||
| with self.subTest(params=params, bias=bias, underscore=underscore): | ||
| bias_shape = None | ||
| if 'bias_shape' in params: | ||
| bias_shape = params['bias_shape'] | ||
| module = self.create_model(weights_shape=params['weights_shape'], | ||
| strides=params['strides'], | ||
| pads=params['pads'], | ||
| dilations=params['dilations'], | ||
| groups=params['groups'], | ||
| output_padding=params['output_padding'], | ||
| transposed=params['transposed'], | ||
| bias_shape=bias_shape, | ||
| bias=bias, | ||
| underscore=underscore) | ||
| sample_input = (torch.randn(1, 3, 10, 10),) | ||
| self.execute_layer_test(module, sample_input) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instructions need to be updated to use the changes in OpenVINO