Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions backends/openvino/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/../../..)

include_directories(BEFORE ${_common_include_directories})

# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

include(${EXECUTORCH_ROOT}/build/Utils.cmake)

set(_common_include_directories ${EXECUTORCH_ROOT}/..)

# Set openvino directory from environment
set(OPENVINO_DIR "$ENV{INTEL_OPENVINO_DIR}")
set(OPENVINO_INCLUDE_DIRS ${OPENVINO_DIR}/deployment_tools/inference_engine/include ${OPENVINO_DIR}/runtime/include)

# Add the OpenVINO backend library
add_library(openvino_backend SHARED)
target_compile_options(openvino_backend PRIVATE "-frtti" "-fexceptions")

# Include directories for ExecuteTorch and OpenVINO
target_include_directories(
openvino_backend PUBLIC ${_common_include_directories}
)

target_include_directories(
openvino_backend PUBLIC ${OPENVINO_INCLUDE_DIRS}
)

set(OPENVINO_LIB_PATH ${OPENVINO_DIR}/runtime/lib/intel64)
set(OPENVINO_LIBS
${OPENVINO_LIB_PATH}/libopenvino.so
${OPENVINO_LIB_PATH}/libopenvino_ir_frontend.so.2450
${OPENVINO_LIB_PATH}/libopenvino_c.so
${OPENVINO_LIB_PATH}/libopenvino_intel_cpu_plugin.so
${OPENVINO_LIB_PATH}/libopenvino_intel_gpu_plugin.so
${OPENVINO_LIB_PATH}/libopenvino_auto_plugin.so
)

# Link the OpenVINO library to the backend
target_link_libraries(openvino_backend PRIVATE ${OPENVINO_LIBS} executorch_core)

target_sources(
openvino_backend
PRIVATE ${CMAKE_CURRENT_LIST_DIR}/runtime/OpenvinoBackend.cpp
)

target_link_options_shared_lib(openvino_backend)
install(TARGETS openvino_backend DESTINATION lib)



4 changes: 4 additions & 0 deletions backends/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .partitioner import OpenvinoPartitioner
from .preprocess import OpenvinoBackend

__all__ = [OpenvinoBackend, OpenvinoPartitioner]
112 changes: 112 additions & 0 deletions backends/openvino/partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

from typing import Callable, final, List, Optional, Tuple

import torch
from executorch.backends.openvino.preprocess import OpenvinoBackend
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data

from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
import torch.fx as fx
from openvino.frontend.pytorch.torchdynamo.op_support import OperatorSupport

class OpenvinoOperatorsSupport(OperatorSupportBase):

def __init__(
self,
op_types_to_skip: Optional[set] = None,
op_names_to_skip: Optional[set] = None,
) -> None:
if op_types_to_skip is None:
op_types_to_skip = set()
if op_names_to_skip is None:
op_names_to_skip = set()

self._op_types_to_skip = op_types_to_skip
self._op_names_to_skip = op_names_to_skip

def is_node_supported(self, _, node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False

options = []
op_type = node.target.__name__
supported_ops = OperatorSupport(options)._support_dict
if (op_type == "getitem"):
return True

if ("torch.ops." + str(op_type) in supported_ops):
return True
else:
print("Op not supported: ", "torch.ops." + str(op_type))

if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip:
print(
f"[OpenVINO Backend] The {op_type} operator with name '{node.name}' is skipped."
)
return False

return False


@final
class OpenvinoPartitioner(Partitioner):

def __init__(
self,
compile_spec: List[CompileSpec],
op_types_to_skip: Optional[set] = None,
op_names_to_skip: Optional[set] = None,
) -> None:
self.delegation_spec = DelegationSpec(OpenvinoBackend.__name__, compile_spec)
self._op_types_to_skip = op_types_to_skip
self._op_names_to_skip = op_names_to_skip

def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_not_decompose = [
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.upsample_bilinear2d.default,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.default,
torch.ops.aten.upsample_nearest2d.vec,
]
return (ops_not_decompose, None)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
options = {}
gm = fx.symbolic_trace(exported_program.graph_module)

partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
OpenvinoOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip),
allows_single_node_partition=True
)
partition_list = partitioner.propose_partitions()

partition_tags = {}
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
45 changes: 45 additions & 0 deletions backends/openvino/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

import contextlib
import struct

from typing import final, List, cast

