Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
@Test
@Throws(IOException::class, URISyntaxException::class)
fun testGenerate() {
val loadResult = llmModule.load()
// Check that the model can be load successfully
assertEquals(OK.toLong(), loadResult.toLong())
llmModule.load()

llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
assertEquals(results.size.toLong(), SEQ_LEN.toLong())
Expand Down Expand Up @@ -277,7 +275,6 @@ class LlmModuleInstrumentationTest : LlmCallback {
private const val TEST_FILE_NAME = "/stories.pte"
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
private const val TEST_PROMPT = "Hello"
private const val OK = 0x00
private const val SEQ_LEN = 32
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ class ModuleInstrumentationTest {
fun testModuleLoadMethodAndForward() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
module.loadMethod(FORWARD_METHOD)

val results = module.forward()
Assert.assertTrue(results[0].isTensor)
Expand Down Expand Up @@ -96,17 +95,22 @@ class ModuleInstrumentationTest {
fun testModuleLoadMethodNonExistantMethod() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

val loadMethod = module.loadMethod(NONE_METHOD)
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
val exception =
Assert.assertThrows(ExecutorchRuntimeException::class.java) {
module.loadMethod(NONE_METHOD)
}
Assert.assertEquals(
ExecutorchRuntimeException.INVALID_ARGUMENT,
exception.getErrorCode(),
)
}

@Test(expected = RuntimeException::class)
@Throws(IOException::class)
fun testNonPteFile() {
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
module.loadMethod(FORWARD_METHOD)
}

@Test
Expand All @@ -116,22 +120,19 @@ class ModuleInstrumentationTest {

module.destroy()

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong())
Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) }
}

@Test
@Throws(IOException::class)
fun testForwardOnDestroyedModule() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

val loadMethod = module.loadMethod(FORWARD_METHOD)
Assert.assertEquals(loadMethod.toLong(), OK.toLong())
module.loadMethod(FORWARD_METHOD)
Comment thread
psiddh marked this conversation as resolved.

module.destroy()

val results = module.forward()
Assert.assertEquals(0, results.size.toLong())
Assert.assertThrows(IllegalStateException::class.java) { module.forward() }
}

