Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,22 @@ public static ExecuTorchRuntime getRuntime() {
/**
* Validates that the given path points to a readable file.
*
* @throws RuntimeException if the file does not exist or is not readable.
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
* readable.
*/
public static void validateFilePath(String path, String description) {
if (path == null) {
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
}
File file = new File(path);
if (!file.canRead() || !file.isFile()) {
throw new RuntimeException("Cannot load " + description + " " + path);
if (!file.exists()) {
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
}
if (!file.isFile()) {
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
}
if (!file.canRead()) {
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
Comment thread
psiddh marked this conversation as resolved.
}
Comment thread
psiddh marked this conversation as resolved.
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package org.pytorch.executorch.extension.asr
import java.io.Closeable
import java.io.File
import java.util.concurrent.atomic.AtomicLong
import org.pytorch.executorch.ExecutorchRuntimeException
import org.pytorch.executorch.annotations.Experimental

/**
Expand Down Expand Up @@ -53,7 +54,10 @@ class AsrModule(

val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath)
if (handle == 0L) {
throw RuntimeException("Failed to create native AsrModule")
throw ExecutorchRuntimeException(
ExecutorchRuntimeException.INTERNAL,
"Failed to create native AsrModule",
)
}
nativeHandle.set(handle)
}
Expand Down Expand Up @@ -129,7 +133,7 @@ class AsrModule(
* @param callback Optional callback to receive tokens as they are generated (can be null)
* @return The complete transcribed text
* @throws IllegalStateException if the module has been destroyed
* @throws RuntimeException if transcription fails (non-zero result code)
* @throws ExecutorchRuntimeException if transcription fails (error code carried in exception)
*/
@JvmOverloads
fun transcribe(
Expand Down Expand Up @@ -160,7 +164,7 @@ class AsrModule(
)

Comment thread
psiddh marked this conversation as resolved.
if (status != 0) {
throw RuntimeException("Transcription failed with error code: $status")
throw ExecutorchRuntimeException(status, "Transcription failed")
}
Comment thread
psiddh marked this conversation as resolved.

return result.toString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public static SGD create(Map<String, Tensor> namedParameters, double learningRat
*/
public void step(Map<String, Tensor> namedGradients) {
if (!mHybridData.isValid()) {
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
throw new IllegalStateException("SGD optimizer has been destroyed");
}
stepNative(namedGradients);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

package org.pytorch.executorch.training;

import android.util.Log;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.HashMap;
import java.io.Closeable;
import java.util.Map;
import org.pytorch.executorch.EValue;
import org.pytorch.executorch.ExecuTorchRuntime;
Expand All @@ -26,7 +25,7 @@
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class TrainingModule {
public class TrainingModule implements Closeable {

static {
if (!NativeLoader.isInitialized()) {
Expand All @@ -37,6 +36,7 @@ public class TrainingModule {
}

private final HybridData mHybridData;
private boolean mDestroyed = false;

@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);
Expand All @@ -45,6 +45,10 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
}

private void checkNotDestroyed() {
if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed");
}

/**
* Loads a serialized ExecuTorch Training Module from the specified path on the disk.
*
Expand Down Expand Up @@ -78,35 +82,33 @@ public static TrainingModule load(final String modelPath) {
* @return return value(s) from the method.
*/
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
checkNotDestroyed();
return executeForwardBackwardNative(methodName, inputs);
}

@DoNotStrip
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);

public Map<String, Tensor> namedParameters(String methodName) {
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
checkNotDestroyed();
return namedParametersNative(methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedParametersNative(String methodName);

public Map<String, Tensor> namedGradients(String methodName) {
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
checkNotDestroyed();
return namedGradientsNative(methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedGradientsNative(String methodName);

@Override
public void close() {
if (mDestroyed) return;
mDestroyed = true;
mHybridData.resetNative();
}
Comment thread
psiddh marked this conversation as resolved.
}
14 changes: 12 additions & 2 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,18 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
#else
auto etdump_gen = nullptr;
#endif
module_ = std::make_unique<Module>(
modelPath->toStdString(), load_mode, std::move(etdump_gen));
try {
module_ = std::make_unique<Module>(
modelPath->toStdString(), load_mode, std::move(etdump_gen));
} catch (const std::exception& e) {
executorch::jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::Internal),
std::string("Failed to create Module: ") + e.what());
} catch (...) {
executorch::jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::Internal),
"Failed to create Module: unknown native error");
}
Comment thread
psiddh marked this conversation as resolved.

#ifdef ET_USE_THREADPOOL
// Default to using cores/2 threadpool threads. The long-term plan is to
Expand Down
190 changes: 102 additions & 88 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,103 +149,117 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
jint num_bos = 0,
jint num_eos = 0,
jint load_mode = 1) {
temperature_ = temperature;
num_bos_ = num_bos;
num_eos_ = num_eos;
try {
temperature_ = temperature;
num_bos_ = num_bos;
num_eos_ = num_eos;
#if defined(ET_USE_THREADPOOL)
// Reserve 1 thread for the main thread.
int32_t num_performant_cores =
::executorch::extension::cpuinfo::get_num_performant_cores() - 1;
if (num_performant_cores > 0) {
ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores);
::executorch::extension::threadpool::get_threadpool()
->_unsafe_reset_threadpool(num_performant_cores);
}
// Reserve 1 thread for the main thread.
int32_t num_performant_cores =
::executorch::extension::cpuinfo::get_num_performant_cores() - 1;
if (num_performant_cores > 0) {
ET_LOG(
Info, "Resetting threadpool to %d threads", num_performant_cores);
::executorch::extension::threadpool::get_threadpool()
->_unsafe_reset_threadpool(num_performant_cores);
}
#endif

model_type_category_ = model_type_category;
auto cpp_load_mode = load_mode_from_int(load_mode);
std::vector<std::string> data_files_vector;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
runner_ = llm::create_multimodal_runner(
model_path->toStdString().c_str(),
llm::load_tokenizer(tokenizer_path->toStdString()),
std::nullopt,
cpp_load_mode);
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
if (data_files != nullptr) {
// Convert Java List<String> to C++ std::vector<string>
auto list_class = facebook::jni::findClassStatic("java/util/List");
auto size_method = list_class->getMethod<jint()>("size");
auto get_method =
list_class->getMethod<facebook::jni::local_ref<jobject>(jint)>(
"get");

jint size = size_method(data_files);
for (jint i = 0; i < size; ++i) {
auto str_obj = get_method(data_files, i);
auto jstr = facebook::jni::static_ref_cast<jstring>(str_obj);
data_files_vector.push_back(jstr->toStdString());
model_type_category_ = model_type_category;
auto cpp_load_mode = load_mode_from_int(load_mode);
std::vector<std::string> data_files_vector;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
runner_ = llm::create_multimodal_runner(
model_path->toStdString().c_str(),
llm::load_tokenizer(tokenizer_path->toStdString()),
std::nullopt,
cpp_load_mode);
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
if (data_files != nullptr) {
// Convert Java List<String> to C++ std::vector<string>
auto list_class = facebook::jni::findClassStatic("java/util/List");
auto size_method = list_class->getMethod<jint()>("size");
auto get_method =
list_class->getMethod<facebook::jni::local_ref<jobject>(jint)>(
"get");

jint size = size_method(data_files);
for (jint i = 0; i < size; ++i) {
auto str_obj = get_method(data_files, i);
auto jstr = facebook::jni::static_ref_cast<jstring>(str_obj);
data_files_vector.push_back(jstr->toStdString());
}
}
}
runner_ = executorch::extension::llm::create_text_llm_runner(
model_path->toStdString(),
llm::load_tokenizer(tokenizer_path->toStdString()),
data_files_vector,
/*temperature=*/-1.0f,
/*event_tracer=*/nullptr,
/*method_name=*/"forward",
cpp_load_mode);
runner_ = executorch::extension::llm::create_text_llm_runner(
model_path->toStdString(),
llm::load_tokenizer(tokenizer_path->toStdString()),
data_files_vector,
/*temperature=*/-1.0f,
/*event_tracer=*/nullptr,
/*method_name=*/"forward",
cpp_load_mode);
#if defined(EXECUTORCH_BUILD_QNN)
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
std::unique_ptr<executorch::extension::Module> module =
std::make_unique<executorch::extension::Module>(
model_path->toStdString().c_str(),
data_files_vector,
cpp_load_mode);
std::string decoder_model = "llama3"; // use llama3 for now
// Using 8bit as default since this meta is introduced with 16bit kv io
// support and older models only have 8bit kv io.
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
}
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
std::unique_ptr<executorch::extension::Module> module =
std::make_unique<executorch::extension::Module>(
model_path->toStdString().c_str(),
data_files_vector,
cpp_load_mode);
std::string decoder_model = "llama3"; // use llama3 for now
// Using 8bit as default since this meta is introduced with 16bit kv io
// support and older models only have 8bit kv io.
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
module->get("get_kv_io_bit_width")
.get()
.toScalar()
.to<int64_t>());
}

if (kv_bitwidth == example::KvBitWidth::kWidth8) {
runner_ = std::make_unique<example::Runner<uint8_t>>(
std::move(module),
decoder_model.c_str(),
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
"",
"",
temperature_);
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
runner_ = std::make_unique<example::Runner<uint16_t>>(
std::move(module),
decoder_model.c_str(),
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
"",
"",
temperature_);
} else {
ET_CHECK_MSG(
false,
"Unsupported kv bitwidth: %ld",
static_cast<int64_t>(kv_bitwidth));
}
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
runner_ = std::make_unique<example::Runner<uint8_t>>(
std::move(module),
decoder_model.c_str(),
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
"",
"",
temperature_);
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
runner_ = std::make_unique<example::Runner<uint16_t>>(
std::move(module),
decoder_model.c_str(),
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
"",
"",
temperature_);
} else {
ET_CHECK_MSG(
false,
"Unsupported kv bitwidth: %ld",
static_cast<int64_t>(kv_bitwidth));
}
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
#endif
#if defined(EXECUTORCH_BUILD_MEDIATEK)
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
runner_ = std::make_unique<MTKLlamaRunner>(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str());
// Interpret the model type as LLM
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
runner_ = std::make_unique<MTKLlamaRunner>(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str());
// Interpret the model type as LLM
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
#endif
}
} catch (const std::exception& e) {
executorch::jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::Internal),
std::string("Failed to create LlmModule: ") + e.what());
Comment thread
psiddh marked this conversation as resolved.
} catch (...) {
executorch::jni_helper::throwExecutorchException(
static_cast<uint32_t>(Error::Internal),
"Failed to create LlmModule: unknown native error");
}
Comment thread
psiddh marked this conversation as resolved.
}

Expand Down
Loading