import torch
from executorch.exir.backend.backend_details import (
BackendDetails,
ExportedProgram,
PreprocessResult,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec
from openvino.frontend.pytorch.torchdynamo.compile import openvino_compile

SKIP_COMPILE_SPEC_KEYS = {"ImportForever"}


@final
class OpenvinoBackend(BackendDetails):

@classmethod
def preprocess(
cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec]
) -> PreprocessResult:
name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes}
input_names = edge_program.graph_signature.user_inputs
output_names = edge_program.graph_signature.user_outputs
args = []
for node in edge_program.graph.nodes:
if (node.target in input_names):
args.append( node.meta["val"])

input_shapes = []
output_shapes = []

compiled = openvino_compile(edge_program.module(), *args)
model_bytes = compiled.export_model()

return PreprocessResult(processed_bytes=model_bytes)
8 changes: 8 additions & 0 deletions backends/openvino/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
datasets
huggingface-hub
safetensors
sentencepiece
tokenizers
transformers
piq
pillow
139 changes: 139 additions & 0 deletions backends/openvino/runtime/OpenvinoBackend.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#include <cstring>
#include <memory>
#include <iostream>

#include <openvino/openvino.hpp>

#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>

using namespace std;
using executorch::aten::ScalarType;
using executorch::runtime::ArrayRef;
using executorch::runtime::Backend;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::CompileSpec;
using executorch::runtime::DelegateHandle;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::Result;

namespace executorch {
namespace backends {
namespace openvino {

typedef struct {
std::shared_ptr<ov::CompiledModel> compiled_model;
std::shared_ptr<ov::InferRequest> infer_request;
} ExecutionHandle;

class OpenvinoBackend final : public ::executorch::runtime::BackendInterface {
public:
OpenvinoBackend() {std::cout << "In OV Backend constructor" << std::endl;}

~OpenvinoBackend() = default;

virtual bool is_available() const override {
// Check if OpenVINO runtime is available
return true;
}

Result<DelegateHandle*> init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const override {
ET_LOG(Info, "OpenvinoBackend::init %p", processed->data());

ov::Core core;

const char* data_ptr = static_cast<const char*>(processed->data());
size_t data_size = processed->size();

// Copy data to a string or vector
std::string data_string(data_ptr, data_size);

// Wrap the data in a stream
std::istringstream compiled_stream(data_string);

auto compiled_model = core.import_model(compiled_stream, "CPU");

// Allocate an infer request
std::shared_ptr<ov::InferRequest> infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());

// Allocate execution handle
MemoryAllocator* allocator = context.get_runtime_allocator();
ExecutionHandle* handle = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
handle->compiled_model = std::make_shared<ov::CompiledModel>(compiled_model);
handle->infer_request = infer_request;

return handle;
}

Error execute(
BackendExecutionContext& context,
DelegateHandle* input_handle,
EValue** args) const override {
ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle;

auto infer_request = execution_handle->infer_request;

// Assume first argument is the input tensor
auto input_tensor = args[0]->toTensor();
ov::Shape input_shape(input_tensor.sizes().begin(), input_tensor.sizes().end());

// Convert input tensor to OpenVINO tensor
ov::element::Type ov_type = convert_to_openvino_type(input_tensor.scalar_type());
ov::Tensor ov_input_tensor(ov_type, input_shape, input_tensor.mutable_data_ptr());

//infer_request->set_tensor("input", ov_input_tensor);
infer_request->set_input_tensor(0, ov_input_tensor);

// Execute the inference
infer_request->infer();

// Retrieve and copy output
auto output_tensor = args[1]->toTensor(); // Assume second argument is the output
ov::Tensor ov_output_tensor = infer_request->get_output_tensor(0); //get_tensor("output");

std::memcpy(output_tensor.mutable_data_ptr(), ov_output_tensor.data(), ov_output_tensor.get_byte_size());

return Error::Ok;
}

void destroy(DelegateHandle* handle) const override {
return;
}

private:
ov::element::Type convert_to_openvino_type(ScalarType scalar_type) const {
// Convert ExecuteTorch scalar types to OpenVINO element types
switch (scalar_type) {
case ScalarType::Float:
return ov::element::f32;
case ScalarType::Int:
return ov::element::i32;
case ScalarType::Char:
return ov::element::i8;
default:
throw std::runtime_error("Unsupported scalar type");
}
}
};

} // namespace openvino
} // namespace backends
} // namespace executorch

namespace {
auto backend = executorch::backends::openvino::OpenvinoBackend();
executorch::runtime::Backend backend_id{"OpenvinoBackend", &backend};
static auto registered = executorch::runtime::register_backend(backend_id);
} // namespace


Loading