@Ignore(
Expand Down Expand Up @@ -175,9 +176,5 @@ class ModuleInstrumentationTest {
private const val NON_PTE_FILE_NAME = "/test.txt"
private const val FORWARD_METHOD = "forward"
private const val NONE_METHOD = "none"
private const val OK = 0x00
private const val INVALID_STATE = 0x2
private const val INVALID_ARGUMENT = 0x12
private const val ACCESS_FAILED = 0x22
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,83 @@
import java.util.HashMap;
import java.util.Map;

/**
* Base exception for all ExecuTorch runtime errors. Each instance carries an integer error code
* corresponding to the native {@code runtime/core/error.h} values, accessible via {@link
* #getErrorCode()}.
*/
public class ExecutorchRuntimeException extends RuntimeException {
// Error code constants - keep in sync with runtime/core/error.h

// System errors

/** Operation completed successfully. */
public static final int OK = 0x00;

/** An unexpected internal error occurred in the runtime. */
public static final int INTERNAL = 0x01;

/** The runtime or method is in an invalid state for the requested operation. */
public static final int INVALID_STATE = 0x02;

/** The method has finished execution and has no more work to do. */
public static final int END_OF_METHOD = 0x03;

/** A required resource has already been loaded. */
public static final int ALREADY_LOADED = 0x04;

// Logical errors

/** The requested operation is not supported by this build or backend. */
public static final int NOT_SUPPORTED = 0x10;

/** The requested operation has not been implemented. */
public static final int NOT_IMPLEMENTED = 0x11;

/** One or more arguments passed to the operation are invalid. */
public static final int INVALID_ARGUMENT = 0x12;

/** A value or tensor has an unexpected type. */
public static final int INVALID_TYPE = 0x13;

/** A required operator kernel is not registered. */
public static final int OPERATOR_MISSING = 0x14;

/** The maximum number of registered kernels has been exceeded. */
public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15;

/** A kernel with the same name is already registered. */
public static final int REGISTRATION_ALREADY_REGISTERED = 0x16;

// Resource errors

/** A required resource (file, tensor, program) was not found. */
public static final int NOT_FOUND = 0x20;

/** A memory allocation failed. */
public static final int MEMORY_ALLOCATION_FAILED = 0x21;

/** Access to a resource was denied or failed. */
public static final int ACCESS_FAILED = 0x22;

/** The loaded program is malformed or incompatible. */
public static final int INVALID_PROGRAM = 0x23;

/** External data referenced by the program is invalid or missing. */
public static final int INVALID_EXTERNAL_DATA = 0x24;

/** The system has run out of a required resource. */
public static final int OUT_OF_RESOURCES = 0x25;

// Delegate errors

/** A delegate reported an incompatible model or configuration. */
public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30;

/** A delegate failed to allocate required memory. */
public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31;

/** A delegate received an invalid or stale handle. */
public static final int DELEGATE_INVALID_HANDLE = 0x32;

private static final Map<Integer, String> ERROR_CODE_MESSAGES;
Expand All @@ -52,6 +101,7 @@ public class ExecutorchRuntimeException extends RuntimeException {
map.put(INTERNAL, "Internal error");
map.put(INVALID_STATE, "Invalid state");
map.put(END_OF_METHOD, "End of method reached");
map.put(ALREADY_LOADED, "Already loaded");
// Logical errors
map.put(NOT_SUPPORTED, "Operation not supported");
map.put(NOT_IMPLEMENTED, "Operation not implemented");
Expand Down Expand Up @@ -83,7 +133,7 @@ static String formatMessage(int errorCode, String details) {

String safeDetails = details != null ? details : "No details provided";
return String.format(
"[Executorch Error 0x%s] %s: %s",
"[ExecuTorch Error 0x%s] %s: %s",
Integer.toHexString(errorCode), baseMessage, safeDetails);
}

Expand Down Expand Up @@ -111,10 +161,12 @@ public ExecutorchRuntimeException(int errorCode, String details) {
this.errorCode = errorCode;
}

/** Returns the numeric error code from {@code runtime/core/error.h}. */
public int getErrorCode() {
return errorCode;
}

/** Returns detailed log output captured from the native runtime, if available. */
public String getDetailedError() {
return ErrorHelper.getDetailedErrorLogs();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.pytorch.executorch;

import android.util.Log;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
Expand Down Expand Up @@ -130,11 +129,10 @@ public EValue[] forward(EValue... inputs) {
* @return return value from the method.
*/
public EValue[] execute(String methodName, EValue... inputs) {
mLock.lock();
try {
mLock.lock();
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
throw new IllegalStateException("Module has been destroyed");
}
return executeNative(methodName, inputs);
} finally {
Expand All @@ -151,17 +149,17 @@ public EValue[] execute(String methodName, EValue... inputs) {
* synchronous, and will block until the method is loaded. Therefore, it is recommended to call
* this on a background thread. However, users need to make sure that they don't execute before
* this function returns.
*
* @return the Error code if there was an error loading the method
*/
public int loadMethod(String methodName) {
public void loadMethod(String methodName) {
mLock.lock();
try {
mLock.lock();
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return 0x2; // InvalidState
throw new IllegalStateException("Module has been destroyed");
}
int errorCode = loadMethodNative(methodName);
if (errorCode != 0) {
throw new ExecutorchRuntimeException(errorCode, "Failed to load method: " + methodName);
}
return loadMethodNative(methodName);
} finally {
mLock.unlock();
}
Expand All @@ -184,8 +182,20 @@ public int loadMethod(String methodName) {
*
* @return name of methods in this Module
*/
public String[] getMethods() {
mLock.lock();
try {
if (!mHybridData.isValid()) {
throw new IllegalStateException("Module has been destroyed");
}
return getMethodsNative();
} finally {
mLock.unlock();
}
}

@DoNotStrip
public native String[] getMethods();
private native String[] getMethodsNative();

/**
* Get the corresponding @MethodMetadata for a method
Expand All @@ -194,11 +204,19 @@ public int loadMethod(String methodName) {
* @return @MethodMetadata for this method
*/
public MethodMetadata getMethodMetadata(String name) {
MethodMetadata methodMetadata = mMethodMetadata.get(name);
if (methodMetadata == null) {
throw new IllegalArgumentException("method " + name + " does not exist for this module");
mLock.lock();
try {
if (!mHybridData.isValid()) {
throw new IllegalStateException("Module has been destroyed");
}
MethodMetadata methodMetadata = mMethodMetadata.get(name);
if (methodMetadata == null) {
throw new IllegalArgumentException("method " + name + " does not exist for this module");
}
return methodMetadata;
} finally {
mLock.unlock();
}
return methodMetadata;
}

@DoNotStrip
Expand All @@ -210,7 +228,15 @@ public static String[] readLogBufferStatic() {

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
return readLogBufferNative();
mLock.lock();
try {
if (!mHybridData.isValid()) {
throw new IllegalStateException("Module has been destroyed");
}
return readLogBufferNative();
} finally {
mLock.unlock();
}
}

@DoNotStrip
Expand All @@ -224,8 +250,20 @@ public String[] readLogBuffer() {
* @return true if the etdump was successfully written, false otherwise.
*/
@Experimental
public boolean etdump() {
mLock.lock();
try {
if (!mHybridData.isValid()) {
throw new IllegalStateException("Module has been destroyed");
}
return etdumpNative();
} finally {
mLock.unlock();
}
}

@DoNotStrip
public native boolean etdump();
private native boolean etdumpNative();

/**
* Explicitly destroys the native Module object. Calling this method is not required, as the
Expand All @@ -241,10 +279,7 @@ public void destroy() {
mLock.unlock();
}
} else {
Log.w(
"ExecuTorch",
"Destroy was called while the module was in use. Resources will not be immediately"
+ " released.");
throw new IllegalStateException("Cannot destroy module while method is executing");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,11 @@ default void onStats(String stats) {}
* @param message Human-readable error description
*/
@DoNotStrip
default void onError(int errorCode, String message) {}
default void onError(int errorCode, String message) {
try {
android.util.Log.e("ExecuTorch", "LLM error " + errorCode + ": " + message);
} catch (Throwable t) {
System.err.println("ExecuTorch LLM error " + errorCode + ": " + message);
}
Comment thread
psiddh marked this conversation as resolved.
}
}
Loading
Loading