Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions examples/openvino/aot/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# **Model Export Script for Executorch**

This script allows users to export deep learning models from various model suites (TIMM, Torchvision, Hugging Face) to a openvino backend using **Executorch**. Users can dynamically specify the model, input shape, and target device.


## **Usage**

### **Command Structure**
```bash
python aot_openvino_compiler.py --suite <MODEL_SUITE> --model <MODEL_NAME> --input_shape <INPUT_SHAPE> --device <DEVICE>
```

### **Arguments**
- **`--suite`** (required):
Specifies the model suite to use.
Supported values:
- `timm` (e.g., VGG16, ResNet50)
- `torchvision` (e.g., resnet18, mobilenet_v2)
- `huggingface` (e.g., bert-base-uncased)

- **`--model`** (required):
Name of the model to export.
Examples:
- For `timm`: `vgg16`, `resnet50`
- For `torchvision`: `resnet18`, `mobilenet_v2`
- For `huggingface`: `bert-base-uncased`, `distilbert-base-uncased`

- **`--input_shape`** (required):
Input shape for the model. Provide this as a **list** or **tuple**.
Examples:
- `[1, 3, 224, 224]` (Zsh users: wrap in quotes)
- `(1, 3, 224, 224)`

- **`--device`** (optional):
Target device for the compiled model. Default is `CPU`.
Examples: `CPU`, `GPU`

## **Examples**

### Export a TIMM VGG16 model for the CPU
```bash
python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3, 224, 224] --device CPU
```

### Export a Torchvision ResNet50 model for the GPU
```bash
python aot_openvino_compiler.py --suite torchvision --model resnet50 --input_shape "(1, 3, 256, 256)" --device GPU
```

### Export a Hugging Face BERT model for the CPU
```bash
python aot_openvino_compiler.py --suite huggingface --model bert-base-uncased --input_shape "(1, 512)" --device CPU
```

## **Notes**
1. **Input Shape in Zsh**:
If you are using Zsh, wrap `--input_shape` in quotes or use a tuple:
```bash
--input_shape '[1, 3, 224, 224]'
--input_shape "(1, 3, 224, 224)"
```

2. **Model Compatibility**:
Ensure the specified `model_name` exists in the selected `suite`. Use the corresponding library's documentation to verify model availability.

3. **Output File**:
The exported model will be saved as `<MODEL_NAME>.pte` in the current directory.

4. **Dependencies**:
- Python 3.8+
- PyTorch
- Executorch
- TIMM (`pip install timm`)
- Torchvision
- Transformers (`pip install transformers`)

## **Error Handling**
- **Model Not Found**:
If the script raises an error such as:
```bash
ValueError: Model <MODEL_NAME> not found
```
Verify that the model name is correct for the chosen suite.

- **Unsupported Input Shape**:
Ensure `--input_shape` is provided as a valid list or tuple.


74 changes: 74 additions & 0 deletions examples/openvino/aot/aot_openvino_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import executorch
import timm
import torch
import torchvision.models as torchvision_models
from transformers import AutoModel
from executorch.exir.backend.backend_details import CompileSpec
from executorch.backends.openvino.preprocess import OpenvinoBackend
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
from executorch.exir import EdgeProgramManager, to_edge
from torch.export import export, ExportedProgram
from torch.export.exported_program import ExportedProgram
import argparse

# Function to load a model based on the selected suite
def load_model(suite: str, model_name: str):
if suite == "timm":
return timm.create_model(model_name, pretrained=True)
elif suite == "torchvision":
if not hasattr(torchvision_models, model_name):
raise ValueError(f"Model {model_name} not found in torchvision.")
return getattr(torchvision_models, model_name)(pretrained=True)
elif suite == "huggingface":
return AutoModel.from_pretrained(model_name)
else:
raise ValueError(f"Unsupported model suite: {suite}")

def main(suite: str, model_name: str, input_shape, device: str):
# Ensure input_shape is a tuple
if isinstance(input_shape, list):
input_shape = tuple(input_shape)
elif not isinstance(input_shape, tuple):
raise ValueError("Input shape must be a list or tuple.")

# Load the selected model
model = load_model(suite, model_name)
model = model.eval()

# Provide input
example_args = (torch.randn(*input_shape), )

# Export to aten dialect using torch.export
aten_dialect: ExportedProgram = export(model, example_args)

# Convert to edge dialect
edge_program: EdgeProgramManager = to_edge(aten_dialect)
to_be_lowered_module = edge_program.exported_program()

# Lower the module to the backend with a custom partitioner
compile_spec = [CompileSpec("device", device.encode())]
lowered_module = edge_program.to_backend(OpenvinoPartitioner(compile_spec))

# Apply backend-specific passes
exec_prog = lowered_module.to_executorch(config=executorch.exir.ExecutorchBackendConfig())

# Serialize and save it to a file
with open(f"{model_name}.pte", "wb") as file:
exec_prog.write_to_file(file)
print(f"Model exported and saved as {model_name}.pte on {device}.")

if __name__ == "__main__":
# Argument parser for dynamic inputs
parser = argparse.ArgumentParser(description="Export models with executorch.")
parser.add_argument("--suite", type=str, required=True, choices=["timm", "torchvision", "huggingface"],
help="Select the model suite (timm, torchvision, huggingface).")
parser.add_argument("--model", type=str, required=True, help="Model name to be loaded.")
parser.add_argument("--input_shape", type=eval, required=True,
help="Input shape for the model as a list or tuple (e.g., [1, 3, 224, 224] or (1, 3, 224, 224)).")
parser.add_argument("--device", type=str, default="CPU",
help="Target device for compiling the model (e.g., CPU, GPU). Default is CPU.")

args = parser.parse_args()

# Run the main function with parsed arguments
main(args.suite, args.model, args.input_shape, args.device)