diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 30ebf1a2c1d..53ee4d3f33a 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -36,12 +36,22 @@ public static ExecuTorchRuntime getRuntime() { /** * Validates that the given path points to a readable file. * - * @throws RuntimeException if the file does not exist or is not readable. + * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not + * readable. */ public static void validateFilePath(String path, String description) { + if (path == null) { + throw new IllegalArgumentException("Cannot load " + description + ": path is null"); + } File file = new File(path); - if (!file.canRead() || !file.isFile()) { - throw new RuntimeException("Cannot load " + description + " " + path); + if (!file.exists()) { + throw new IllegalArgumentException("Cannot load " + description + "!! " + path); + } + if (!file.isFile()) { + throw new IllegalArgumentException("Cannot load " + description + "!! " + path); + } + if (!file.canRead()) { + throw new IllegalArgumentException("Cannot load " + description + "!! " + path); } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt index 987cb3ec3be..ab9099ba405 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt @@ -11,6 +11,7 @@ package org.pytorch.executorch.extension.asr import java.io.Closeable import java.io.File import java.util.concurrent.atomic.AtomicLong +import org.pytorch.executorch.ExecutorchRuntimeException import org.pytorch.executorch.annotations.Experimental /** @@ -53,7 +54,10 @@ class AsrModule( val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath) if (handle == 0L) { - throw RuntimeException("Failed to create native AsrModule") + throw ExecutorchRuntimeException( + ExecutorchRuntimeException.INTERNAL, + "Failed to create native AsrModule", + ) } nativeHandle.set(handle) } @@ -129,7 +133,7 @@ class AsrModule( * @param callback Optional callback to receive tokens as they are generated (can be null) * @return The complete transcribed text * @throws IllegalStateException if the module has been destroyed - * @throws RuntimeException if transcription fails (non-zero result code) + * @throws ExecutorchRuntimeException if transcription fails (error code carried in exception) */ @JvmOverloads fun transcribe( @@ -160,7 +164,7 @@ class AsrModule( ) if (status != 0) { - throw RuntimeException("Transcription failed with error code: $status") + throw ExecutorchRuntimeException(status, "Transcription failed") } return result.toString() diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 8f4292c1bc8..58c7704b83e 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -93,7 +93,7 @@ public static SGD create(Map namedParameters, double learningRat */ public void step(Map namedGradients) { if (!mHybridData.isValid()) { - throw new RuntimeException("Attempt to use a destroyed SGD optimizer"); + throw new IllegalStateException("SGD optimizer has been destroyed"); } stepNative(namedGradients); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index 4a6653cb7a1..ca4bac9aa54 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -8,12 +8,11 @@ package org.pytorch.executorch.training; -import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.util.HashMap; +import java.io.Closeable; import java.util.Map; import org.pytorch.executorch.EValue; import org.pytorch.executorch.ExecuTorchRuntime; @@ -26,7 +25,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class TrainingModule { +public class TrainingModule implements Closeable { static { if (!NativeLoader.isInitialized()) { @@ -37,6 +36,7 @@ public class TrainingModule { } private final HybridData mHybridData; + private boolean mDestroyed = false; @DoNotStrip private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); @@ -45,6 +45,10 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); } + private void checkNotDestroyed() { + if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed"); + } + /** * Loads a serialized ExecuTorch Training Module from the specified path on the disk. * @@ -78,10 +82,7 @@ public static TrainingModule load(final String modelPath) { * @return return value(s) from the method. */ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new EValue[0]; - } + checkNotDestroyed(); return executeForwardBackwardNative(methodName, inputs); } @@ -89,10 +90,7 @@ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); public Map namedParameters(String methodName) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new HashMap(); - } + checkNotDestroyed(); return namedParametersNative(methodName); } @@ -100,13 +98,17 @@ public Map namedParameters(String methodName) { private native Map namedParametersNative(String methodName); public Map namedGradients(String methodName) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new HashMap(); - } + checkNotDestroyed(); return namedGradientsNative(methodName); } @DoNotStrip private native Map namedGradientsNative(String methodName); + + @Override + public void close() { + if (mDestroyed) return; + mDestroyed = true; + mHybridData.resetNative(); + } } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 88e9f9e2a12..0cf08e41983 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -284,8 +284,18 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #else auto etdump_gen = nullptr; #endif - module_ = std::make_unique( - modelPath->toStdString(), load_mode, std::move(etdump_gen)); + try { + module_ = std::make_unique( + modelPath->toStdString(), load_mode, std::move(etdump_gen)); + } catch (const std::exception& e) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + std::string("Failed to create Module: ") + e.what()); + } catch (...) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + "Failed to create Module: unknown native error"); + } #ifdef ET_USE_THREADPOOL // Default to using cores/2 threadpool threads. The long-term plan is to diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 94c0efff335..0c1ff5c67b9 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -149,103 +149,117 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint num_bos = 0, jint num_eos = 0, jint load_mode = 1) { - temperature_ = temperature; - num_bos_ = num_bos; - num_eos_ = num_eos; + try { + temperature_ = temperature; + num_bos_ = num_bos; + num_eos_ = num_eos; #if defined(ET_USE_THREADPOOL) - // Reserve 1 thread for the main thread. - int32_t num_performant_cores = - ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; - if (num_performant_cores > 0) { - ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); - ::executorch::extension::threadpool::get_threadpool() - ->_unsafe_reset_threadpool(num_performant_cores); - } + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG( + Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } #endif - model_type_category_ = model_type_category; - auto cpp_load_mode = load_mode_from_int(load_mode); - std::vector data_files_vector; - if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { - runner_ = llm::create_multimodal_runner( - model_path->toStdString().c_str(), - llm::load_tokenizer(tokenizer_path->toStdString()), - std::nullopt, - cpp_load_mode); - } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_files != nullptr) { - // Convert Java List to C++ std::vector - auto list_class = facebook::jni::findClassStatic("java/util/List"); - auto size_method = list_class->getMethod("size"); - auto get_method = - list_class->getMethod(jint)>( - "get"); - - jint size = size_method(data_files); - for (jint i = 0; i < size; ++i) { - auto str_obj = get_method(data_files, i); - auto jstr = facebook::jni::static_ref_cast(str_obj); - data_files_vector.push_back(jstr->toStdString()); + model_type_category_ = model_type_category; + auto cpp_load_mode = load_mode_from_int(load_mode); + std::vector data_files_vector; + if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { + runner_ = llm::create_multimodal_runner( + model_path->toStdString().c_str(), + llm::load_tokenizer(tokenizer_path->toStdString()), + std::nullopt, + cpp_load_mode); + } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { + if (data_files != nullptr) { + // Convert Java List to C++ std::vector + auto list_class = facebook::jni::findClassStatic("java/util/List"); + auto size_method = list_class->getMethod("size"); + auto get_method = + list_class->getMethod(jint)>( + "get"); + + jint size = size_method(data_files); + for (jint i = 0; i < size; ++i) { + auto str_obj = get_method(data_files, i); + auto jstr = facebook::jni::static_ref_cast(str_obj); + data_files_vector.push_back(jstr->toStdString()); + } } - } - runner_ = executorch::extension::llm::create_text_llm_runner( - model_path->toStdString(), - llm::load_tokenizer(tokenizer_path->toStdString()), - data_files_vector, - /*temperature=*/-1.0f, - /*event_tracer=*/nullptr, - /*method_name=*/"forward", - cpp_load_mode); + runner_ = executorch::extension::llm::create_text_llm_runner( + model_path->toStdString(), + llm::load_tokenizer(tokenizer_path->toStdString()), + data_files_vector, + /*temperature=*/-1.0f, + /*event_tracer=*/nullptr, + /*method_name=*/"forward", + cpp_load_mode); #if defined(EXECUTORCH_BUILD_QNN) - } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { - std::unique_ptr module = - std::make_unique( - model_path->toStdString().c_str(), - data_files_vector, - cpp_load_mode); - std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width").get().toScalar().to()); - } + } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { + std::unique_ptr module = + std::make_unique( + model_path->toStdString().c_str(), + data_files_vector, + cpp_load_mode); + std::string decoder_model = "llama3"; // use llama3 for now + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width") + .get() + .toScalar() + .to()); + } - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) - } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); - // Interpret the model type as LLM - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { + runner_ = std::make_unique( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str()); + // Interpret the model type as LLM + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif + } + } catch (const std::exception& e) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + std::string("Failed to create LlmModule: ") + e.what()); + } catch (...) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + "Failed to create LlmModule: unknown native error"); } }