Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions backends/apple/mps/runtime/MPSCompiler.mm
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ void printLoadedGraph(MPSGraphExecutable* executable) {
size_t num_bytes) {
ExirMPSGraphPackage* exirMPSGraphPackage = (ExirMPSGraphPackage*)buffer_pointer;
NSData *new_manifest_plist_data = [NSData dataWithBytes:exirMPSGraphPackage->data length:exirMPSGraphPackage->model_0_offset];
NSData *new_model_0_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset length:exirMPSGraphPackage->model_1_offset - exirMPSGraphPackage->model_0_offset];
NSData *new_model_1_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_1_offset length:exirMPSGraphPackage->total_bytes - sizeof(ExirMPSGraphPackage) - exirMPSGraphPackage->model_1_offset];
NSData *new_model_0_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset length:exirMPSGraphPackage->total_bytes - sizeof(ExirMPSGraphPackage) - exirMPSGraphPackage->model_0_offset];

NSError* error = nil;
NSString* packageName = [NSString stringWithUTF8String:(
Expand All @@ -52,14 +51,12 @@ void printLoadedGraph(MPSGraphExecutable* executable) {

NSString* manifestFileStr = [NSString stringWithFormat:@"%@/manifest.plist", dataFileNSStr];
NSString* model0FileStr = [NSString stringWithFormat:@"%@/model_0.mpsgraph", dataFileNSStr];
NSString* model1FileStr = [NSString stringWithFormat:@"%@/model_1.mpsgraph", dataFileNSStr];

NSFileManager *fileManager= [NSFileManager defaultManager];
[fileManager createDirectoryAtPath:dataFileNSStr withIntermediateDirectories:NO attributes:nil error:&error];

[new_manifest_plist_data writeToFile:manifestFileStr options:NSDataWritingAtomic error:&error];
[new_model_0_data writeToFile:model0FileStr options:NSDataWritingAtomic error:&error];
[new_model_1_data writeToFile:model1FileStr options:NSDataWritingAtomic error:&error];

NSURL *bundleURL = [NSURL fileURLWithPath:dataFileNSStr];
MPSGraphCompilationDescriptor *compilationDescriptor = [MPSGraphCompilationDescriptor new];
Expand Down
3 changes: 0 additions & 3 deletions backends/apple/mps/utils/MPSGraphInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ class MPSGraphModule {
std::vector<MPSGraphTensor*> outputTensors_;
std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphExecutable* executable_;

id<MTLDevice> device_;
id<MTLCommandQueue> commandQueue_;
};

} // namespace mps
12 changes: 2 additions & 10 deletions backends/apple/mps/utils/MPSGraphInterface.mm
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
"MPS Executorch backend is supported only from macOS 14.0 and above.");

mpsGraph = [MPSGraph new];
device_ = MTLCreateSystemDefaultDevice();
commandQueue_ = [device_ newCommandQueue];
}

MPSGraphModule::~MPSGraphModule() {
Expand Down Expand Up @@ -95,7 +93,7 @@
[targetTensors addObject:outputTensor];
});

MPSGraphExecutable *exec = [mpsGraph compileWithDevice:[MPSGraphDevice deviceWithMTLDevice:device_]
MPSGraphExecutable *exec = [mpsGraph compileWithDevice:nil
feeds:feeds
targetTensors:targetTensors
targetOperations:nil
Expand All @@ -111,7 +109,6 @@

std::string name = "mpsgraphmodule_" + std::to_string(arc4random_uniform(INT_MAX));
std::string mpsgraphpackagePath = dataFolder + name + ".mpsgraphpackage";

NSString *mpsgraphpackageFileStr = [NSString stringWithUTF8String:mpsgraphpackagePath.c_str()];
NSURL *bundleURL = [NSURL fileURLWithPath:mpsgraphpackageFileStr];

Expand All @@ -122,28 +119,23 @@

NSString* mpsgraphpackage_manifest_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/manifest.plist").c_str()];
NSString* mpsgraphpackage_model_0_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/model_0.mpsgraph").c_str()];
NSString* mpsgraphpackage_model_1_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/model_1.mpsgraph").c_str()];

NSURL* manifestPlistURL = [NSURL fileURLWithPath:mpsgraphpackage_manifest_file];
NSURL* model0URL = [NSURL fileURLWithPath:mpsgraphpackage_model_0_file];
NSURL* model1URL = [NSURL fileURLWithPath:mpsgraphpackage_model_1_file];

NSData* manifest_plist_data = [NSData dataWithContentsOfURL:manifestPlistURL];
NSData* model_0_data = [NSData dataWithContentsOfURL:model0URL];
NSData* model_1_data = [NSData dataWithContentsOfURL:model1URL];

int64_t total_package_size = sizeof(ExirMPSGraphPackage) + [manifest_plist_data length] + [model_0_data length] + [model_1_data length];
int64_t total_package_size = sizeof(ExirMPSGraphPackage) + [manifest_plist_data length] + [model_0_data length];
ExirMPSGraphPackage *exirMPSGraphPackage = (ExirMPSGraphPackage*)malloc(total_package_size);
assert(exirMPSGraphPackage != nil);

exirMPSGraphPackage->manifest_plist_offset = 0;
exirMPSGraphPackage->model_0_offset = [manifest_plist_data length];
exirMPSGraphPackage->model_1_offset = exirMPSGraphPackage->model_0_offset + [model_0_data length];
exirMPSGraphPackage->total_bytes = total_package_size;

memcpy(exirMPSGraphPackage->data, [manifest_plist_data bytes], [manifest_plist_data length]);
memcpy(exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset, [model_0_data bytes], [model_0_data length]);
memcpy(exirMPSGraphPackage->data + exirMPSGraphPackage->model_1_offset, [model_1_data bytes], [model_1_data length]);

std::vector<uint8_t> data((uint8_t*)exirMPSGraphPackage, (uint8_t*)exirMPSGraphPackage + total_package_size);
free(exirMPSGraphPackage);
Expand Down
1 change: 0 additions & 1 deletion backends/apple/mps/utils/MPSGraphPackageExport.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
struct ExirMPSGraphPackage {
int64_t manifest_plist_offset;
int64_t model_0_offset;
int64_t model_1_offset;
int64_t total_bytes;
uint8_t data[];
};