diff --git a/backends/openvino/preprocess.py b/backends/openvino/preprocess.py index 96df9faba85..8f6991afdd3 100644 --- a/backends/openvino/preprocess.py +++ b/backends/openvino/preprocess.py @@ -38,7 +38,7 @@ def preprocess( for spec in module_compile_spec: compile_options[spec.key] = spec.value.decode() - compiled = openvino_compile(edge_program.module(), *args, options=compile_options) + compiled = openvino_compile(edge_program.module(), *args, options=compile_options, executorch=True) model_bytes = compiled.export_model() return PreprocessResult(processed_bytes=model_bytes) diff --git a/backends/openvino/tests/ops/base_openvino_op_test.py b/backends/openvino/tests/ops/base_openvino_op_test.py new file mode 100644 index 00000000000..a51b99e8eca --- /dev/null +++ b/backends/openvino/tests/ops/base_openvino_op_test.py @@ -0,0 +1,154 @@ +import os +import subprocess +import tempfile +import unittest + +import numpy as np +import torch +import executorch +from executorch.backends.openvino.partitioner import OpenvinoPartitioner +from executorch.exir.backend.backend_details import CompileSpec +from torch.export import export, ExportedProgram +from executorch.exir import EdgeProgramManager, to_edge +from executorch.backends.openvino.preprocess import OpenvinoBackend + + +class BaseOpenvinoOpTest(unittest.TestCase): + device = "CPU" + build_folder = "" + + atol = 1e-1 + rtol = 1e-1 + + def execute_layer_test( + self, + module: torch.nn.Module, + sample_inputs: tuple[torch.Tensor], + expected_partitions: int = 1, + assert_output_equal: bool = True, + ): + + module = module.eval() + # Export to aten dialect using torch.export + aten_dialect: ExportedProgram = export(module, sample_inputs) + + # Convert to edge dialect + edge_program: EdgeProgramManager = to_edge(aten_dialect) + to_be_lowered_module = edge_program.exported_program() + + # Lower the module to the backend with a custom partitioner + compile_spec = [CompileSpec("device", self.device.encode())] + lowered_module = edge_program.to_backend(OpenvinoPartitioner(compile_spec)) + + # Apply backend-specific passes + exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig()) + + # Check if the number of partitions created matches the expected number of partitions + self.assertEqual( + len(exec_prog.executorch_program.execution_plan[0].delegates), + expected_partitions, + ) + # Check if the individual partitions are assigned to Openvino backend + for i in range(expected_partitions): + self.assertEqual( + exec_prog.executorch_program.execution_plan[0].delegates[i].id, + OpenvinoBackend.__name__, + ) + + # Execute the model and compare the outputs with the reference outputs + if (assert_output_equal): + with tempfile.TemporaryDirectory() as tmp_dir: + input_list = "" + for idx, _ in enumerate(sample_inputs): + input_name = f"input_0_{idx}.raw" + input_list += input_name + " " + input_list = input_list.strip() + "\n" + + output_dir = f"{tmp_dir}/outputs" + + # Execute the module in eager mode to calculate the reference outputs + ref_output = module(*sample_inputs) + if isinstance(ref_output, torch.Tensor): + ref_output = [ref_output,] + + # Serialize the executorch model and save into a temporary file + pte_fname = f"{tmp_dir}/openvino_executorch_test.pte" + with open(pte_fname, "wb") as file: + exec_prog.write_to_file(file) + + # Save inputs into a temporary file + self.generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list) + self.make_output_dir(output_dir) + + # Start a subprocess to execute model with openvino_executor_runner + cmd = [ + f"{self.build_folder}/examples/openvino/openvino_executor_runner", + "--model_path", + pte_fname, + "--input_list_path", + f"{tmp_dir}/input_list.txt", + "--output_folder_path", + output_dir, + ] + + env = dict(os.environ) + proc = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + cwd=tmp_dir, + ) + + stdout_str = proc.stdout.decode('utf-8') + + # Check if execution completed successfully + self.assertIn("Model executed successfully.", stdout_str) + + # Read the outputs from the temporary files + output_dir = f"{tmp_dir}/outputs" + outputs = [] + + for i, f in enumerate(sorted(os.listdir(output_dir))): + filename = os.path.join(output_dir, f) + output = np.fromfile(filename, dtype=ref_output[i].detach().numpy().dtype) + output = torch.from_numpy(output).reshape(ref_output[i].shape) + outputs.append(output) + + # Compare the outputs with the reference outputs + self.assertTrue(len(ref_output) == len(outputs)) + for i in range(len(ref_output)): + self.assertTrue( + torch.allclose( + outputs[i], ref_output[i], atol=self.atol, rtol=self.rtol, equal_nan=True + ), + msg=f"ref_output:\n{ref_output[i]}\n\ntest_output:\n{outputs[i]}", + ) + + def generate_inputs(self, dest_path: str, file_name: str, inputs=None, input_list=None): + input_list_file = None + input_files = [] + + # Prepare input list + if input_list is not None: + input_list_file = f"{dest_path}/{file_name}" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + + # Prepare input data + if inputs is not None: + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{dest_path}/input_{idx}_{i}.raw" + d.detach().numpy().tofile(file_name) + input_files.append(file_name) + + return input_list_file, input_files + + def make_output_dir(self, path: str): + if os.path.exists(path): + for f in os.listdir(path): + os.remove(os.path.join(path, f)) + os.removedirs(path) + os.makedirs(path) diff --git a/backends/openvino/tests/ops/test_add.py b/backends/openvino/tests/ops/test_add.py new file mode 100644 index 00000000000..d298f77e792 --- /dev/null +++ b/backends/openvino/tests/ops/test_add.py @@ -0,0 +1,19 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +class TestAddOperator(BaseOpenvinoOpTest): + + def create_model(self): + class Add(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.add(x, y) + + return Add() + + def test_add(self): + module = self.create_model() + sample_input = (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_addmm.py b/backends/openvino/tests/ops/test_addmm.py new file mode 100644 index 00000000000..32f09ebdc29 --- /dev/null +++ b/backends/openvino/tests/ops/test_addmm.py @@ -0,0 +1,25 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +class TestAddMMOperator(BaseOpenvinoOpTest): + + def create_model(self): + class AddMM(torch.nn.Module): + def __init__(self): + super().__init__() + self.alpha = 1. + self.beta = 1. + + def forward(self, x, y, z): + #return torch.add(x, y) + return torch.addmm(x, y, z, alpha=self.alpha, beta=self.beta) + + return AddMM() + + def test_addmm(self): + module = self.create_model() + input_x = torch.randn(4,4, dtype=torch.float32) + input_y = torch.randn(4,4, dtype=torch.float32) + input_z = torch.randn(4,4, dtype=torch.float32) + sample_input = (input_x, input_y, input_z) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_arange.py b/backends/openvino/tests/ops/test_arange.py new file mode 100644 index 00000000000..0dd739a2585 --- /dev/null +++ b/backends/openvino/tests/ops/test_arange.py @@ -0,0 +1,20 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +class TestArangeOperator(BaseOpenvinoOpTest): + + def create_model(self, x): + class Arange(torch.nn.Module): + def __init__(self, x): + super().__init__() + self.x = x + + def forward(self, y): + return torch.arange(self.x, dtype=torch.float32) + y + + return Arange(5) + + def test_arange(self): + module = self.create_model(5) + sample_input = (torch.randn(5),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_batch_norm.py b/backends/openvino/tests/ops/test_batch_norm.py new file mode 100644 index 00000000000..ecb76860434 --- /dev/null +++ b/backends/openvino/tests/ops/test_batch_norm.py @@ -0,0 +1,51 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'weights': True, 'bias': True, 'eps': 1.0 }, + {'weights': True, 'bias': True, 'eps': 0.00005 }, + {'weights': True, 'bias': True, 'eps': 0.5 }, + {'weights': True, 'bias': True, 'eps': 0.042 }, + {'weights': True, 'bias': False, 'eps': 1.0 }, + {'weights': True, 'bias': False, 'eps': 0.00005 }, + {'weights': True, 'bias': False, 'eps': 0.5 }, + {'weights': True, 'bias': False, 'eps': 0.042 }, + {'weights': False, 'bias': True, 'eps': 1.0 }, + {'weights': False, 'bias': True, 'eps': 0.00005 }, + {'weights': False, 'bias': True, 'eps': 0.5 }, + {'weights': False, 'bias': True, 'eps': 0.042 }, + {'weights': False, 'bias': False, 'eps': 1.0 }, + {'weights': False, 'bias': False, 'eps': 0.00005 }, + {'weights': False, 'bias': False, 'eps': 0.5 }, + {'weights': False, 'bias': False, 'eps': 0.042 }, + ] + + +class TestBatchNormOperator(BaseOpenvinoOpTest): + + def create_model(self, weights, bias, eps): + + class BatchNorm(torch.nn.Module): + def __init__(self, weights=True, bias=True, eps=1e-05): + super(BatchNorm, self).__init__() + self.weight = torch.nn.Parameter(torch.randn(6)) if weights else None + self.bias = torch.nn.Parameter(torch.randn(6)) if bias else None + self.running_mean = torch.randn(6) + self.running_var = torch.randn(6) + self.eps = eps + + def forward(self, x): + return torch.nn.functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=False) + + return BatchNorm(weights, bias, eps) + + + def test_batch_norm(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model(weights=params['weights'], + bias=params['bias'], + eps=params['eps']) + + sample_input = (torch.randn(20, 6, 10),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_convolution.py b/backends/openvino/tests/ops/test_convolution.py new file mode 100644 index 00000000000..83a80282089 --- /dev/null +++ b/backends/openvino/tests/ops/test_convolution.py @@ -0,0 +1,105 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +d2_params = [{'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [1, 1], 'groups': 1, + 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ + 1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ + 1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ + 1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ + 1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ + 1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ + 3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ + 3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ + 1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ + 0, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 1, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 0, 1], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 1, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 0, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 1], 'bias_shape': [1], 'pads': [ + 1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [2, 2], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ + 1, 1], 'dilations': [2, 2], 'groups': 1, 'output_padding': [1, 1], 'transposed': True}, + ] + +class TestConvolutionOperator(BaseOpenvinoOpTest): + + def create_model(self, weights_shape, strides, pads, dilations, groups, bias, transposed, output_padding=0, + bias_shape=None, underscore=False): + + bias_dim = 0 + + class Convolution(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(weights_shape)) + self.bias_shape = bias_shape + if self.bias_shape is None: + self.bias_shape = weights_shape[bias_dim] + self.bias = torch.nn.Parameter(torch.randn(self.bias_shape)) if bias else None + self.strides = strides + self.pads = pads + self.dilations = dilations + self.groups = groups + self.transposed = transposed + self.output_padding = output_padding + if underscore: + self.forward = self.forward_ + + def forward(self, x): + return torch.convolution( + x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed, + self.output_padding, self.groups + ) + + def forward_(self, x): + return torch._convolution( + x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed, + self.output_padding, self.groups, False, False, False, False + ) + + return Convolution() + + def test_convolution(self): + bias_underscore_config = [(False, False), (True, False)] + for bias, underscore in bias_underscore_config: + for params in d2_params: + with self.subTest(params=params, bias=bias, underscore=underscore): + bias_shape = None + if 'bias_shape' in params: + bias_shape = params['bias_shape'] + module = self.create_model(weights_shape=params['weights_shape'], + strides=params['strides'], + pads=params['pads'], + dilations=params['dilations'], + groups=params['groups'], + output_padding=params['output_padding'], + transposed=params['transposed'], + bias_shape=bias_shape, + bias=bias, + underscore=underscore) + sample_input = (torch.randn(1, 3, 10, 10),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_mean.py b/backends/openvino/tests/ops/test_mean.py new file mode 100644 index 00000000000..3315fd1e61d --- /dev/null +++ b/backends/openvino/tests/ops/test_mean.py @@ -0,0 +1,59 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'axes': None, 'keep_dim': None, 'dtype': None, }, + {'axes': None, 'keep_dim': None, 'dtype': "float64",}, + {'axes': None, 'keep_dim': None, 'dtype': "float32",}, + {'axes': None, 'keep_dim': None, 'dtype': "int32", }, + {'axes': 0, 'keep_dim': False, 'dtype': None, }, + {'axes': 0, 'keep_dim': False, 'dtype': None, }, + ] + +dtypes = { + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "int8": torch.int8, + "uint8": torch.uint8 +} + +class TestMeanOperator(BaseOpenvinoOpTest): + + def create_model(self, axes, keep_dims, dtype): + + pt_dtype = dtypes.get(dtype) + + class Mean(torch.nn.Module): + def __init__(self, axes=None, keep_dims=None, dtype=None): + super(Mean, self).__init__() + self.axes = axes + self.keep_dims = keep_dims + self.dtype = dtype + + def forward(self, x): + if self.axes is None and self.keep_dims is None: + if self.dtype is None: + return torch.mean(x, dtype=self.dtype) + return torch.mean(x) + if self.axes is not None and self.keep_dims is None: + if self.dtype is None: + return torch.mean(x, self.axes) + return torch.mean(x, self.axes, dtype=self.dtype) + if self.dtype is None: + return torch.mean(x, self.axes, self.keep_dims) + return torch.mean(x, self.axes, self.keep_dims, dtype=self.dtype) + + return Mean(axes, keep_dims, pt_dtype) + + + def test_mean(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model(axes=params['axes'], + keep_dims=params['keep_dim'], + dtype=params['dtype']) + + sample_input = (torch.randint(-10, 10, (1, 3, 224, 224)).to(dtype=torch.float32),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_permute.py b/backends/openvino/tests/ops/test_permute.py new file mode 100644 index 00000000000..1de60db3965 --- /dev/null +++ b/backends/openvino/tests/ops/test_permute.py @@ -0,0 +1,30 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'order': [0, 2, 3, 1] }, + {'order': [0, 3, 1, 2] }, + ] + +class TestPermuteOperator(BaseOpenvinoOpTest): + + def create_model(self, order): + + class Permute(torch.nn.Module): + def __init__(self, order): + super(Permute, self).__init__() + self.order = order + + def forward(self, x): + return torch.permute(x, self.order) + + return Permute(order) + + + def test_permute(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model(order=params['order']) + + sample_input = (torch.randn(1, 3, 224, 224),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_pooling.py b/backends/openvino/tests/ops/test_pooling.py new file mode 100644 index 00000000000..60ab2f9edfa --- /dev/null +++ b/backends/openvino/tests/ops/test_pooling.py @@ -0,0 +1,65 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': 1}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [0, 1]}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [1, 0]}, + {'kernel_size': [3, 3], 'stride': [2, 1], 'padding': 0}, + {'kernel_size': [2, 1], 'stride': [2, 1], 'padding': 0}, + {'kernel_size': [2, 1], 'stride': None, 'padding': 0}, + {'kernel_size': [2, 1], 'stride': [], 'padding': 0}, + {'kernel_size': [8, 8], 'stride': [8, 4], 'padding': 1}, + ] + +class TestPoolingOperator(BaseOpenvinoOpTest): + + def create_model(self, op_type, kernel_size, stride, padding, dilation=1, ceil_mode=True, count_include_pad=True, dtype=torch.float32): + + class MaxPoolingBase(torch.nn.Module): + def __init__(self): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.dtype = dtype + + def forward(self, x): + pass + + class MaxPool2D(MaxPoolingBase): + def forward(self, x): + return torch.nn.functional.max_pool2d(x.to(self.dtype), self.kernel_size, self.stride, self.padding, self.dilation, + self.ceil_mode) + + class MaxPool2DIndices(MaxPoolingBase): + def forward(self, x): + return torch.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding, self.dilation, + self.ceil_mode, return_indices=True) + + ops = { + "MaxPool2D": MaxPool2D, + "MaxPool2DIndices": MaxPool2DIndices, + } + + aten_pooling = ops[op_type] + + return aten_pooling() + + def test_pooling2d(self): + for params in d2_params: + with self.subTest(params=params): + bias_shape = None + if 'bias_shape' in params: + bias_shape = params['bias_shape'] + module = self.create_model(op_type='MaxPool2D', + kernel_size=params['kernel_size'], + stride=params['stride'], + padding=params['padding'], + dilation=1, + ceil_mode=True, + count_include_pad=True) + sample_input = (torch.randn(1, 3, 15, 15),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_unary_ops.py b/backends/openvino/tests/ops/test_unary_ops.py new file mode 100644 index 00000000000..9a5866d6e65 --- /dev/null +++ b/backends/openvino/tests/ops/test_unary_ops.py @@ -0,0 +1,36 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + + +OPS = [ + torch.relu, +] + + +class TestUnaryOperator(BaseOpenvinoOpTest): + + def create_model(self, op, dtype): + + class UnaryOp(torch.nn.Module): + def __init__(self, op, dtype): + super().__init__() + self.dtype = dtype + self.op = op + + def forward(self, x): + x1 = x.to(self.dtype) + y = self.op(x1) + return y, x1 + + return UnaryOp(op, dtype) + + + def test_unary_op(self): + for op in OPS: + with self.subTest(op=OPS): + + module = self.create_model(op, dtype=torch.float32) + + sample_input = (torch.rand(2, 10) * 10 + 1,) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_view.py b/backends/openvino/tests/ops/test_view.py new file mode 100644 index 00000000000..f5450a10af9 --- /dev/null +++ b/backends/openvino/tests/ops/test_view.py @@ -0,0 +1,32 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'input_shape': [2, 3, 2], 'target_shape': [2, 6] }, + {'input_shape': [4], 'target_shape': [2, 2] }, + ] + +class TestViewOperator(BaseOpenvinoOpTest): + + def create_model(self, target_shape): + + class View(torch.nn.Module): + + def __init__(self, target_shape) -> None: + super().__init__() + self.target_shape = target_shape + + def forward(self, input_tensor): + return input_tensor.view(self.target_shape) + + return View(target_shape) + + + def test_view(self): + for params in op_params: + with self.subTest(params=params): + + module = self.create_model(params['target_shape']) + + sample_input = (torch.randn(params['input_shape']),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/test_openvino_delegate.py b/backends/openvino/tests/test_openvino_delegate.py new file mode 100644 index 00000000000..bbf61d1ea09 --- /dev/null +++ b/backends/openvino/tests/test_openvino_delegate.py @@ -0,0 +1,65 @@ +import unittest +import argparse + +class OpenvinoTestSuite(unittest.TestSuite): + + test_params = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def addTest(self, test): + # Set test parameters if this is an instance of TestOpenvino + from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest + if isinstance(test, BaseOpenvinoOpTest): + if "device" in self.test_params: + test.device = self.test_params["device"] + if "build_folder" in self.test_params: + test.build_folder = self.test_params["build_folder"] + # Call the original addTest method to actually add the test to the suite + super().addTest(test) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-b", + "--build_folder", + help="path to cmake binary directory", + type=str, + required=True, + ) + parser.add_argument( + "-s", + "--device", + help="OpenVINO device to execute the model on", + type=str, + default="CPU", + ) + parser.add_argument( + "-p", + "--pattern", + help="Pattern to match test files. Provide complete file name to run individual op tests", + type=str, + default="test_*.py", + ) + + args, ns_args = parser.parse_known_args(namespace=unittest) + test_params = {} + test_params["device"] = args.device + test_params["build_folder"] = args.build_folder + test_params["pattern"] = args.pattern + return test_params + +if __name__ == "__main__": + loader = unittest.TestLoader() + # Replace the default test suite with a custom test suite to be able to + # pass test parameter to the test cases + loader.suiteClass = OpenvinoTestSuite + test_params = parse_arguments() + loader.suiteClass.test_params = test_params + # Discover all existing op tests in "ops" folder + suite = loader.discover("ops", pattern=test_params['pattern']) + # Start running tests + unittest.TextTestRunner().run(suite) diff --git a/examples/openvino/executor_runner/openvino_executor_runner.cpp b/examples/openvino/executor_runner/openvino_executor_runner.cpp index 67bb35d9701..b6e13218773 100644 --- a/examples/openvino/executor_runner/openvino_executor_runner.cpp +++ b/examples/openvino/executor_runner/openvino_executor_runner.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -16,7 +17,7 @@ // Define a fixed-size memory pool for the method allocator (4 MB) static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB -// Define command-line flags for model path and the number of iterations +// Define command-line flags for model path, the number of iterations, input list path, and output folder path DEFINE_string( model_path, "", @@ -25,6 +26,14 @@ DEFINE_int32( num_iter, 1, "Number of inference iterations (default is 1)."); +DEFINE_string( + input_list_path, + "", + "Path to the input list file which includes the list of raw input tensor files (optional)."); +DEFINE_string( + output_folder_path, + "", + "Path to the output folder to save raw output tensor files (optional)."); using executorch::extension::FileDataLoader; using executorch::extension::prepare_input_tensors; @@ -38,6 +47,7 @@ using executorch::runtime::MethodMeta; using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; +using executorch::runtime::TensorInfo; int main(int argc, char** argv) { // Initialize the runtime environment @@ -128,6 +138,72 @@ int main(int argc, char** argv) { inputs.ok(), "Could not prepare inputs: 0x%" PRIx32, static_cast(inputs.error())); + + // If the input path list is provided, read input tensors from the files + if (!(FLAGS_input_list_path.empty())) { + const char* input_list_path = FLAGS_input_list_path.c_str(); + ET_LOG(Info, "Loading input tensors from the list provided in %s.", input_list_path); + Error status = Error::Ok; + std::vector inputs(method->inputs_size()); + ET_LOG(Info, "%zu inputs: ", inputs.size()); + status = method->get_inputs(inputs.data(), inputs.size()); + ET_CHECK(status == Error::Ok); + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + // Read raw input tensor file names from input list file and + // iterate each raw input tensor file to read values + std::ifstream input_list(input_list_path); + if (input_list.is_open()) { + size_t num_inputs = method->inputs_size(); + std::string file_path; + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + if (input_files.size() == 0) { + break; + } + for (int input_index = 0; input_index < num_inputs; ++input_index) { + MethodMeta method_meta = method->method_meta(); + Result tensor_meta = + method_meta.input_tensor_meta(input_index); + auto input_data_ptr = inputs[input_index].toTensor().data_ptr(); + + std::ifstream fin(input_files[input_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + input_index, + file_size, + tensor_meta->nbytes()); + + fin.seekg(0, fin.beg); + fin.read( + static_cast(input_data_ptr), + file_size); + fin.close(); + } + } + } else { + ET_CHECK_MSG(false, + "Failed to read input list file: %s", + input_list_path); + } + } ET_LOG(Info, "Inputs prepared."); // Measure execution time for inference @@ -161,6 +237,23 @@ int main(int argc, char** argv) { status = method->get_outputs(outputs.data(), outputs.size()); ET_CHECK(status == Error::Ok); + // If output folder path is provided, save output tensors + // into raw tensor files. + if (!(FLAGS_output_folder_path.empty())) { + const char* output_folder_path = FLAGS_output_folder_path.c_str(); + ET_LOG(Info, "Saving output tensors into the output folder: %s.", output_folder_path); + for (size_t output_index = 0; output_index < method->outputs_size(); + output_index++) { + auto output_tensor = outputs[output_index].toTensor(); + auto output_file_name = std::string(output_folder_path) + "/output_" + + std::to_string(output_index) + ".raw"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + fout.write( + output_tensor.const_data_ptr(), output_tensor.nbytes()); + fout.close(); + } + } + return 0; }