Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bin/pytorch_inference/CSupportedOperations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA
// deepset/tinyroberta-squad2, typeform/squeezebert-mnli,
// facebook/bart-large-mnli, valhalla/distilbart-mnli-12-6,
// distilbert-base-uncased-finetuned-sst-2-english,
// sentence-transformers/all-distilroberta-v1.
// sentence-transformers/all-distilroberta-v1,
// jinaai/jina-embeddings-v5-text-nano (EuroBERT + LoRA).
// Eland-deployed variants of the above models (with pooling/normalization layers).
// Additional ops from Elasticsearch integration test models
// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT).
Expand All @@ -68,6 +69,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::clone"sv,
"aten::contiguous"sv,
"aten::copy_"sv,
"aten::cos"sv,
"aten::cumsum"sv,
"aten::detach"sv,
"aten::div"sv,
Expand Down Expand Up @@ -117,10 +119,13 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::relu"sv,
"aten::repeat"sv,
"aten::reshape"sv,
"aten::rsqrt"sv,
"aten::rsub"sv,
"aten::scaled_dot_product_attention"sv,
"aten::select"sv,
"aten::sign"sv,
"aten::silu"sv,
"aten::sin"sv,
"aten::size"sv,
"aten::slice"sv,
"aten::softmax"sv,
Expand Down
38 changes: 21 additions & 17 deletions bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,25 +259,25 @@ BOOST_AUTO_TEST_CASE(testValidModuleWithAllowedOps) {
}

BOOST_AUTO_TEST_CASE(testModuleWithUnrecognisedOps) {
// torch.sin is not in the transformer allowlist.
// torch.logit is not in the transformer allowlist.
::torch::jit::Module m("__torch__.UnknownOps");
m.define(R"(
def forward(self, x: Tensor) -> Tensor:
return torch.sin(x)
return torch.logit(x)
)");

auto result = CModelGraphValidator::validate(m);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(result.s_ForbiddenOps.empty());
BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false);
bool foundSin = false;
bool foundLogit = false;
for (const auto& op : result.s_UnrecognisedOps) {
if (op == "aten::sin") {
foundSin = true;
if (op == "aten::logit") {
foundLogit = true;
}
}
BOOST_REQUIRE(foundSin);
BOOST_REQUIRE(foundLogit);
}

BOOST_AUTO_TEST_CASE(testModuleNodeCountPopulated) {
Expand All @@ -301,7 +301,7 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) {
::torch::jit::Module child("__torch__.Child");
child.define(R"(
def forward(self, x: Tensor) -> Tensor:
return torch.sin(x)
return torch.logit(x)
)");

::torch::jit::Module parent("__torch__.Parent");
Expand All @@ -314,13 +314,13 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) {
auto result = CModelGraphValidator::validate(parent);

BOOST_REQUIRE(result.s_IsValid == false);
bool foundSin = false;
bool foundLogit = false;
for (const auto& op : result.s_UnrecognisedOps) {
if (op == "aten::sin") {
foundSin = true;
if (op == "aten::logit") {
foundLogit = true;
}
}
BOOST_REQUIRE(foundSin);
BOOST_REQUIRE(foundLogit);
}

// --- Integration tests with malicious .pt model fixtures ---
Expand Down Expand Up @@ -363,34 +363,38 @@ BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) {
BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) {
// Unrecognised ops buried three levels deep in nested submodules.
// The validator must inline through all submodules to find them.
// The model uses aten::sin which is now allowed (EuroBERT/Jina v5),
// but also contains other ops that remain unrecognised.
auto module = ::torch::jit::load("testfiles/malicious_models/malicious_hidden_in_submodule.pt");
auto result = CModelGraphValidator::validate(module);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(result.s_ForbiddenOps.empty());
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin"));
BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false);
}

BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) {
// An unrecognised op hidden inside a conditional branch. The
// validator must recurse into prim::If blocks to detect it.
// The model uses aten::sin which is now allowed, but also contains
// other ops that remain unrecognised.
auto module = ::torch::jit::load("testfiles/malicious_models/malicious_conditional.pt");
auto result = CModelGraphValidator::validate(module);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin"));
BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false);
}

BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) {
// A model using many different unrecognised ops (sin, cos, tan, exp).
// A model using many different ops (sin, cos, tan, exp).
// sin and cos are now allowed (EuroBERT/Jina v5), but tan and exp
// remain unrecognised.
auto module = ::torch::jit::load("testfiles/malicious_models/malicious_many_unrecognised.pt");
auto result = CModelGraphValidator::validate(module);

BOOST_REQUIRE(result.s_IsValid == false);
BOOST_REQUIRE(result.s_ForbiddenOps.empty());
BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 4);
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin"));
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::cos"));
BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 2);
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::tan"));
BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::exp"));
}
Expand Down