diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 56dbbaa84..61229a5e0 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -41,7 +41,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA // deepset/tinyroberta-squad2, typeform/squeezebert-mnli, // facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6, // distilbert-base-uncased-finetuned-sst-2-english, -// sentence-transformers/all-distilroberta-v1. +// sentence-transformers/all-distilroberta-v1, +// jinaai/jina-embeddings-v5-text-nano (EuroBERT + LoRA). // Eland-deployed variants of the above models (with pooling/normalization layers). // Additional ops from Elasticsearch integration test models // (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT). @@ -68,6 +69,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::clone"sv, "aten::contiguous"sv, "aten::copy_"sv, + "aten::cos"sv, "aten::cumsum"sv, "aten::detach"sv, "aten::div"sv, @@ -117,10 +119,13 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::relu"sv, "aten::repeat"sv, "aten::reshape"sv, + "aten::rsqrt"sv, "aten::rsub"sv, "aten::scaled_dot_product_attention"sv, "aten::select"sv, "aten::sign"sv, + "aten::silu"sv, + "aten::sin"sv, "aten::size"sv, "aten::slice"sv, "aten::softmax"sv, diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index 5180fb403..e292b78b9 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -259,11 +259,11 @@ BOOST_AUTO_TEST_CASE(testValidModuleWithAllowedOps) { } BOOST_AUTO_TEST_CASE(testModuleWithUnrecognisedOps) { - // torch.sin is not in the transformer allowlist. + // torch.logit is not in the transformer allowlist. ::torch::jit::Module m("__torch__.UnknownOps"); m.define(R"( def forward(self, x: Tensor) -> Tensor: - return torch.sin(x) + return torch.logit(x) )"); auto result = CModelGraphValidator::validate(m); @@ -271,13 +271,13 @@ BOOST_AUTO_TEST_CASE(testModuleWithUnrecognisedOps) { BOOST_REQUIRE(result.s_IsValid == false); BOOST_REQUIRE(result.s_ForbiddenOps.empty()); BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false); - bool foundSin = false; + bool foundLogit = false; for (const auto& op : result.s_UnrecognisedOps) { - if (op == "aten::sin") { - foundSin = true; + if (op == "aten::logit") { + foundLogit = true; } } - BOOST_REQUIRE(foundSin); + BOOST_REQUIRE(foundLogit); } BOOST_AUTO_TEST_CASE(testModuleNodeCountPopulated) { @@ -301,7 +301,7 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) { ::torch::jit::Module child("__torch__.Child"); child.define(R"( def forward(self, x: Tensor) -> Tensor: - return torch.sin(x) + return torch.logit(x) )"); ::torch::jit::Module parent("__torch__.Parent"); @@ -314,13 +314,13 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) { auto result = CModelGraphValidator::validate(parent); BOOST_REQUIRE(result.s_IsValid == false); - bool foundSin = false; + bool foundLogit = false; for (const auto& op : result.s_UnrecognisedOps) { - if (op == "aten::sin") { - foundSin = true; + if (op == "aten::logit") { + foundLogit = true; } } - BOOST_REQUIRE(foundSin); + BOOST_REQUIRE(foundLogit); } // --- Integration tests with malicious .pt model fixtures --- @@ -363,34 +363,38 @@ BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) { BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) { // Unrecognised ops buried three levels deep in nested submodules. // The validator must inline through all submodules to find them. + // The model uses aten::sin which is now allowed (EuroBERT/Jina v5), + // but also contains other ops that remain unrecognised. auto module = ::torch::jit::load("testfiles/malicious_models/malicious_hidden_in_submodule.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false); } BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) { // An unrecognised op hidden inside a conditional branch. The // validator must recurse into prim::If blocks to detect it. + // The model uses aten::sin which is now allowed, but also contains + // other ops that remain unrecognised. auto module = ::torch::jit::load("testfiles/malicious_models/malicious_conditional.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false); } BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) { - // A model using many different unrecognised ops (sin, cos, tan, exp). + // A model using many different ops (sin, cos, tan, exp). + // sin and cos are now allowed (EuroBERT/Jina v5), but tan and exp + // remain unrecognised. auto module = ::torch::jit::load("testfiles/malicious_models/malicious_many_unrecognised.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 4); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::cos")); + BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 2); BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::tan")); BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::exp")); }