diff --git a/httpfs/CMakeLists.txt b/httpfs/CMakeLists.txt index 602beae..15148eb 100644 --- a/httpfs/CMakeLists.txt +++ b/httpfs/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(httpfs_extension_source src/httpfs.cpp src/httpfs_extension.cpp src/s3fs.cpp + src/xetfs.cpp src/crypto.cpp src/http_config.cpp src/cached_file_manager.cpp diff --git a/httpfs/src/httpfs_extension.cpp b/httpfs/src/httpfs_extension.cpp index 5bf031d..13707a4 100644 --- a/httpfs/src/httpfs_extension.cpp +++ b/httpfs/src/httpfs_extension.cpp @@ -6,6 +6,7 @@ #include "main/database.h" #include "s3fs.h" #include "s3fs_config.h" +#include "xetfs.h" namespace lbug { namespace httpfs_extension { @@ -26,6 +27,7 @@ static void registerExtensionOptions(main::Database* db) { static void registerFileSystem(main::Database* db) { db->registerFileSystem(std::make_unique()); + db->registerFileSystem(std::make_unique()); for (auto& fsConfig : S3FileSystemConfig::getAvailableConfigs()) { db->registerFileSystem(std::make_unique(fsConfig)); } diff --git a/httpfs/src/include/xetfs.h b/httpfs/src/include/xetfs.h new file mode 100644 index 0000000..d79aa4b --- /dev/null +++ b/httpfs/src/include/xetfs.h @@ -0,0 +1,32 @@ +#pragma once + +#include "httpfs.h" + +namespace lbug { +namespace httpfs_extension { + +class XetFileSystem final : public HTTPFileSystem { +public: + std::unique_ptr openFile(const std::string& path, common::FileOpenFlags flags, + main::ClientContext* context = nullptr) override; + + std::vector glob(main::ClientContext* context, + const std::string& path) const override; + + bool canHandleFile(const std::string_view path) const override; + + bool fileOrPathExists(const std::string& path, main::ClientContext* context = nullptr) override; + + static std::string toHuggingFaceURL(const std::string& path); + +protected: + std::unique_ptr headRequest(common::FileInfo* fileInfo, const std::string& url, + HeaderMap headerMap) const override; + + std::unique_ptr getRangeRequest(common::FileInfo* fileInfo, + const std::string& url, HeaderMap headerMap, uint64_t fileOffset, char* buffer, + uint64_t bufferLen) const override; +}; + +} // namespace httpfs_extension +} // namespace lbug diff --git a/httpfs/src/xetfs.cpp b/httpfs/src/xetfs.cpp new file mode 100644 index 0000000..e5073e9 --- /dev/null +++ b/httpfs/src/xetfs.cpp @@ -0,0 +1,221 @@ +#include "xetfs.h" + +#include "common/exception/io.h" +#include "common/string_utils.h" +#include + +namespace lbug { +namespace httpfs_extension { + +using namespace common; + +namespace { + +static constexpr std::string_view XET_PREFIX = "xet://"; +static constexpr std::string_view HF_BASE_URL = "https://huggingface.co/"; + +std::vector splitPath(std::string_view path) { + std::vector result; + size_t start = 0; + while (start <= path.size()) { + auto end = path.find('/', start); + if (end == std::string_view::npos) { + end = path.size(); + } + result.emplace_back(path.substr(start, end - start)); + start = end + 1; + if (end == path.size()) { + break; + } + } + return result; +} + +std::string joinSegments(const std::vector& segments, size_t start) { + std::string result; + for (auto i = start; i < segments.size(); ++i) { + if (!result.empty()) { + result += "/"; + } + result += segments[i]; + } + return result; +} + +std::string buildResolveURL(std::string_view repoPrefix, const std::vector& segments) { + if (segments.size() < 4) { + throw IOException{ + "Xet URL must include namespace, repository, revision, and file path components."}; + } + const auto filePath = joinSegments(segments, 3); + if (filePath.empty()) { + throw IOException{"Xet URL must include a file path."}; + } + return std::format("{}{}{}{}/{}/resolve/{}/{}", HF_BASE_URL, repoPrefix, + repoPrefix.empty() ? "" : "/", segments[0], segments[1], segments[2], filePath); +} + +std::string buildResolveURLWithExplicitResolve(std::string_view repoPrefix, + const std::vector& segments) { + if (segments.size() < 5 || segments[2] != "resolve") { + return buildResolveURL(repoPrefix, segments); + } + const auto filePath = joinSegments(segments, 4); + if (filePath.empty()) { + throw IOException{"Xet URL must include a file path."}; + } + return std::format("{}{}{}{}/{}/resolve/{}/{}", HF_BASE_URL, repoPrefix, + repoPrefix.empty() ? "" : "/", segments[0], segments[1], segments[3], filePath); +} + +std::string makeAbsoluteRedirectURL(const std::string& sourceURL, const std::string& location) { + if (location.rfind("http://", 0) == 0 || location.rfind("https://", 0) == 0) { + return location; + } + const auto [host, hostPath] = HTTPFileSystem::parseUrl(sourceURL); + if (location.starts_with('/')) { + return host + location; + } + const auto lastSlash = hostPath.find_last_of('/'); + const auto basePath = + lastSlash == std::string::npos ? std::string{"/"} : hostPath.substr(0, lastSlash + 1); + return host + basePath + location; +} + +std::unique_ptr getNoRedirectClient(const std::string& host) { + auto client = HTTPFileSystem::getClient(host); + client->set_follow_location(false); + client->set_url_encode(false); + return client; +} + +std::unique_ptr synthesizeHeadResponse(const HTTPResponse& response, + const std::string& url, const std::string& contentLength) { + httplib::Response res; + res.status = 200; + res.reason = "OK"; + for (auto& [name, value] : response.headers) { + if (StringUtils::caseInsensitiveEquals(name, "Content-Length")) { + continue; + } + res.headers.emplace(name, value); + } + res.headers.emplace("Content-Length", contentLength); + return std::make_unique(res, url); +} + +} // namespace + +std::unique_ptr XetFileSystem::openFile(const std::string& path, + common::FileOpenFlags flags, main::ClientContext* context) { + if (flags.flags & FileFlags::WRITE) { + throw IOException{"Writing to xet:// URLs is not supported."}; + } + return HTTPFileSystem::openFile(toHuggingFaceURL(path), flags, context); +} + +std::vector XetFileSystem::glob(main::ClientContext* /*context*/, + const std::string& path) const { + // Keep xet:// paths routed to XetFileSystem after bind-time glob expansion. + return {path}; +} + +bool XetFileSystem::canHandleFile(const std::string_view path) const { + return path.rfind(XET_PREFIX, 0) == 0; +} + +bool XetFileSystem::fileOrPathExists(const std::string& path, main::ClientContext* context) { + return HTTPFileSystem::fileOrPathExists(toHuggingFaceURL(path), context); +} + +std::string XetFileSystem::toHuggingFaceURL(const std::string& path) { + if (path.rfind(XET_PREFIX, 0) != 0) { + throw IOException{"Xet URL needs to start with xet://"}; + } + + auto suffix = std::string_view{path}.substr(XET_PREFIX.size()); + if (suffix.rfind("huggingface.co/", 0) == 0) { + return std::format("{}{}", HF_BASE_URL, + suffix.substr(std::string_view{"huggingface.co/"}.size())); + } + if (suffix.rfind("hf.co/", 0) == 0) { + return std::format("{}{}", HF_BASE_URL, suffix.substr(std::string_view{"hf.co/"}.size())); + } + + const auto segments = splitPath(suffix); + if (segments.empty() || segments[0].empty()) { + throw IOException{"Xet URL must include a Hugging Face repository path."}; + } + if (segments[0] == "models" || segments[0] == "model") { + return buildResolveURLWithExplicitResolve("", + std::vector{segments.begin() + 1, segments.end()}); + } + if (segments[0] == "datasets" || segments[0] == "dataset") { + return buildResolveURLWithExplicitResolve("datasets", + std::vector{segments.begin() + 1, segments.end()}); + } + if (segments[0] == "spaces" || segments[0] == "space") { + return buildResolveURLWithExplicitResolve("spaces", + std::vector{segments.begin() + 1, segments.end()}); + } + return buildResolveURLWithExplicitResolve("", segments); +} + +std::unique_ptr XetFileSystem::headRequest(common::FileInfo* /*fileInfo*/, + const std::string& url, HeaderMap headerMap) const { + const auto [host, hostPath] = HTTPFileSystem::parseUrl(url); + auto headers = getHTTPHeaders(headerMap); + auto client = getNoRedirectClient(host); + + std::function request( + [&]() { return client->Head(hostPath.c_str(), *headers); }); + std::function retry([&]() { client = getNoRedirectClient(host); }); + + auto response = runRequestWithRetry(request, url, "HEAD", retry); + if (response->code >= 300 && response->code < 400 && + response->headers.contains("x-linked-size")) { + return synthesizeHeadResponse(*response, url, response->headers["x-linked-size"]); + } + if (response->code >= 300 && response->code < 400 && response->headers.contains("Location")) { + return headRequest(nullptr, makeAbsoluteRedirectURL(url, response->headers["Location"]), + headerMap); + } + return response; +} + +std::unique_ptr XetFileSystem::getRangeRequest(common::FileInfo* /*fileInfo*/, + const std::string& url, HeaderMap headerMap, uint64_t fileOffset, char* buffer, + uint64_t bufferLen) const { + const auto [host, hostPath] = HTTPFileSystem::parseUrl(url); + auto headers = getHTTPHeaders(headerMap); + headers->insert(std::make_pair("Range", + std::format("bytes={}-{}", fileOffset, fileOffset + bufferLen - 1))); + auto client = getNoRedirectClient(host); + + std::function request( + [&]() { return client->Get(hostPath.c_str(), *headers); }); + std::function retry([&]() { client = getNoRedirectClient(host); }); + + auto response = runRequestWithRetry(request, url, "GET Range", retry); + if (response->code >= 300 && response->code < 400 && response->headers.contains("Location")) { + return getRangeRequest(nullptr, makeAbsoluteRedirectURL(url, response->headers["Location"]), + headerMap, fileOffset, buffer, bufferLen); + } + if (response->code >= 400) { + throw IOException(std::format("HTTP GET error on '{}' (HTTP {})", url, response->code)); + } + if (response->code < 300) { + if (response->body.size() != bufferLen) { + throw IOException(std::format("HTTP GET error: response body size {} does not match " + "requested range size {}.", + response->body.size(), bufferLen)); + } + if (buffer != nullptr) { + memcpy(buffer, response->body.data(), bufferLen); + } + } + return response; +} + +} // namespace httpfs_extension +} // namespace lbug diff --git a/httpfs/test/CMakeLists.txt b/httpfs/test/CMakeLists.txt index e69de29..9987a3e 100644 --- a/httpfs/test/CMakeLists.txt +++ b/httpfs/test/CMakeLists.txt @@ -0,0 +1,4 @@ +if (${BUILD_EXTENSION_TESTS}) + add_lbug_test(httpfs_xetfs_test xetfs_test.cpp) + target_link_libraries(httpfs_xetfs_test PRIVATE httpfs_extension_source ${OPENSSL_LIBRARIES}) +endif () diff --git a/httpfs/test/xetfs_test.cpp b/httpfs/test/xetfs_test.cpp new file mode 100644 index 0000000..9351692 --- /dev/null +++ b/httpfs/test/xetfs_test.cpp @@ -0,0 +1,28 @@ +#include "gtest/gtest.h" +#include "xetfs.h" + +using namespace lbug::httpfs_extension; + +TEST(XetFileSystemTest, MapsModelResolveURL) { + EXPECT_EQ("https://huggingface.co/Qwen/Qwen-Image-Edit/resolve/main/model.safetensors", + XetFileSystem::toHuggingFaceURL("xet://models/Qwen/Qwen-Image-Edit/main/" + "model.safetensors")); +} + +TEST(XetFileSystemTest, MapsDatasetResolveURL) { + EXPECT_EQ("https://huggingface.co/datasets/org/repo/resolve/main/data/train.parquet", + XetFileSystem::toHuggingFaceURL("xet://datasets/org/repo/resolve/main/" + "data/train.parquet")); +} + +TEST(XetFileSystemTest, MapsExplicitHubURL) { + EXPECT_EQ("https://huggingface.co/org/repo/resolve/main/file.csv", + XetFileSystem::toHuggingFaceURL("xet://huggingface.co/org/repo/resolve/main/file.csv")); +} + +TEST(XetFileSystemTest, GlobKeepsXetPath) { + XetFileSystem fs; + const auto path = + std::string{"xet://datasets/ladybugdb/small-kgs/main/kg_history/icebug-disk/schema.cypher"}; + EXPECT_EQ(std::vector{path}, fs.glob(nullptr, path)); +}