diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 7ea95ed14..a2e919409 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1601,6 +1601,54 @@ struct SDGenerationParams { return true; } + static bool sanitize_lora_path(const std::string& lora_model_dir, + const std::string& raw_path_str, + fs::path& full_path) { + if (lora_model_dir.empty()) { + return false; + } + + fs::path raw_path(raw_path_str); + + // Disallow absolute paths and '..' components + if (raw_path.is_absolute()) { + LOG_WARN("lora path must be relative: %s", raw_path_str.c_str()); + return false; + } + + for (const auto& part : raw_path) { + if (part == "..") { + LOG_WARN("lora path cannot contain '..': %s", raw_path_str.c_str()); + return false; + } + } + + // Construct and canonicalize paths + fs::path lora_dir(lora_model_dir); + full_path = lora_dir / raw_path; + + auto canonical_lora_dir = fs::weakly_canonical(lora_dir); + auto canonical_full_path = fs::weakly_canonical(full_path); + + // Check if path is a directory + if (fs::is_directory(canonical_full_path)) { + LOG_WARN("lora path resolved to a directory, not a file: %s", raw_path_str.c_str()); + return false; + } + + // Verify path stays within lora directory + auto [root_end, nothing] = std::mismatch( + canonical_lora_dir.begin(), canonical_lora_dir.end(), + canonical_full_path.begin(), canonical_full_path.end()); + + if (root_end != canonical_lora_dir.end()) { + LOG_WARN("lora path is outside of the lora model directory: %s", raw_path_str.c_str()); + return false; + } + + return true; + } + void extract_and_remove_lora(const std::string& lora_model_dir) { if (lora_model_dir.empty()) { return; @@ -1632,10 +1680,10 @@ struct SDGenerationParams { } fs::path final_path; - if (is_absolute_path(raw_path)) { - final_path = raw_path; - } else { - final_path = fs::path(lora_model_dir) / raw_path; + if (!sanitize_lora_path(lora_model_dir, raw_path, final_path)) { + tmp = m.suffix().str(); + prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only); + continue; } if (!fs::exists(final_path)) { bool found = false; diff --git a/examples/server/main.cpp b/examples/server/main.cpp index c540958f8..69c75d322 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -293,6 +293,7 @@ int main(int argc, const char** argv) { LOG_DEBUG("%s", default_gen_params.to_string().c_str()); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); + ctx_params.lora_apply_mode = LORA_APPLY_AT_RUNTIME; sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); if (sd_ctx == nullptr) { @@ -414,7 +415,7 @@ int main(int argc, const char** argv) { return; } - if (!gen_params.process_and_check(IMG_GEN, "")) { + if (!gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) { res.status = 400; res.set_content(R"({"error":"invalid params"})", "application/json"); return; @@ -592,7 +593,7 @@ int main(int argc, const char** argv) { return; } - if (!gen_params.process_and_check(IMG_GEN, "")) { + if (!gen_params.process_and_check(IMG_GEN, ctx_params.lora_model_dir)) { res.status = 400; res.set_content(R"({"error":"invalid params"})", "application/json"); return;