diff --git a/CMakeLists.txt b/CMakeLists.txt index 50575d26a19..de452bde3ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,8 @@ option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF) option(EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" OFF) +option(EXECUTORCH_BUILD_OPENVINO "Build the Openvino backend" ON) + option(EXECUTORCH_BUILD_PYBIND "Build the Python Bindings" OFF) option(EXECUTORCH_BUILD_QNN "Build the Qualcomm backend" OFF) @@ -656,6 +658,10 @@ if(EXECUTORCH_BUILD_NEURON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek) endif() +if(EXECUTORCH_BUILD_OPENVINO) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/openvino) +endif() + if(EXECUTORCH_BUILD_QNN) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/qualcomm) endif() diff --git a/backends/openvino/CMakeLists.txt b/backends/openvino/CMakeLists.txt new file mode 100644 index 00000000000..4df2015a8d7 --- /dev/null +++ b/backends/openvino/CMakeLists.txt @@ -0,0 +1,77 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +# Set C++ standard +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Ensure compile_commands are generated +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Define common include directories +set(COMMON_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +# Include common directories before others to ensure proper order +include_directories(BEFORE ${COMMON_INCLUDE_DIRS}) + +# Set up EXECUTORCH_ROOT if not already set +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +# Include utility cmake script from the executorch repository +include(${EXECUTORCH_ROOT}/build/Utils.cmake) + +# Update common include directory for ExecuteTorch +set(COMMON_INCLUDE_DIRS ${EXECUTORCH_ROOT}/..) + +# Set OpenVINO directory and include directories from environment variable +set(OPENVINO_DIR "$ENV{INTEL_OPENVINO_DIR}") +if(NOT OPENVINO_DIR) + message(FATAL_ERROR "INTEL_OPENVINO_DIR environment variable is not set.") +endif() + +set(OPENVINO_INCLUDE_DIRS + ${OPENVINO_DIR}/deployment_tools/inference_engine/include + ${OPENVINO_DIR}/runtime/include +) + +# Define OpenVINO library path +set(OPENVINO_LIB_PATH ${OPENVINO_DIR}/runtime/lib/intel64) + +# Define OpenVINO libraries +set(OPENVINO_LIB ${OPENVINO_LIB_PATH}/libopenvino.so) + +# Add the OpenVINO backend library as a shared library +add_library(openvino_backend SHARED) + +# Enable exceptions and RTTI for OpenVINO backend +target_compile_options(openvino_backend PRIVATE "-frtti" "-fexceptions") + +# Include directories for ExecuteTorch and OpenVINO +target_include_directories( + openvino_backend PUBLIC + ${COMMON_INCLUDE_DIRS} + ${OPENVINO_INCLUDE_DIRS} +) + +# Link OpenVINO libraries and executorch core to the backend +target_link_libraries(openvino_backend PRIVATE + ${OPENVINO_LIB} + executorch_core +) + +# Add source files to the OpenVINO backend library +target_sources(openvino_backend PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/runtime/OpenvinoBackend.cpp +) + +# Set additional link options for shared library +target_link_options(openvino_backend PRIVATE -Wl,-rpath=${OPENVINO_LIB_PATH}) + +# Install the OpenVINO backend library to the lib directory +install(TARGETS openvino_backend DESTINATION lib) + diff --git a/backends/openvino/__init__.py b/backends/openvino/__init__.py new file mode 100644 index 00000000000..dac275d3f12 --- /dev/null +++ b/backends/openvino/__init__.py @@ -0,0 +1,4 @@ +from .partitioner import OpenvinoPartitioner +from .preprocess import OpenvinoBackend + +__all__ = [OpenvinoBackend, OpenvinoPartitioner] diff --git a/backends/openvino/openvino_functions.yaml b/backends/openvino/openvino_functions.yaml new file mode 100644 index 00000000000..296d57d7320 --- /dev/null +++ b/backends/openvino/openvino_functions.yaml @@ -0,0 +1,242 @@ +# This yaml file contains operators that are unsupported with openvino backend and +# will use portable kernels for fall back + +- op: _cdist_forward.out + kernels: + - arg_meta: null + kernel_name: torch::executor::_cdist_forward_out + +- op: _pdist_forward.out + kernels: + - arg_meta: null + kernel_name: torch::executor::_pdist_forward_out + +- op: alias_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::alias_copy_out + +- op: any.all_out + kernels: + - arg_meta: null + kernel_name: torch::executor::any_all_out + +- op: any.dims_out + kernels: + - arg_meta: null + kernel_name: torch::executor::any_dims_out + +- op: atan.out + kernels: + - arg_meta: null + kernel_name: torch::executor::atan_out + +- op: atan2.out + kernels: + - arg_meta: null + kernel_name: torch::executor::atan2_out + +- op: bitwise_or.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::bitwise_or_Scalar_out + +- op: bitwise_xor.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::bitwise_xor_Scalar_out + +- op: clamp.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::clamp_tensor_out + +- op: convolution_backward.out + kernels: + - arg_meta: null + kernel_name: torch::executor::convolution_backward_out + +- op: detach_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::detach_copy_out + +- op: diagonal_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::diagonal_copy_out + +- op: expm1.out + kernels: + - arg_meta: null + kernel_name: torch::executor::expm1_out + +- op: floor_divide.out + kernels: + - arg_meta: null + kernel_name: torch::executor::floor_divide_out + +- op: index_put.out + kernels: + - arg_meta: null + kernel_name: torch::executor::index_put_out + +- op: logical_and.out + kernels: + - arg_meta: null + kernel_name: torch::executor::logical_and_out + +- op: logical_or.out + kernels: + - arg_meta: null + kernel_name: torch::executor::logical_or_out + +- op: logical_xor.out + kernels: + - arg_meta: null + kernel_name: torch::executor::logical_xor_out + +- op: logit.out + kernels: + - arg_meta: null + kernel_name: torch::executor::logit_out + +- op: masked_scatter.out + kernels: + - arg_meta: null + kernel_name: torch::executor::masked_scatter_out + +- op: masked_select.out + kernels: + - arg_meta: null + kernel_name: torch::executor::masked_select_out + +- op: narrow_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::narrow_copy_out + +- op: nonzero.out + kernels: + - arg_meta: null + kernel_name: torch::executor::nonzero_out + +- op: pixel_shuffle.out + kernels: + - arg_meta: null + kernel_name: torch::executor::pixel_shuffle_out + +- op: pixel_unshuffle.out + kernels: + - arg_meta: null + kernel_name: torch::executor::pixel_unshuffle_out + +- op: prod.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::prod_int_out + +- op: prod.out + kernels: + - arg_meta: null + kernel_name: torch::executor::prod_out + +- op: remainder.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::remainder_Tensor_out + +- op: remainder.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::remainder_Scalar_out + +- op: repeat_interleave.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::repeat_interleave_Tensor_out + +- op: reflection_pad1d.out + kernels: + - arg_meta: null + kernel_name: torch::executor::reflection_pad1d_out + +- op: reflection_pad3d.out + kernels: + - arg_meta: null + kernel_name: torch::executor::reflection_pad3d_out + +- op: replication_pad1d.out + kernels: + - arg_meta: null + kernel_name: torch::executor::replication_pad1d_out + +- op: replication_pad2d.out + kernels: + - arg_meta: null + kernel_name: torch::executor::replication_pad2d_out + +- op: replication_pad3d.out + kernels: + - arg_meta: null + kernel_name: torch::executor::replication_pad3d_out + +- op: round.out + kernels: + - arg_meta: null + kernel_name: torch::executor::round_out + +- op: scatter_add.out + kernels: + - arg_meta: null + kernel_name: torch::executor::scatter_add_out + +- op: split_copy.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::split_copy_Tensor_out + +- op: squeeze_copy.dim_out + kernels: + - arg_meta: null + kernel_name: torch::executor::squeeze_copy_dim_out + +- op: sub.Scalar_out + kernels: + - arg_meta: null + kernel_name: torch::executor::sub_scalar_out + +- op: t_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::t_copy_out + +- op: transpose_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::transpose_copy_int_out + +- op: trunc.out + kernels: + - arg_meta: null + kernel_name: torch::executor::trunc_out + +- op: unbind_copy.int_out + kernels: + - arg_meta: null + kernel_name: torch::executor::unbind_copy_int_out + +- op: upsample_bilinear2d.vec_out + kernels: + - arg_meta: null + kernel_name: torch::executor::upsample_bilinear2d_vec_out + +- func: dim_order_ops::_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_empty_dim_order_out + +- func: dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_to_dim_order_copy_out diff --git a/backends/openvino/partitioner.py b/backends/openvino/partitioner.py new file mode 100644 index 00000000000..8e621f56508 --- /dev/null +++ b/backends/openvino/partitioner.py @@ -0,0 +1,112 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +from typing import Callable, final, List, Optional, Tuple + +import torch +from executorch.backends.openvino.preprocess import OpenvinoBackend +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data + +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupportBase +import torch.fx as fx +from openvino.frontend.pytorch.torchdynamo.op_support import OperatorSupport + +class OpenvinoOperatorsSupport(OperatorSupportBase): + + def __init__( + self, + op_types_to_skip: Optional[set] = None, + op_names_to_skip: Optional[set] = None, + ) -> None: + if op_types_to_skip is None: + op_types_to_skip = set() + if op_names_to_skip is None: + op_names_to_skip = set() + + self._op_types_to_skip = op_types_to_skip + self._op_names_to_skip = op_names_to_skip + + def is_node_supported(self, _, node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + options = [] + op_type = node.target.__name__ + supported_ops = OperatorSupport(options)._support_dict + if (op_type == "getitem"): + return True + + if ("torch.ops." + str(op_type) in supported_ops): + return True + else: + print("Op not supported: ", "torch.ops." + str(op_type)) + + if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip: + print( + f"[OpenVINO Backend] The {op_type} operator with name '{node.name}' is skipped." + ) + return False + + return False + + +@final +class OpenvinoPartitioner(Partitioner): + + def __init__( + self, + compile_spec: List[CompileSpec], + op_types_to_skip: Optional[set] = None, + op_names_to_skip: Optional[set] = None, + ) -> None: + self.delegation_spec = DelegationSpec(OpenvinoBackend.__name__, compile_spec) + self._op_types_to_skip = op_types_to_skip + self._op_names_to_skip = op_names_to_skip + + def ops_to_not_decompose( + self, + ep: ExportedProgram, + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + ops_not_decompose = [ + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.upsample_bilinear2d.default, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.default, + torch.ops.aten.upsample_nearest2d.vec, + ] + return (ops_not_decompose, None) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + options = {} + gm = fx.symbolic_trace(exported_program.graph_module) + + partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + OpenvinoOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip), + allows_single_node_partition=True + ) + partition_list = partitioner.propose_partitions() + + partition_tags = {} + for partition in partition_list: + for node in partition.nodes: + tag = f"tag{partition.id}" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/backends/openvino/preprocess.py b/backends/openvino/preprocess.py new file mode 100644 index 00000000000..6af45ff63f9 --- /dev/null +++ b/backends/openvino/preprocess.py @@ -0,0 +1,50 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import contextlib +import struct + +from typing import final, List, cast + +import torch +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from openvino.frontend.pytorch.torchdynamo.compile import openvino_compile + +SKIP_COMPILE_SPEC_KEYS = {"ImportForever"} + + +@final +class OpenvinoBackend(BackendDetails): + + @classmethod + def preprocess( + cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec] + ) -> PreprocessResult: + + name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes} + input_names = edge_program.graph_signature.user_inputs + output_names = edge_program.graph_signature.user_outputs + args = [] + for node in edge_program.graph.nodes: + if (node.target in input_names): + args.append( node.meta["val"]) + + input_shapes = [] + output_shapes = [] + + compile_options = {} + for spec in module_compile_spec: + compile_options[spec.key] = spec.value.decode() + + compiled = openvino_compile(edge_program.module(), *args, options=compile_options, executorch=True) + model_bytes = compiled.export_model() + + return PreprocessResult(processed_bytes=model_bytes) diff --git a/backends/openvino/requirements.txt b/backends/openvino/requirements.txt new file mode 100644 index 00000000000..7c3de886e27 --- /dev/null +++ b/backends/openvino/requirements.txt @@ -0,0 +1,8 @@ +datasets +huggingface-hub +safetensors +sentencepiece +tokenizers +transformers +piq +pillow diff --git a/backends/openvino/runtime/OpenvinoBackend.cpp b/backends/openvino/runtime/OpenvinoBackend.cpp new file mode 100644 index 00000000000..95d85445f7e --- /dev/null +++ b/backends/openvino/runtime/OpenvinoBackend.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) Intel Corporation + * + * Licensed under the BSD License (the "License"); you may not use this file + * except in compliance with the License. See the license file in the root + * directory of this source tree for more details. + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "OpenvinoBackend.hpp" + +using namespace std; +using executorch::aten::ScalarType; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::Result; + +namespace executorch { +namespace backends { +namespace openvino { + +OpenvinoBackend::OpenvinoBackend() { + if (!is_available()) { + //ET_LOG(Error, "OpenVINO runtime is not available. Initialization failed."); + throw std::runtime_error("OpenVINO runtime not available"); + } + + //ET_LOG(Info, "OpenVINO runtime successfully verified and initialized."); +} + +bool OpenvinoBackend::is_available() const { + try { + // Create an OpenVINO Core object to verify runtime availability + ov::Core core; + + // Check if at least one device is available + auto devices = core.get_available_devices(); + if (!devices.empty()) { + return true; // OpenVINO is available + } + } catch (const std::exception& e) { + // Log the exception if OpenVINO runtime is not available + ET_LOG(Error, "OpenVINO is not available: %s", e.what()); + } catch (...) { + // Handle any unexpected errors + ET_LOG(Error, "OpenVINO availability check failed due to an unknown error."); + } + + return false; // OpenVINO is not available +} + +Result OpenvinoBackend::init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const { + + ET_LOG(Info, "OpenvinoBackend::init %p", processed->data()); + + ov::Core core; + const char* data_ptr = static_cast(processed->data()); + size_t data_size = processed->size(); + + // Copy data to a string or vector + std::string data_string(data_ptr, data_size); + + // Wrap the data in a stream + std::istringstream compiled_stream(data_string); + + // Import the model + auto compiled_model = core.import_model(compiled_stream, "CPU"); + + // Allocate an infer request + std::shared_ptr infer_request = std::make_shared(compiled_model.create_infer_request()); + + // Allocate execution handle + MemoryAllocator* allocator = context.get_runtime_allocator(); + ExecutionHandle* handle = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle); + handle->compiled_model = std::make_shared(compiled_model); + handle->infer_request = infer_request; + + return handle; +} + +Error OpenvinoBackend::execute( + BackendExecutionContext& context, + DelegateHandle* input_handle, + EValue** args) const { + + ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle; + + auto infer_request = execution_handle->infer_request; + + size_t num_inputs = infer_request->get_compiled_model().inputs().size(); + size_t num_outputs = infer_request->get_compiled_model().outputs().size(); + + // Set inputs + for (size_t i = 0; i < num_inputs; i++) { + auto input_tensor = args[i]->toTensor(); + ov::Shape input_shape(input_tensor.sizes().begin(), input_tensor.sizes().end()); + + // Convert input tensor to OpenVINO tensor + ov::element::Type ov_type = convert_to_openvino_type(input_tensor.scalar_type()); + ov::Tensor ov_input_tensor(ov_type, input_shape, input_tensor.mutable_data_ptr()); + + infer_request->set_input_tensor(i, ov_input_tensor); + } + + // Set outputs + for (size_t i = 0; i < num_outputs; i++) { + auto output_tensor = args[num_inputs+i]->toTensor(); + ov::Shape output_shape(output_tensor.sizes().begin(), output_tensor.sizes().end()); + + // Convert input tensor to OpenVINO tensor + ov::element::Type ov_type = convert_to_openvino_type(output_tensor.scalar_type()); + ov::Tensor ov_output_tensor(ov_type, output_shape, output_tensor.mutable_data_ptr()); + + infer_request->set_output_tensor(i, ov_output_tensor); + } + + // Execute the inference + infer_request->infer(); + + return Error::Ok; +} + +void OpenvinoBackend::destroy(DelegateHandle* handle) const { + if (!handle) { + ET_LOG(Info, "Attempted to destroy a null handle."); + return; + } + + // Cast the handle to the appropriate type + ExecutionHandle* execution_handle = static_cast(handle); + + // Clean up resources + if (execution_handle->infer_request) { + execution_handle->infer_request.reset(); // Release the infer request + ET_LOG(Info, "Infer request successfully destroyed."); + } + + if (execution_handle->compiled_model) { + execution_handle->compiled_model.reset(); // Release the compiled model + ET_LOG(Info, "Compiled model successfully destroyed."); + } + + ET_LOG(Info, "Delegate handle destroyed successfully."); +} + +ov::element::Type OpenvinoBackend::convert_to_openvino_type(ScalarType scalar_type) const { + switch (scalar_type) { + case ScalarType::Float: + return ov::element::f32; + case ScalarType::Int: + return ov::element::i32; + case ScalarType::Char: + return ov::element::i8; + default: + throw std::runtime_error("Unsupported scalar type"); + } +} + +} // namespace openvino +} // namespace backends +} // namespace executorch + +namespace { +auto backend = executorch::backends::openvino::OpenvinoBackend(); +executorch::runtime::Backend backend_id{"OpenvinoBackend", &backend}; +static auto registered = executorch::runtime::register_backend(backend_id); +} // namespace + + diff --git a/backends/openvino/runtime/OpenvinoBackend.hpp b/backends/openvino/runtime/OpenvinoBackend.hpp new file mode 100644 index 00000000000..e6f0e8659fb --- /dev/null +++ b/backends/openvino/runtime/OpenvinoBackend.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Intel Corporation + * + * Licensed under the BSD License (the "License"); you may not use this file + * except in compliance with the License. See the license file in the root + * directory of this source tree for more details. + */ + +#ifndef OPENVINO_BACKEND_HPP +#define OPENVINO_BACKEND_HPP + +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace std; +using executorch::aten::ScalarType; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::Result; + +namespace executorch { +namespace backends { +namespace openvino { + +typedef struct { + std::shared_ptr compiled_model; + std::shared_ptr infer_request; +} ExecutionHandle; + +class OpenvinoBackend final : public ::executorch::runtime::BackendInterface { + public: + OpenvinoBackend(); + ~OpenvinoBackend() = default; + + virtual bool is_available() const override; + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override; + Error execute( + BackendExecutionContext& context, + DelegateHandle* input_handle, + EValue** args) const override; + void destroy(DelegateHandle* handle) const override; + + private: + ov::element::Type convert_to_openvino_type(ScalarType scalar_type) const; +}; + +} // namespace openvino +} // namespace backends +} // namespace executorch + +#endif // OPENVINO_BACKEND_HPP diff --git a/backends/openvino/scripts/build.sh b/backends/openvino/scripts/build.sh new file mode 100755 index 00000000000..0c07a5bb729 --- /dev/null +++ b/backends/openvino/scripts/build.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Exit immediately if a command exits with a non-zero status. +set -e + +# Define the directory where CMakeLists.txt is located +EXECUTORCH_ROOT=$(realpath "$(dirname "$0")/../../..") +echo EXECUTORCH_ROOT=${EXECUTORCH_ROOT} + +main() { + # Set build directory + local build_dir="cmake-openvino-out" + + # Create and enter the build directory + cd "$EXECUTORCH_ROOT" + rm -rf "${build_dir}" + + # Configure the project with CMake + # Note: Add any additional configuration options you need here + cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -B"${build_dir}" + + + # Build the project + cmake --build cmake-openvino-out --target install --config Release -j5 + + # Switch back to the original directory + cd - > /dev/null + + # Print a success message + echo "Build successfully completed." + +} + +main "$@" diff --git a/backends/openvino/tests/ops/base_openvino_op_test.py b/backends/openvino/tests/ops/base_openvino_op_test.py new file mode 100644 index 00000000000..a51b99e8eca --- /dev/null +++ b/backends/openvino/tests/ops/base_openvino_op_test.py @@ -0,0 +1,154 @@ +import os +import subprocess +import tempfile +import unittest + +import numpy as np +import torch +import executorch +from executorch.backends.openvino.partitioner import OpenvinoPartitioner +from executorch.exir.backend.backend_details import CompileSpec +from torch.export import export, ExportedProgram +from executorch.exir import EdgeProgramManager, to_edge +from executorch.backends.openvino.preprocess import OpenvinoBackend + + +class BaseOpenvinoOpTest(unittest.TestCase): + device = "CPU" + build_folder = "" + + atol = 1e-1 + rtol = 1e-1 + + def execute_layer_test( + self, + module: torch.nn.Module, + sample_inputs: tuple[torch.Tensor], + expected_partitions: int = 1, + assert_output_equal: bool = True, + ): + + module = module.eval() + # Export to aten dialect using torch.export + aten_dialect: ExportedProgram = export(module, sample_inputs) + + # Convert to edge dialect + edge_program: EdgeProgramManager = to_edge(aten_dialect) + to_be_lowered_module = edge_program.exported_program() + + # Lower the module to the backend with a custom partitioner + compile_spec = [CompileSpec("device", self.device.encode())] + lowered_module = edge_program.to_backend(OpenvinoPartitioner(compile_spec)) + + # Apply backend-specific passes + exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig()) + + # Check if the number of partitions created matches the expected number of partitions + self.assertEqual( + len(exec_prog.executorch_program.execution_plan[0].delegates), + expected_partitions, + ) + # Check if the individual partitions are assigned to Openvino backend + for i in range(expected_partitions): + self.assertEqual( + exec_prog.executorch_program.execution_plan[0].delegates[i].id, + OpenvinoBackend.__name__, + ) + + # Execute the model and compare the outputs with the reference outputs + if (assert_output_equal): + with tempfile.TemporaryDirectory() as tmp_dir: + input_list = "" + for idx, _ in enumerate(sample_inputs): + input_name = f"input_0_{idx}.raw" + input_list += input_name + " " + input_list = input_list.strip() + "\n" + + output_dir = f"{tmp_dir}/outputs" + + # Execute the module in eager mode to calculate the reference outputs + ref_output = module(*sample_inputs) + if isinstance(ref_output, torch.Tensor): + ref_output = [ref_output,] + + # Serialize the executorch model and save into a temporary file + pte_fname = f"{tmp_dir}/openvino_executorch_test.pte" + with open(pte_fname, "wb") as file: + exec_prog.write_to_file(file) + + # Save inputs into a temporary file + self.generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list) + self.make_output_dir(output_dir) + + # Start a subprocess to execute model with openvino_executor_runner + cmd = [ + f"{self.build_folder}/examples/openvino/openvino_executor_runner", + "--model_path", + pte_fname, + "--input_list_path", + f"{tmp_dir}/input_list.txt", + "--output_folder_path", + output_dir, + ] + + env = dict(os.environ) + proc = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + cwd=tmp_dir, + ) + + stdout_str = proc.stdout.decode('utf-8') + + # Check if execution completed successfully + self.assertIn("Model executed successfully.", stdout_str) + + # Read the outputs from the temporary files + output_dir = f"{tmp_dir}/outputs" + outputs = [] + + for i, f in enumerate(sorted(os.listdir(output_dir))): + filename = os.path.join(output_dir, f) + output = np.fromfile(filename, dtype=ref_output[i].detach().numpy().dtype) + output = torch.from_numpy(output).reshape(ref_output[i].shape) + outputs.append(output) + + # Compare the outputs with the reference outputs + self.assertTrue(len(ref_output) == len(outputs)) + for i in range(len(ref_output)): + self.assertTrue( + torch.allclose( + outputs[i], ref_output[i], atol=self.atol, rtol=self.rtol, equal_nan=True + ), + msg=f"ref_output:\n{ref_output[i]}\n\ntest_output:\n{outputs[i]}", + ) + + def generate_inputs(self, dest_path: str, file_name: str, inputs=None, input_list=None): + input_list_file = None + input_files = [] + + # Prepare input list + if input_list is not None: + input_list_file = f"{dest_path}/{file_name}" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + + # Prepare input data + if inputs is not None: + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{dest_path}/input_{idx}_{i}.raw" + d.detach().numpy().tofile(file_name) + input_files.append(file_name) + + return input_list_file, input_files + + def make_output_dir(self, path: str): + if os.path.exists(path): + for f in os.listdir(path): + os.remove(os.path.join(path, f)) + os.removedirs(path) + os.makedirs(path) diff --git a/backends/openvino/tests/ops/test_add.py b/backends/openvino/tests/ops/test_add.py new file mode 100644 index 00000000000..d298f77e792 --- /dev/null +++ b/backends/openvino/tests/ops/test_add.py @@ -0,0 +1,19 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +class TestAddOperator(BaseOpenvinoOpTest): + + def create_model(self): + class Add(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.add(x, y) + + return Add() + + def test_add(self): + module = self.create_model() + sample_input = (torch.randn(2, 5, 1, 3), torch.randn(2, 5, 1, 3)) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_addmm.py b/backends/openvino/tests/ops/test_addmm.py new file mode 100644 index 00000000000..32f09ebdc29 --- /dev/null +++ b/backends/openvino/tests/ops/test_addmm.py @@ -0,0 +1,25 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +class TestAddMMOperator(BaseOpenvinoOpTest): + + def create_model(self): + class AddMM(torch.nn.Module): + def __init__(self): + super().__init__() + self.alpha = 1. + self.beta = 1. + + def forward(self, x, y, z): + #return torch.add(x, y) + return torch.addmm(x, y, z, alpha=self.alpha, beta=self.beta) + + return AddMM() + + def test_addmm(self): + module = self.create_model() + input_x = torch.randn(4,4, dtype=torch.float32) + input_y = torch.randn(4,4, dtype=torch.float32) + input_z = torch.randn(4,4, dtype=torch.float32) + sample_input = (input_x, input_y, input_z) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_arange.py b/backends/openvino/tests/ops/test_arange.py new file mode 100644 index 00000000000..0dd739a2585 --- /dev/null +++ b/backends/openvino/tests/ops/test_arange.py @@ -0,0 +1,20 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +class TestArangeOperator(BaseOpenvinoOpTest): + + def create_model(self, x): + class Arange(torch.nn.Module): + def __init__(self, x): + super().__init__() + self.x = x + + def forward(self, y): + return torch.arange(self.x, dtype=torch.float32) + y + + return Arange(5) + + def test_arange(self): + module = self.create_model(5) + sample_input = (torch.randn(5),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_batch_norm.py b/backends/openvino/tests/ops/test_batch_norm.py new file mode 100644 index 00000000000..ecb76860434 --- /dev/null +++ b/backends/openvino/tests/ops/test_batch_norm.py @@ -0,0 +1,51 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'weights': True, 'bias': True, 'eps': 1.0 }, + {'weights': True, 'bias': True, 'eps': 0.00005 }, + {'weights': True, 'bias': True, 'eps': 0.5 }, + {'weights': True, 'bias': True, 'eps': 0.042 }, + {'weights': True, 'bias': False, 'eps': 1.0 }, + {'weights': True, 'bias': False, 'eps': 0.00005 }, + {'weights': True, 'bias': False, 'eps': 0.5 }, + {'weights': True, 'bias': False, 'eps': 0.042 }, + {'weights': False, 'bias': True, 'eps': 1.0 }, + {'weights': False, 'bias': True, 'eps': 0.00005 }, + {'weights': False, 'bias': True, 'eps': 0.5 }, + {'weights': False, 'bias': True, 'eps': 0.042 }, + {'weights': False, 'bias': False, 'eps': 1.0 }, + {'weights': False, 'bias': False, 'eps': 0.00005 }, + {'weights': False, 'bias': False, 'eps': 0.5 }, + {'weights': False, 'bias': False, 'eps': 0.042 }, + ] + + +class TestBatchNormOperator(BaseOpenvinoOpTest): + + def create_model(self, weights, bias, eps): + + class BatchNorm(torch.nn.Module): + def __init__(self, weights=True, bias=True, eps=1e-05): + super(BatchNorm, self).__init__() + self.weight = torch.nn.Parameter(torch.randn(6)) if weights else None + self.bias = torch.nn.Parameter(torch.randn(6)) if bias else None + self.running_mean = torch.randn(6) + self.running_var = torch.randn(6) + self.eps = eps + + def forward(self, x): + return torch.nn.functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, eps=self.eps, training=False) + + return BatchNorm(weights, bias, eps) + + + def test_batch_norm(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model(weights=params['weights'], + bias=params['bias'], + eps=params['eps']) + + sample_input = (torch.randn(20, 6, 10),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_convolution.py b/backends/openvino/tests/ops/test_convolution.py new file mode 100644 index 00000000000..83a80282089 --- /dev/null +++ b/backends/openvino/tests/ops/test_convolution.py @@ -0,0 +1,105 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +d2_params = [{'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [1, 1], 'groups': 1, + 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 2, 2], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ + 1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ + 1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0], 'dilations': [ + 1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ + 1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ + 1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ + 3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ + 3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'bias_shape': [1], 'pads': [ + 1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [1, 1], 'pads': [ + 0, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 1, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 0, 1], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 1, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [ + 0, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 1], 'bias_shape': [1], 'pads': [ + 1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 3, 1, 1], 'strides': [2, 2], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 3, 1, 1], 'strides': [2, 1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': False}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ + 0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0], 'transposed': True}, + {'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'bias_shape': [1], 'pads': [ + 1, 1], 'dilations': [2, 2], 'groups': 1, 'output_padding': [1, 1], 'transposed': True}, + ] + +class TestConvolutionOperator(BaseOpenvinoOpTest): + + def create_model(self, weights_shape, strides, pads, dilations, groups, bias, transposed, output_padding=0, + bias_shape=None, underscore=False): + + bias_dim = 0 + + class Convolution(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(weights_shape)) + self.bias_shape = bias_shape + if self.bias_shape is None: + self.bias_shape = weights_shape[bias_dim] + self.bias = torch.nn.Parameter(torch.randn(self.bias_shape)) if bias else None + self.strides = strides + self.pads = pads + self.dilations = dilations + self.groups = groups + self.transposed = transposed + self.output_padding = output_padding + if underscore: + self.forward = self.forward_ + + def forward(self, x): + return torch.convolution( + x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed, + self.output_padding, self.groups + ) + + def forward_(self, x): + return torch._convolution( + x, self.weight, self.bias, self.strides, self.pads, self.dilations, self.transposed, + self.output_padding, self.groups, False, False, False, False + ) + + return Convolution() + + def test_convolution(self): + bias_underscore_config = [(False, False), (True, False)] + for bias, underscore in bias_underscore_config: + for params in d2_params: + with self.subTest(params=params, bias=bias, underscore=underscore): + bias_shape = None + if 'bias_shape' in params: + bias_shape = params['bias_shape'] + module = self.create_model(weights_shape=params['weights_shape'], + strides=params['strides'], + pads=params['pads'], + dilations=params['dilations'], + groups=params['groups'], + output_padding=params['output_padding'], + transposed=params['transposed'], + bias_shape=bias_shape, + bias=bias, + underscore=underscore) + sample_input = (torch.randn(1, 3, 10, 10),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_mean.py b/backends/openvino/tests/ops/test_mean.py new file mode 100644 index 00000000000..3315fd1e61d --- /dev/null +++ b/backends/openvino/tests/ops/test_mean.py @@ -0,0 +1,59 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'axes': None, 'keep_dim': None, 'dtype': None, }, + {'axes': None, 'keep_dim': None, 'dtype': "float64",}, + {'axes': None, 'keep_dim': None, 'dtype': "float32",}, + {'axes': None, 'keep_dim': None, 'dtype': "int32", }, + {'axes': 0, 'keep_dim': False, 'dtype': None, }, + {'axes': 0, 'keep_dim': False, 'dtype': None, }, + ] + +dtypes = { + "float32": torch.float32, + "float64": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "int8": torch.int8, + "uint8": torch.uint8 +} + +class TestMeanOperator(BaseOpenvinoOpTest): + + def create_model(self, axes, keep_dims, dtype): + + pt_dtype = dtypes.get(dtype) + + class Mean(torch.nn.Module): + def __init__(self, axes=None, keep_dims=None, dtype=None): + super(Mean, self).__init__() + self.axes = axes + self.keep_dims = keep_dims + self.dtype = dtype + + def forward(self, x): + if self.axes is None and self.keep_dims is None: + if self.dtype is None: + return torch.mean(x, dtype=self.dtype) + return torch.mean(x) + if self.axes is not None and self.keep_dims is None: + if self.dtype is None: + return torch.mean(x, self.axes) + return torch.mean(x, self.axes, dtype=self.dtype) + if self.dtype is None: + return torch.mean(x, self.axes, self.keep_dims) + return torch.mean(x, self.axes, self.keep_dims, dtype=self.dtype) + + return Mean(axes, keep_dims, pt_dtype) + + + def test_mean(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model(axes=params['axes'], + keep_dims=params['keep_dim'], + dtype=params['dtype']) + + sample_input = (torch.randint(-10, 10, (1, 3, 224, 224)).to(dtype=torch.float32),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_permute.py b/backends/openvino/tests/ops/test_permute.py new file mode 100644 index 00000000000..1de60db3965 --- /dev/null +++ b/backends/openvino/tests/ops/test_permute.py @@ -0,0 +1,30 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'order': [0, 2, 3, 1] }, + {'order': [0, 3, 1, 2] }, + ] + +class TestPermuteOperator(BaseOpenvinoOpTest): + + def create_model(self, order): + + class Permute(torch.nn.Module): + def __init__(self, order): + super(Permute, self).__init__() + self.order = order + + def forward(self, x): + return torch.permute(x, self.order) + + return Permute(order) + + + def test_permute(self): + for params in op_params: + with self.subTest(params=params): + module = self.create_model(order=params['order']) + + sample_input = (torch.randn(1, 3, 224, 224),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_pooling.py b/backends/openvino/tests/ops/test_pooling.py new file mode 100644 index 00000000000..60ab2f9edfa --- /dev/null +++ b/backends/openvino/tests/ops/test_pooling.py @@ -0,0 +1,65 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': 1}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [0, 1]}, + {'kernel_size': [3, 3], 'stride': [1, 1], 'padding': [1, 0]}, + {'kernel_size': [3, 3], 'stride': [2, 1], 'padding': 0}, + {'kernel_size': [2, 1], 'stride': [2, 1], 'padding': 0}, + {'kernel_size': [2, 1], 'stride': None, 'padding': 0}, + {'kernel_size': [2, 1], 'stride': [], 'padding': 0}, + {'kernel_size': [8, 8], 'stride': [8, 4], 'padding': 1}, + ] + +class TestPoolingOperator(BaseOpenvinoOpTest): + + def create_model(self, op_type, kernel_size, stride, padding, dilation=1, ceil_mode=True, count_include_pad=True, dtype=torch.float32): + + class MaxPoolingBase(torch.nn.Module): + def __init__(self): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.dtype = dtype + + def forward(self, x): + pass + + class MaxPool2D(MaxPoolingBase): + def forward(self, x): + return torch.nn.functional.max_pool2d(x.to(self.dtype), self.kernel_size, self.stride, self.padding, self.dilation, + self.ceil_mode) + + class MaxPool2DIndices(MaxPoolingBase): + def forward(self, x): + return torch.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding, self.dilation, + self.ceil_mode, return_indices=True) + + ops = { + "MaxPool2D": MaxPool2D, + "MaxPool2DIndices": MaxPool2DIndices, + } + + aten_pooling = ops[op_type] + + return aten_pooling() + + def test_pooling2d(self): + for params in d2_params: + with self.subTest(params=params): + bias_shape = None + if 'bias_shape' in params: + bias_shape = params['bias_shape'] + module = self.create_model(op_type='MaxPool2D', + kernel_size=params['kernel_size'], + stride=params['stride'], + padding=params['padding'], + dilation=1, + ceil_mode=True, + count_include_pad=True) + sample_input = (torch.randn(1, 3, 15, 15),) + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_unary_ops.py b/backends/openvino/tests/ops/test_unary_ops.py new file mode 100644 index 00000000000..9a5866d6e65 --- /dev/null +++ b/backends/openvino/tests/ops/test_unary_ops.py @@ -0,0 +1,36 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + + +OPS = [ + torch.relu, +] + + +class TestUnaryOperator(BaseOpenvinoOpTest): + + def create_model(self, op, dtype): + + class UnaryOp(torch.nn.Module): + def __init__(self, op, dtype): + super().__init__() + self.dtype = dtype + self.op = op + + def forward(self, x): + x1 = x.to(self.dtype) + y = self.op(x1) + return y, x1 + + return UnaryOp(op, dtype) + + + def test_unary_op(self): + for op in OPS: + with self.subTest(op=OPS): + + module = self.create_model(op, dtype=torch.float32) + + sample_input = (torch.rand(2, 10) * 10 + 1,) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/ops/test_view.py b/backends/openvino/tests/ops/test_view.py new file mode 100644 index 00000000000..f5450a10af9 --- /dev/null +++ b/backends/openvino/tests/ops/test_view.py @@ -0,0 +1,32 @@ +from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest +import torch + +op_params = [{'input_shape': [2, 3, 2], 'target_shape': [2, 6] }, + {'input_shape': [4], 'target_shape': [2, 2] }, + ] + +class TestViewOperator(BaseOpenvinoOpTest): + + def create_model(self, target_shape): + + class View(torch.nn.Module): + + def __init__(self, target_shape) -> None: + super().__init__() + self.target_shape = target_shape + + def forward(self, input_tensor): + return input_tensor.view(self.target_shape) + + return View(target_shape) + + + def test_view(self): + for params in op_params: + with self.subTest(params=params): + + module = self.create_model(params['target_shape']) + + sample_input = (torch.randn(params['input_shape']),) + + self.execute_layer_test(module, sample_input) diff --git a/backends/openvino/tests/test_openvino_delegate.py b/backends/openvino/tests/test_openvino_delegate.py new file mode 100644 index 00000000000..bbf61d1ea09 --- /dev/null +++ b/backends/openvino/tests/test_openvino_delegate.py @@ -0,0 +1,65 @@ +import unittest +import argparse + +class OpenvinoTestSuite(unittest.TestSuite): + + test_params = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def addTest(self, test): + # Set test parameters if this is an instance of TestOpenvino + from executorch.backends.openvino.tests.ops.base_openvino_op_test import BaseOpenvinoOpTest + if isinstance(test, BaseOpenvinoOpTest): + if "device" in self.test_params: + test.device = self.test_params["device"] + if "build_folder" in self.test_params: + test.build_folder = self.test_params["build_folder"] + # Call the original addTest method to actually add the test to the suite + super().addTest(test) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-b", + "--build_folder", + help="path to cmake binary directory", + type=str, + required=True, + ) + parser.add_argument( + "-s", + "--device", + help="OpenVINO device to execute the model on", + type=str, + default="CPU", + ) + parser.add_argument( + "-p", + "--pattern", + help="Pattern to match test files. Provide complete file name to run individual op tests", + type=str, + default="test_*.py", + ) + + args, ns_args = parser.parse_known_args(namespace=unittest) + test_params = {} + test_params["device"] = args.device + test_params["build_folder"] = args.build_folder + test_params["pattern"] = args.pattern + return test_params + +if __name__ == "__main__": + loader = unittest.TestLoader() + # Replace the default test suite with a custom test suite to be able to + # pass test parameter to the test cases + loader.suiteClass = OpenvinoTestSuite + test_params = parse_arguments() + loader.suiteClass.test_params = test_params + # Discover all existing op tests in "ops" folder + suite = loader.discover("ops", pattern=test_params['pattern']) + # Start running tests + unittest.TextTestRunner().run(suite) diff --git a/build/executorch_srcs.cmake b/build/executorch_srcs.cmake new file mode 100644 index 00000000000..a44fe650da2 --- /dev/null +++ b/build/executorch_srcs.cmake @@ -0,0 +1,448 @@ +# @generated by extract_sources.py + +set(_executorch__srcs + kernels/prim_ops/et_copy_index.cpp + kernels/prim_ops/et_view.cpp + kernels/prim_ops/register_prim_ops.cpp +) + +set(_executorch_core__srcs + runtime/backend/interface.cpp + runtime/core/evalue.cpp + runtime/core/exec_aten/util/tensor_util_portable.cpp + runtime/core/portable_type/tensor_impl.cpp + runtime/executor/method.cpp + runtime/executor/method_meta.cpp + runtime/executor/program.cpp + runtime/executor/tensor_parser_exec_aten.cpp + runtime/executor/tensor_parser_portable.cpp + runtime/kernel/operator_registry.cpp + runtime/platform/abort.cpp + runtime/platform/default/posix.cpp + runtime/platform/log.cpp + runtime/platform/profiler.cpp + runtime/platform/runtime.cpp + schema/extended_header.cpp +) + +set(_portable_kernels__srcs + kernels/portable/cpu/op__to_dim_order_copy.cpp + kernels/portable/cpu/op_abs.cpp + kernels/portable/cpu/op_acos.cpp + kernels/portable/cpu/op_acosh.cpp + kernels/portable/cpu/op_add.cpp + kernels/portable/cpu/op_addmm.cpp + kernels/portable/cpu/op_alias_copy.cpp + kernels/portable/cpu/op_allclose.cpp + kernels/portable/cpu/op_amax.cpp + kernels/portable/cpu/op_amin.cpp + kernels/portable/cpu/op_any.cpp + kernels/portable/cpu/op_arange.cpp + kernels/portable/cpu/op_argmax.cpp + kernels/portable/cpu/op_argmin.cpp + kernels/portable/cpu/op_as_strided_copy.cpp + kernels/portable/cpu/op_asin.cpp + kernels/portable/cpu/op_asinh.cpp + kernels/portable/cpu/op_atan.cpp + kernels/portable/cpu/op_atan2.cpp + kernels/portable/cpu/op_atanh.cpp + kernels/portable/cpu/op_avg_pool2d.cpp + kernels/portable/cpu/op_bitwise_and.cpp + kernels/portable/cpu/op_bitwise_not.cpp + kernels/portable/cpu/op_bitwise_or.cpp + kernels/portable/cpu/op_bitwise_xor.cpp + kernels/portable/cpu/op_bmm.cpp + kernels/portable/cpu/op_cat.cpp + kernels/portable/cpu/op_cdist_forward.cpp + kernels/portable/cpu/op_ceil.cpp + kernels/portable/cpu/op_clamp.cpp + kernels/portable/cpu/op_clone.cpp + kernels/portable/cpu/op_constant_pad_nd.cpp + kernels/portable/cpu/op_convolution.cpp + kernels/portable/cpu/op_convolution_backward.cpp + kernels/portable/cpu/op_copy.cpp + kernels/portable/cpu/op_cos.cpp + kernels/portable/cpu/op_cosh.cpp + kernels/portable/cpu/op_cumsum.cpp + kernels/portable/cpu/op_detach_copy.cpp + kernels/portable/cpu/op_diagonal_copy.cpp + kernels/portable/cpu/op_div.cpp + kernels/portable/cpu/op_embedding.cpp + kernels/portable/cpu/op_empty.cpp + kernels/portable/cpu/op_eq.cpp + kernels/portable/cpu/op_erf.cpp + kernels/portable/cpu/op_exp.cpp + kernels/portable/cpu/op_expand_copy.cpp + kernels/portable/cpu/op_expm1.cpp + kernels/portable/cpu/op_fill.cpp + kernels/portable/cpu/op_flip.cpp + kernels/portable/cpu/op_floor.cpp + kernels/portable/cpu/op_floor_divide.cpp + kernels/portable/cpu/op_fmod.cpp + kernels/portable/cpu/op_full.cpp + kernels/portable/cpu/op_full_like.cpp + kernels/portable/cpu/op_gather.cpp + kernels/portable/cpu/op_ge.cpp + kernels/portable/cpu/op_gelu.cpp + kernels/portable/cpu/op_glu.cpp + kernels/portable/cpu/op_gt.cpp + kernels/portable/cpu/op_hardtanh.cpp + kernels/portable/cpu/op_index.cpp + kernels/portable/cpu/op_index_put.cpp + kernels/portable/cpu/op_index_select.cpp + kernels/portable/cpu/op_isinf.cpp + kernels/portable/cpu/op_isnan.cpp + kernels/portable/cpu/op_le.cpp + kernels/portable/cpu/op_leaky_relu.cpp + kernels/portable/cpu/op_lift_fresh_copy.cpp + kernels/portable/cpu/op_linear_scratch_example.cpp + kernels/portable/cpu/op_log.cpp + kernels/portable/cpu/op_log10.cpp + kernels/portable/cpu/op_log1p.cpp + kernels/portable/cpu/op_log2.cpp + kernels/portable/cpu/op_log_softmax.cpp + kernels/portable/cpu/op_logical_and.cpp + kernels/portable/cpu/op_logical_not.cpp + kernels/portable/cpu/op_logical_or.cpp + kernels/portable/cpu/op_logical_xor.cpp + kernels/portable/cpu/op_logit.cpp + kernels/portable/cpu/op_lt.cpp + kernels/portable/cpu/op_masked_fill.cpp + kernels/portable/cpu/op_masked_scatter.cpp + kernels/portable/cpu/op_max.cpp + kernels/portable/cpu/op_max_pool2d_with_indices.cpp + kernels/portable/cpu/op_maximum.cpp + kernels/portable/cpu/op_mean.cpp + kernels/portable/cpu/op_min.cpp + kernels/portable/cpu/op_minimum.cpp + kernels/portable/cpu/op_mm.cpp + kernels/portable/cpu/op_mul.cpp + kernels/portable/cpu/op_narrow_copy.cpp + kernels/portable/cpu/op_native_batch_norm.cpp + kernels/portable/cpu/op_native_group_norm.cpp + kernels/portable/cpu/op_native_layer_norm.cpp + kernels/portable/cpu/op_ne.cpp + kernels/portable/cpu/op_neg.cpp + kernels/portable/cpu/op_nonzero.cpp + kernels/portable/cpu/op_ones.cpp + kernels/portable/cpu/op_pdist_forward.cpp + kernels/portable/cpu/op_permute_copy.cpp + kernels/portable/cpu/op_pixel_shuffle.cpp + kernels/portable/cpu/op_pixel_unshuffle.cpp + kernels/portable/cpu/op_pow.cpp + kernels/portable/cpu/op_prod.cpp + kernels/portable/cpu/op_reciprocal.cpp + kernels/portable/cpu/op_reflection_pad1d.cpp + kernels/portable/cpu/op_reflection_pad2d.cpp + kernels/portable/cpu/op_reflection_pad3d.cpp + kernels/portable/cpu/op_relu.cpp + kernels/portable/cpu/op_remainder.cpp + kernels/portable/cpu/op_repeat.cpp + kernels/portable/cpu/op_replication_pad1d.cpp + kernels/portable/cpu/op_replication_pad2d.cpp + kernels/portable/cpu/op_replication_pad3d.cpp + kernels/portable/cpu/op_roll.cpp + kernels/portable/cpu/op_round.cpp + kernels/portable/cpu/op_rsqrt.cpp + kernels/portable/cpu/op_rsub.cpp + kernels/portable/cpu/op_scalar_tensor.cpp + kernels/portable/cpu/op_scatter.cpp + kernels/portable/cpu/op_scatter_add.cpp + kernels/portable/cpu/op_select_copy.cpp + kernels/portable/cpu/op_select_scatter.cpp + kernels/portable/cpu/op_sigmoid.cpp + kernels/portable/cpu/op_sign.cpp + kernels/portable/cpu/op_sin.cpp + kernels/portable/cpu/op_sinh.cpp + kernels/portable/cpu/op_slice_copy.cpp + kernels/portable/cpu/op_slice_scatter.cpp + kernels/portable/cpu/op_softmax.cpp + kernels/portable/cpu/op_split_copy.cpp + kernels/portable/cpu/op_split_with_sizes_copy.cpp + kernels/portable/cpu/op_sqrt.cpp + kernels/portable/cpu/op_squeeze_copy.cpp + kernels/portable/cpu/op_stack.cpp + kernels/portable/cpu/op_sub.cpp + kernels/portable/cpu/op_sum.cpp + kernels/portable/cpu/op_t_copy.cpp + kernels/portable/cpu/op_tan.cpp + kernels/portable/cpu/op_tanh.cpp + kernels/portable/cpu/op_to_copy.cpp + kernels/portable/cpu/op_topk.cpp + kernels/portable/cpu/op_transpose_copy.cpp + kernels/portable/cpu/op_tril.cpp + kernels/portable/cpu/op_trunc.cpp + kernels/portable/cpu/op_unbind_copy.cpp + kernels/portable/cpu/op_unsqueeze_copy.cpp + kernels/portable/cpu/op_var.cpp + kernels/portable/cpu/op_view_copy.cpp + kernels/portable/cpu/op_where.cpp + kernels/portable/cpu/op_zeros.cpp + kernels/portable/cpu/pattern/unary_ufunc_realh.cpp + kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp + kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp + kernels/portable/cpu/util/activation_ops_util.cpp + kernels/portable/cpu/util/advanced_index_util.cpp + kernels/portable/cpu/util/broadcast_util.cpp + kernels/portable/cpu/util/copy_ops_util.cpp + kernels/portable/cpu/util/distance_util.cpp + kernels/portable/cpu/util/dtype_util.cpp + kernels/portable/cpu/util/index_util.cpp + kernels/portable/cpu/util/kernel_ops_util.cpp + kernels/portable/cpu/util/matmul_ops_util.cpp + kernels/portable/cpu/util/normalization_ops_util.cpp + kernels/portable/cpu/util/padding_util.cpp + kernels/portable/cpu/util/reduce_util.cpp + kernels/portable/cpu/util/repeat_util.cpp + kernels/portable/cpu/util/select_copy_util.cpp + kernels/portable/cpu/util/slice_util.cpp +) + +set(_optimized_kernels__srcs + extension/parallel/thread_parallel.cpp + kernels/optimized/blas/BlasKernel.cpp + kernels/optimized/blas/CPUBlas.cpp + kernels/optimized/cpu/op_add.cpp + kernels/optimized/cpu/op_bmm.cpp + kernels/optimized/cpu/op_div.cpp + kernels/optimized/cpu/op_exp.cpp + kernels/optimized/cpu/op_le.cpp + kernels/optimized/cpu/op_linear.cpp + kernels/optimized/cpu/op_mm.cpp + kernels/optimized/cpu/op_mul.cpp + kernels/optimized/cpu/op_native_layer_norm.cpp + kernels/optimized/cpu/op_neg.cpp + kernels/optimized/cpu/op_sub.cpp +) + +set(_quantized_kernels__srcs + kernels/quantized/cpu/embeddingxb.cpp + kernels/quantized/cpu/op_add.cpp + kernels/quantized/cpu/op_choose_qparams.cpp + kernels/quantized/cpu/op_dequantize.cpp + kernels/quantized/cpu/op_embedding.cpp + kernels/quantized/cpu/op_embedding2b.cpp + kernels/quantized/cpu/op_embedding4b.cpp + kernels/quantized/cpu/op_mixed_linear.cpp + kernels/quantized/cpu/op_mixed_mm.cpp + kernels/quantized/cpu/op_quantize.cpp +) + +set(_program_schema__srcs + schema/program.fbs + schema/scalar_type.fbs +) + +set(_optimized_cpublas__srcs + extension/parallel/thread_parallel.cpp + extension/threadpool/threadpool.cpp + extension/threadpool/threadpool_guard.cpp + kernels/optimized/blas/BlasKernel.cpp + kernels/optimized/blas/CPUBlas.cpp +) + +set(_optimized_native_cpu_ops_oss__srcs + codegen/templates/RegisterCodegenUnboxedKernels.cpp + codegen/templates/RegisterDispatchKeyCustomOps.cpp + codegen/templates/RegisterKernels.cpp + codegen/templates/RegisterSchema.cpp + extension/parallel/thread_parallel.cpp + extension/threadpool/threadpool.cpp + extension/threadpool/threadpool_guard.cpp + kernels/optimized/blas/BlasKernel.cpp + kernels/optimized/blas/CPUBlas.cpp + kernels/optimized/cpu/op_add.cpp + kernels/optimized/cpu/op_bmm.cpp + kernels/optimized/cpu/op_div.cpp + kernels/optimized/cpu/op_exp.cpp + kernels/optimized/cpu/op_le.cpp + kernels/optimized/cpu/op_linear.cpp + kernels/optimized/cpu/op_mm.cpp + kernels/optimized/cpu/op_mul.cpp + kernels/optimized/cpu/op_native_layer_norm.cpp + kernels/optimized/cpu/op_neg.cpp + kernels/optimized/cpu/op_sub.cpp +) + +set(_extension_data_loader__srcs + extension/data_loader/file_data_loader.cpp + extension/data_loader/mmap_data_loader.cpp +) + +set(_extension_module__srcs + extension/module/module.cpp +) + +set(_extension_runner_util__srcs + extension/runner_util/inputs.cpp + extension/runner_util/inputs_portable.cpp +) + +set(_extension_llm_runner__srcs + extension/data_loader/file_data_loader.cpp + extension/data_loader/mmap_data_loader.cpp + extension/llm/runner/text_decoder_runner.cpp + extension/llm/runner/text_prefiller.cpp + extension/llm/sampler/sampler.cpp + extension/tensor/tensor_ptr.cpp + extension/tensor/tensor_ptr_maker.cpp +) + +set(_extension_tensor__srcs + extension/tensor/tensor_ptr.cpp + extension/tensor/tensor_ptr_maker.cpp +) + +set(_extension_threadpool__srcs + extension/threadpool/threadpool.cpp + extension/threadpool/threadpool_guard.cpp +) + +set(_extension_training__srcs + extension/data_loader/file_data_loader.cpp + extension/data_loader/mmap_data_loader.cpp + extension/module/module.cpp + extension/training/module/training_module.cpp + extension/training/optimizer/sgd.cpp + kernels/prim_ops/et_copy_index.cpp + kernels/prim_ops/et_view.cpp + kernels/prim_ops/register_prim_ops.cpp +) + +set(_train_xor__srcs + extension/data_loader/file_data_loader.cpp + extension/data_loader/mmap_data_loader.cpp + extension/module/module.cpp + extension/tensor/tensor_ptr.cpp + extension/tensor/tensor_ptr_maker.cpp + extension/training/examples/XOR/train.cpp + extension/training/module/training_module.cpp + extension/training/optimizer/sgd.cpp +) + +set(_executor_runner__srcs + examples/portable/executor_runner/executor_runner.cpp + extension/data_loader/file_data_loader.cpp + extension/evalue_util/print_evalue.cpp + extension/runner_util/inputs.cpp + extension/runner_util/inputs_portable.cpp + runtime/executor/test/test_backend_compiler_lib.cpp +) + +set(_size_test__srcs + extension/data_loader/file_data_loader.cpp + test/size_test.cpp +) + +set(_mps_executor_runner__srcs + backends/apple/mps/runtime/MPSBackend.mm + backends/apple/mps/runtime/MPSCompiler.mm + backends/apple/mps/runtime/MPSDelegateHeader.mm + backends/apple/mps/runtime/MPSDevice.mm + backends/apple/mps/runtime/MPSExecutor.mm + backends/apple/mps/runtime/MPSGraphBuilder.mm + backends/apple/mps/runtime/MPSStream.mm + backends/apple/mps/runtime/operations/ActivationOps.mm + backends/apple/mps/runtime/operations/BinaryOps.mm + backends/apple/mps/runtime/operations/ClampOps.mm + backends/apple/mps/runtime/operations/ConstantOps.mm + backends/apple/mps/runtime/operations/ConvolutionOps.mm + backends/apple/mps/runtime/operations/IndexingOps.mm + backends/apple/mps/runtime/operations/LinearAlgebra.mm + backends/apple/mps/runtime/operations/NormalizationOps.mm + backends/apple/mps/runtime/operations/OperationUtils.mm + backends/apple/mps/runtime/operations/PadOps.mm + backends/apple/mps/runtime/operations/PoolingOps.mm + backends/apple/mps/runtime/operations/QuantDequant.mm + backends/apple/mps/runtime/operations/RangeOps.mm + backends/apple/mps/runtime/operations/ReduceOps.mm + backends/apple/mps/runtime/operations/ShapeOps.mm + backends/apple/mps/runtime/operations/UnaryOps.mm + devtools/bundled_program/bundled_program.cpp + devtools/etdump/emitter.cpp + devtools/etdump/etdump_flatcc.cpp + examples/apple/mps/executor_runner/mps_executor_runner.mm + extension/data_loader/file_data_loader.cpp + extension/evalue_util/print_evalue.cpp + extension/runner_util/inputs.cpp + extension/runner_util/inputs_portable.cpp +) + +set(_mps_backend__srcs + backends/apple/mps/runtime/MPSBackend.mm + backends/apple/mps/runtime/MPSCompiler.mm + backends/apple/mps/runtime/MPSDelegateHeader.mm + backends/apple/mps/runtime/MPSDevice.mm + backends/apple/mps/runtime/MPSExecutor.mm + backends/apple/mps/runtime/MPSGraphBuilder.mm + backends/apple/mps/runtime/MPSStream.mm + backends/apple/mps/runtime/operations/ActivationOps.mm + backends/apple/mps/runtime/operations/BinaryOps.mm + backends/apple/mps/runtime/operations/ClampOps.mm + backends/apple/mps/runtime/operations/ConstantOps.mm + backends/apple/mps/runtime/operations/ConvolutionOps.mm + backends/apple/mps/runtime/operations/IndexingOps.mm + backends/apple/mps/runtime/operations/LinearAlgebra.mm + backends/apple/mps/runtime/operations/NormalizationOps.mm + backends/apple/mps/runtime/operations/OperationUtils.mm + backends/apple/mps/runtime/operations/PadOps.mm + backends/apple/mps/runtime/operations/PoolingOps.mm + backends/apple/mps/runtime/operations/QuantDequant.mm + backends/apple/mps/runtime/operations/RangeOps.mm + backends/apple/mps/runtime/operations/ReduceOps.mm + backends/apple/mps/runtime/operations/ShapeOps.mm + backends/apple/mps/runtime/operations/UnaryOps.mm +) + +set(_mps_schema__srcs + backends/apple/mps/serialization/schema.fbs +) + +set(_xnn_executor_runner__srcs + examples/portable/executor_runner/executor_runner.cpp + extension/data_loader/file_data_loader.cpp + extension/evalue_util/print_evalue.cpp + extension/runner_util/inputs.cpp + extension/runner_util/inputs_portable.cpp +) + +set(_xnnpack_backend__srcs + backends/xnnpack/runtime/XNNCompiler.cpp + backends/xnnpack/runtime/XNNExecutor.cpp + backends/xnnpack/runtime/XNNHeader.cpp + backends/xnnpack/runtime/XNNPACKBackend.cpp + backends/xnnpack/runtime/profiling/XNNProfiler.cpp + extension/threadpool/threadpool.cpp + extension/threadpool/threadpool_guard.cpp +) + +set(_xnnpack_schema__srcs + backends/xnnpack/serialization/runtime_schema.fbs +) + +set(_vulkan_schema__srcs + backends/vulkan/serialization/schema.fbs +) + +set(_custom_ops__srcs + extension/llm/custom_ops/op_fallback.cpp + extension/llm/custom_ops/op_fast_hadamard_transform.cpp + extension/llm/custom_ops/op_sdpa.cpp + extension/llm/custom_ops/op_update_quantized_cache.cpp + extension/llm/custom_ops/spinquant/fast_hadamard_transform.cpp + extension/llm/custom_ops/spinquant/third-party/FFHT/fht_avx.c + kernels/portable/cpu/util/reduce_util.cpp +) + +set(_llama_runner__srcs + examples/models/llama/runner/runner.cpp + examples/models/llama/tokenizer/llama_tiktoken.cpp + extension/evalue_util/print_evalue.cpp + extension/llm/runner/text_decoder_runner.cpp + extension/llm/runner/text_prefiller.cpp + extension/llm/sampler/sampler.cpp + extension/llm/tokenizer/bpe_tokenizer.cpp + extension/llm/tokenizer/tiktoken.cpp + extension/tensor/tensor_ptr.cpp + extension/tensor/tensor_ptr_maker.cpp +) \ No newline at end of file diff --git a/examples/openvino/CMakeLists.txt b/examples/openvino/CMakeLists.txt new file mode 100644 index 00000000000..64f1e8d5463 --- /dev/null +++ b/examples/openvino/CMakeLists.txt @@ -0,0 +1,104 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +set(CMAKE_CXX_STANDARD 17) + +cmake_minimum_required(VERSION 3.19) +project(openvino_runner_example) + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/build/Utils.cmake) +include(${EXECUTORCH_ROOT}/build/Codegen.cmake) + +if(NOT PYTHON_EXECUTABLE) + resolve_python_executable() +endif() + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Debug) +endif() + +set(_common_compile_options -Wno-deprecated-declarations -fPIC) + +# Let files say "include ". +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# +# The `__srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}. +# +set(EXECUTORCH_SRCS_FILE + "${CMAKE_CURRENT_BINARY_DIR}/../../../build/executorch_srcs.cmake" +) +extract_sources(${EXECUTORCH_SRCS_FILE}) +include(${EXECUTORCH_SRCS_FILE}) + +set(_openvino_executor_runner__srcs ${CMAKE_CURRENT_LIST_DIR}/../openvino/executor_runner/openvino_executor_runner.cpp) + +# preprocess executor runner src files +list(PREPEND _openvino_executor_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/../openvino/executor_runner/openvino_executor_runner.cpp +) + +find_package(executorch CONFIG REQUIRED) +target_include_directories(executorch INTERFACE ${_common_include_directories}) +target_compile_options(executorch INTERFACE ${_common_compile_options}) + +# portable_ops_lib +gen_selected_ops(LIB_NAME "openvino_portable_ops_lib" INCLUDE_ALL_OPS "ON") +generate_bindings_for_kernels( + LIB_NAME "openvino_portable_ops_lib" FUNCTIONS_YAML + ${EXECUTORCH_ROOT}/backends/openvino/openvino_functions.yaml +) +gen_operators_lib( + LIB_NAME "openvino_portable_ops_lib" KERNEL_LIBS portable_kernels DEPS executorch +) +target_compile_options( + openvino_portable_ops_lib INTERFACE -DET_EVENT_TRACER_ENABLED +) +target_include_directories( + openvino_portable_ops_lib PUBLIC ${_common_include_directories} +) + + +# build executor runner +add_executable(openvino_executor_runner ${_openvino_executor_runner__srcs}) +target_include_directories( + openvino_executor_runner PUBLIC ${_common_include_directories} +) + +# Set the path to the library directory +set(LIBRARY_DIR "${CMAKE_CURRENT_LIST_DIR}/../../cmake-openvino-out/lib/") + +# List the libraries you want to link (without the 'lib' prefix and file extension) +set(LIBRARIES_TO_LINK ${LIBRARY_DIR}/libopenvino_backend.so + ${LIBRARY_DIR}/libexecutorch.a + ${LIBRARY_DIR}/libexecutorch_core.a + ${EXECUTORCH_ROOT}/third-party/gflags/build/lib/libgflags_nothreads.a + ${LIBRARY_DIR}/libpthreadpool.a + ${LIBRARY_DIR}/libextension_data_loader.a + ${LIBRARY_DIR}/libextension_runner_util.a +) + +# Add the library directory to the link search path +link_directories(${LIBRARY_DIR}) + +# Link all libraries at once +target_link_libraries(openvino_executor_runner PRIVATE ${LIBRARIES_TO_LINK} openvino_portable_ops_lib) + +set_target_properties( + openvino_executor_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" +) + + +get_filename_component( + EXECUTORCH_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE +) + + diff --git a/examples/openvino/ReadMe.md b/examples/openvino/ReadMe.md new file mode 100644 index 00000000000..13196f5151c --- /dev/null +++ b/examples/openvino/ReadMe.md @@ -0,0 +1,64 @@ +# TODO: Delete and reformat later + +## Build Executorch + +```bash +git clone -b openvino_backend https://github.com/ynimmaga/executorch +cd executorch +git submodule update --init –recursive +./install_requirements.sh +(If not successful) pkill -f buck && ./install_requirements.sh +``` + +## Build OpenVINO and source environment variables: + +```bash +git clone -b executorch_ov_backend https://github.com/ynimmaga/openvino +cd openvino +git submodule update --init --recursive +mkdir build +cd build +cmake .. -DCMAKE_BUILD_TYPE=Release -DENABLE_PYTHON=ON -DENABLE_WHEEL=ON +make -j +cd wheels +pip install + +cd ../.. +cmake --install build --prefix +cd +source setupvars.sh +``` + +## Build gflags: + +```bash +cd third-party/gflags +mkdir build +cd build +cmake .. +make -j12 +``` + +## Build OpenVINO example: + +```bash +cd ../../../examples/openvino +./openvino_build.sh +``` + +### AOT step: +```bash +cd aot +python aot_openvino_compiler.py --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device CPU +``` + +### Update the model.pte in executorch example and rebuild +```bash +cd +cd examples/openvino/executor_runner +Update the path of model.pte in openvino_executor_runner.cpp at https://github.com/ynimmaga/executorch/blob/openvino_backend/examples/openvino/executor_runner/openvino_executor_runner.cpp#L20 + +Rebuild the example using “./openvino_build.sh” +The executable is in /cmake-openvino-out/examples/openvino +./openvino_executor_runner +``` diff --git a/examples/openvino/aot/README.md b/examples/openvino/aot/README.md new file mode 100644 index 00000000000..6c59f1dad41 --- /dev/null +++ b/examples/openvino/aot/README.md @@ -0,0 +1,88 @@ +# **Model Export Script for Executorch** + +This script allows users to export deep learning models from various model suites (TIMM, Torchvision, Hugging Face) to a openvino backend using **Executorch**. Users can dynamically specify the model, input shape, and target device. + + +## **Usage** + +### **Command Structure** +```bash +python aot_openvino_compiler.py --suite --model --input_shape --device +``` + +### **Arguments** +- **`--suite`** (required): + Specifies the model suite to use. + Supported values: + - `timm` (e.g., VGG16, ResNet50) + - `torchvision` (e.g., resnet18, mobilenet_v2) + - `huggingface` (e.g., bert-base-uncased) + +- **`--model`** (required): + Name of the model to export. + Examples: + - For `timm`: `vgg16`, `resnet50` + - For `torchvision`: `resnet18`, `mobilenet_v2` + - For `huggingface`: `bert-base-uncased`, `distilbert-base-uncased` + +- **`--input_shape`** (required): + Input shape for the model. Provide this as a **list** or **tuple**. + Examples: + - `[1, 3, 224, 224]` (Zsh users: wrap in quotes) + - `(1, 3, 224, 224)` + +- **`--device`** (optional): + Target device for the compiled model. Default is `CPU`. + Examples: `CPU`, `GPU` + +## **Examples** + +### Export a TIMM VGG16 model for the CPU +```bash +python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU +``` + +### Export a Torchvision ResNet50 model for the GPU +```bash +python aot_openvino_compiler.py --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device GPU +``` + +### Export a Hugging Face BERT model for the CPU +```bash +python aot_openvino_compiler.py --suite huggingface --model bert-base-uncased --input_shape "(1, 512)" --device CPU +``` + +## **Notes** +1. **Input Shape in Zsh**: + If you are using Zsh, wrap `--input_shape` in quotes or use a tuple: + ```bash + --input_shape '[1, 3, 224, 224]' + --input_shape "(1, 3, 224, 224)" + ``` + +2. **Model Compatibility**: + Ensure the specified `model_name` exists in the selected `suite`. Use the corresponding library's documentation to verify model availability. + +3. **Output File**: + The exported model will be saved as `.pte` in the current directory. + +4. **Dependencies**: + - Python 3.8+ + - PyTorch + - Executorch + - TIMM (`pip install timm`) + - Torchvision + - Transformers (`pip install transformers`) + +## **Error Handling** +- **Model Not Found**: + If the script raises an error such as: + ```bash + ValueError: Model not found + ``` + Verify that the model name is correct for the chosen suite. + +- **Unsupported Input Shape**: + Ensure `--input_shape` is provided as a valid list or tuple. + + diff --git a/examples/openvino/aot/aot_openvino_compiler.py b/examples/openvino/aot/aot_openvino_compiler.py new file mode 100644 index 00000000000..4674fbbd755 --- /dev/null +++ b/examples/openvino/aot/aot_openvino_compiler.py @@ -0,0 +1,80 @@ +# Copyright (c) Intel Corporation +# +# Licensed under the BSD License (the "License"); you may not use this file +# except in compliance with the License. See the license file in the root +# directory of this source tree for more details. + +import executorch +import timm +import torch +import torchvision.models as torchvision_models +from transformers import AutoModel +from executorch.exir.backend.backend_details import CompileSpec +from executorch.backends.openvino.preprocess import OpenvinoBackend +from executorch.backends.openvino.partitioner import OpenvinoPartitioner +from executorch.exir import EdgeProgramManager, to_edge +from torch.export import export, ExportedProgram +from torch.export.exported_program import ExportedProgram +import argparse + +# Function to load a model based on the selected suite +def load_model(suite: str, model_name: str): + if suite == "timm": + return timm.create_model(model_name, pretrained=True) + elif suite == "torchvision": + if not hasattr(torchvision_models, model_name): + raise ValueError(f"Model {model_name} not found in torchvision.") + return getattr(torchvision_models, model_name)(pretrained=True) + elif suite == "huggingface": + return AutoModel.from_pretrained(model_name) + else: + raise ValueError(f"Unsupported model suite: {suite}") + +def main(suite: str, model_name: str, input_shape, device: str): + # Ensure input_shape is a tuple + if isinstance(input_shape, list): + input_shape = tuple(input_shape) + elif not isinstance(input_shape, tuple): + raise ValueError("Input shape must be a list or tuple.") + + # Load the selected model + model = load_model(suite, model_name) + model = model.eval() + + # Provide input + example_args = (torch.randn(*input_shape), ) + + # Export to aten dialect using torch.export + aten_dialect: ExportedProgram = export(model, example_args) + + # Convert to edge dialect + edge_program: EdgeProgramManager = to_edge(aten_dialect) + to_be_lowered_module = edge_program.exported_program() + + # Lower the module to the backend with a custom partitioner + compile_spec = [CompileSpec("device", device.encode())] + lowered_module = edge_program.to_backend(OpenvinoPartitioner(compile_spec)) + + # Apply backend-specific passes + exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig()) + + # Serialize and save it to a file + with open(f"{model_name}.pte", "wb") as file: + exec_prog.write_to_file(file) + print(f"Model exported and saved as {model_name}.pte on {device}.") + +if __name__ == "__main__": + # Argument parser for dynamic inputs + parser = argparse.ArgumentParser(description="Export models with executorch.") + parser.add_argument("--suite", type=str, required=True, choices=["timm", "torchvision", "huggingface"], + help="Select the model suite (timm, torchvision, huggingface).") + parser.add_argument("--model", type=str, required=True, help="Model name to be loaded.") + parser.add_argument("--input_shape", type=eval, required=True, + help="Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224)).") + parser.add_argument("--device", type=str, default="CPU", + help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.") + + args = parser.parse_args() + + # Run the main function with parsed arguments + main(args.suite, args.model, args.input_shape, args.device) diff --git a/examples/openvino/executor_runner/openvino_executor_runner.cpp b/examples/openvino/executor_runner/openvino_executor_runner.cpp new file mode 100644 index 00000000000..7615b63649a --- /dev/null +++ b/examples/openvino/executor_runner/openvino_executor_runner.cpp @@ -0,0 +1,267 @@ +/* + * Copyright (c) Intel Corporation + * + * Licensed under the BSD License (the "License"); you may not use this file + * except in compliance with the License. See the license file in the root + * directory of this source tree for more details. + */ + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +// Define a fixed-size memory pool for the method allocator (4 MB) +static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB + +// Define command-line flags for model path, the number of iterations, input list path, and output folder path +DEFINE_string( + model_path, + "", + "Path to the model serialized in flatbuffer format (required)."); +DEFINE_int32( + num_iter, + 1, + "Number of inference iterations (default is 1)."); +DEFINE_string( + input_list_path, + "", + "Path to the input list file which includes the list of raw input tensor files (optional)."); +DEFINE_string( + output_folder_path, + "", + "Path to the output folder to save raw output tensor files (optional)."); + +using executorch::extension::FileDataLoader; +using executorch::extension::prepare_input_tensors; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; +using executorch::runtime::TensorInfo; + +int main(int argc, char** argv) { + // Initialize the runtime environment + executorch::runtime::runtime_init(); + + // Parse command-line arguments and flags + gflags::ParseCommandLineFlags(&argc, &argv, true); + + // Check if the model path is provided + if (FLAGS_model_path.empty()) { + std::cerr << "Error: --model_path is required." << std::endl; + std::cerr << "Usage: " << argv[0] + << " --model_path= --num_iter=" << std::endl; + return 1; + } + + // Retrieve the model path and number of iterations + const char* model_path = FLAGS_model_path.c_str(); + int num_iterations = FLAGS_num_iter; + std::cout << "Model path: " << model_path << std::endl; + std::cout << "Number of iterations: " << num_iterations << std::endl; + + // Load the model using FileDataLoader + Result loader = FileDataLoader::from(model_path); + ET_CHECK_MSG( + loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + static_cast(loader.error())); + + // Load the program from the loaded model + Result program = Program::load(&loader.get()); + if (!program.ok()) { + ET_LOG(Error, "Failed to parse model file %s", model_path); + return 1; + } + ET_LOG(Info, "Model file %s is loaded.", model_path); + + // Retrieve the method name from the program (assumes the first method is used) + const char* method_name = nullptr; + { + const auto method_name_result = program->get_method_name(0); + ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); + method_name = *method_name_result; + } + ET_LOG(Info, "Using method %s", method_name); + + // Retrieve metadata about the method + Result method_meta = program->method_meta(method_name); + ET_CHECK_MSG( + method_meta.ok(), + "Failed to get method_meta for %s: 0x%" PRIx32, + method_name, + static_cast(method_meta.error())); + + // Set up a memory allocator for the method + MemoryAllocator method_allocator{ + MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)}; + + // Prepare planned buffers for memory planning + std::vector> planned_buffers; + std::vector> planned_spans; + size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers(); + for (size_t id = 0; id < num_memory_planned_buffers; ++id) { + size_t buffer_size = + static_cast(method_meta->memory_planned_buffer_size(id).get()); + ET_LOG(Info, "Setting up planned buffer %zu, size %zu.", id, buffer_size); + planned_buffers.push_back(std::make_unique(buffer_size)); + planned_spans.push_back({planned_buffers.back().get(), buffer_size}); + } + HierarchicalAllocator planned_memory( + {planned_spans.data(), planned_spans.size()}); + + // Set up a memory manager using the method allocator and planned memory + MemoryManager memory_manager(&method_allocator, &planned_memory); + + // Load the method into the program + Result method = program->load_method(method_name, &memory_manager); + ET_CHECK_MSG( + method.ok(), + "Loading of method %s failed with status 0x%" PRIx32, + method_name, + static_cast(method.error())); + ET_LOG(Info, "Method loaded."); + + // Prepare the input tensors for the method + auto inputs = prepare_input_tensors(*method); + ET_CHECK_MSG( + inputs.ok(), + "Could not prepare inputs: 0x%" PRIx32, + static_cast(inputs.error())); + + // If the input path list is provided, read input tensors from the files + if (!(FLAGS_input_list_path.empty())) { + const char* input_list_path = FLAGS_input_list_path.c_str(); + ET_LOG(Info, "Loading input tensors from the list provided in %s.", input_list_path); + Error status = Error::Ok; + std::vector inputs(method->inputs_size()); + ET_LOG(Info, "%zu inputs: ", inputs.size()); + status = method->get_inputs(inputs.data(), inputs.size()); + ET_CHECK(status == Error::Ok); + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + // Read raw input tensor file names from input list file and + // iterate each raw input tensor file to read values + std::ifstream input_list(input_list_path); + if (input_list.is_open()) { + size_t num_inputs = method->inputs_size(); + std::string file_path; + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + if (input_files.size() == 0) { + break; + } + for (int input_index = 0; input_index < num_inputs; ++input_index) { + MethodMeta method_meta = method->method_meta(); + Result tensor_meta = + method_meta.input_tensor_meta(input_index); + auto input_data_ptr = inputs[input_index].toTensor().data_ptr(); + + std::ifstream fin(input_files[input_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + input_index, + file_size, + tensor_meta->nbytes()); + + fin.seekg(0, fin.beg); + fin.read( + static_cast(input_data_ptr), + file_size); + fin.close(); + } + } + } else { + ET_CHECK_MSG(false, + "Failed to read input list file: %s", + input_list_path); + } + } + ET_LOG(Info, "Inputs prepared."); + + // Measure execution time for inference + auto before_exec = std::chrono::high_resolution_clock::now(); + Error status = Error::Ok; + for (int i = 0; i < num_iterations; ++i) { + status = method->execute(); + } + auto after_exec = std::chrono::high_resolution_clock::now(); + double elapsed_time = std::chrono::duration_cast( + after_exec - before_exec) + .count() / 1000.0; + + // Log execution time and average time per iteration + ET_LOG( + Info, + "%d inference took %f ms, avg %f ms", + num_iterations, + elapsed_time, + elapsed_time / static_cast(num_iterations)); + ET_CHECK_MSG( + status == Error::Ok, + "Execution of method %s failed with status 0x%" PRIx32, + method_name, + static_cast(status)); + ET_LOG(Info, "Model executed successfully."); + + // Retrieve and print the method outputs + std::vector outputs(method->outputs_size()); + ET_LOG(Info, "%zu Number of outputs: ", outputs.size()); + status = method->get_outputs(outputs.data(), outputs.size()); + ET_CHECK(status == Error::Ok); + + // If output folder path is provided, save output tensors + // into raw tensor files. + if (!(FLAGS_output_folder_path.empty())) { + const char* output_folder_path = FLAGS_output_folder_path.c_str(); + ET_LOG(Info, "Saving output tensors into the output folder: %s.", output_folder_path); + for (size_t output_index = 0; output_index < method->outputs_size(); + output_index++) { + auto output_tensor = outputs[output_index].toTensor(); + auto output_file_name = std::string(output_folder_path) + "/output_" + + std::to_string(output_index) + ".raw"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + fout.write( + output_tensor.const_data_ptr(), output_tensor.nbytes()); + fout.close(); + } + } + + return 0; +} + diff --git a/examples/openvino/openvino_build.sh b/examples/openvino/openvino_build.sh new file mode 100755 index 00000000000..0d2703e5646 --- /dev/null +++ b/examples/openvino/openvino_build.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Exit immediately if a command exits with a non-zero status. +set -e + +# Define the directory where CMakeLists.txt is located +EXECUTORCH_ROOT=$(realpath "$(dirname "$0")/../..") +echo EXECUTORCH_ROOT=${EXECUTORCH_ROOT} + +main() { + # Set build directory + local build_dir="cmake-openvino-out" + + # Create and enter the build directory + cd "$EXECUTORCH_ROOT" + rm -rf "${build_dir}" + + # Configure the project with CMake + # Note: Add any additional configuration options you need here + cmake -DCMAKE_INSTALL_PREFIX="${build_dir}" \ + -DEXECUTORCH_BUILD_OPENVINO=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -B"${build_dir}" + + + # Build the project + cmake --build cmake-openvino-out --target install --config Release -j5 + + ## Build example + local example_dir=examples/openvino + local example_build_dir="${build_dir}/${example_dir}" + local cmake_prefix_path="${PWD}/${build_dir}/lib/cmake/ExecuTorch;${PWD}/${build_dir}/third-party/gflags;" + rm -rf "${example_build_dir}" + + ## OpenVINO original + cmake -DCMAKE_PREFIX_PATH="${cmake_prefix_path}" \ + -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ + -B"${example_build_dir}" \ + $EXECUTORCH_ROOT/$example_dir + + cmake --build "${example_build_dir}" -j5 + + # Switch back to the original directory + cd - > /dev/null + + # Print a success message + echo "Build successfully completed." +} + +main "$@"