From 577dcf6b6fb396da2f04216b0c2c6fd5a2dd49fb Mon Sep 17 00:00:00 2001 From: OhYee Date: Fri, 22 May 2026 11:17:06 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix(tool,=20toolset):=20=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=20additionalProperties=20schema=20=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 additional_properties 字段类型从 bool 扩展为 Union[bool, ToolSchema] - 递归解析 schema-valued additionalProperties,序列化时保留为对象 - 修正 from_any_openapi_schema 入口条件:if not schema → if schema is None, 使空 dict {} 作为合法 JSON Schema(any)正确解析,而非降级为 string - 将嵌套三元表达式重构为 if/elif/else 提升可读性 - 添加 OpenAPI、MCP tool schema 及空 schema 的回归测试 Change-Id: I1e14f80b169adbb7693f0ee668ed6462e2a4803c Co-developed-by: Claude Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: OhYee --- agentrun/tool/model.py | 24 ++++-- agentrun/toolset/model.py | 22 ++++-- tests/unittests/tool/test_model.py | 106 +++++++++++++++++++++++++- tests/unittests/toolset/test_model.py | 68 ++++++++++++++++- 4 files changed, 203 insertions(+), 17 deletions(-) diff --git a/agentrun/tool/model.py b/agentrun/tool/model.py index 0bdb001..971606f 100644 --- a/agentrun/tool/model.py +++ b/agentrun/tool/model.py @@ -5,7 +5,7 @@ """ from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from agentrun.utils.model import BaseModel @@ -194,8 +194,8 @@ class ToolSchema(BaseModel): required: Optional[List[str]] = None """必填字段 / Required fields""" - additional_properties: Optional[bool] = None - """是否允许额外属性 / Whether additional properties are allowed""" + additional_properties: Optional[Union[bool, "ToolSchema"]] = None + """额外属性约束 / Additional properties constraint""" items: Optional["ToolSchema"] = None """数组元素类型 / Array item type""" @@ -252,7 +252,7 @@ def from_any_openapi_schema(cls, schema: Any) -> "ToolSchema": 递归解析所有嵌套结构,保留完整的 schema 信息。 Recursively parses all nested structures, preserving complete schema information. """ - if not schema or not isinstance(schema, dict): + if schema is None or not isinstance(schema, dict): return cls(type="string") from pydash import get as pydash_get @@ -291,13 +291,21 @@ def from_any_openapi_schema(cls, schema: Any) -> "ToolSchema": else None ) + additional_properties_raw = pydash_get(schema, "additionalProperties") + if isinstance(additional_properties_raw, dict): + additional_properties = cls.from_any_openapi_schema( + additional_properties_raw + ) + else: + additional_properties = additional_properties_raw + return cls( type=pydash_get(schema, "type"), description=pydash_get(schema, "description"), title=pydash_get(schema, "title"), properties=properties, required=pydash_get(schema, "required"), - additional_properties=pydash_get(schema, "additionalProperties"), + additional_properties=additional_properties, items=items, min_items=pydash_get(schema, "minItems"), max_items=pydash_get(schema, "maxItems"), @@ -334,7 +342,11 @@ def to_json_schema(self) -> Dict[str, Any]: if self.required: result["required"] = self.required if self.additional_properties is not None: - result["additionalProperties"] = self.additional_properties + result["additionalProperties"] = ( + self.additional_properties.to_json_schema() + if isinstance(self.additional_properties, ToolSchema) + else self.additional_properties + ) if self.items: result["items"] = self.items.to_json_schema() diff --git a/agentrun/toolset/model.py b/agentrun/toolset/model.py index 29ddbe7..35ea2d4 100644 --- a/agentrun/toolset/model.py +++ b/agentrun/toolset/model.py @@ -5,7 +5,7 @@ """ from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from agentrun.utils.model import BaseModel, Field, PageableInput @@ -103,7 +103,7 @@ class ToolSchema(BaseModel): # 对象类型字段 properties: Optional[Dict[str, "ToolSchema"]] = None required: Optional[List[str]] = None - additional_properties: Optional[bool] = None + additional_properties: Optional[Union[bool, "ToolSchema"]] = None # 数组类型字段 items: Optional["ToolSchema"] = None @@ -137,7 +137,7 @@ def from_any_openapi_schema(cls, schema: Any) -> "ToolSchema": 递归解析所有嵌套结构,保留完整的 schema 信息。 """ - if not schema or not isinstance(schema, dict): + if schema is None or not isinstance(schema, dict): return cls(type="string") from pydash import get as pg @@ -179,6 +179,14 @@ def from_any_openapi_schema(cls, schema: Any) -> "ToolSchema": else None ) + additional_properties_raw = pg(schema, "additionalProperties") + if isinstance(additional_properties_raw, dict): + additional_properties = cls.from_any_openapi_schema( + additional_properties_raw + ) + else: + additional_properties = additional_properties_raw + return cls( # 基本字段 type=pg(schema, "type"), @@ -187,7 +195,7 @@ def from_any_openapi_schema(cls, schema: Any) -> "ToolSchema": # 对象类型 properties=properties, required=pg(schema, "required"), - additional_properties=pg(schema, "additionalProperties"), + additional_properties=additional_properties, # 数组类型 items=items, min_items=pg(schema, "minItems"), @@ -231,7 +239,11 @@ def to_json_schema(self) -> Dict[str, Any]: if self.required: result["required"] = self.required if self.additional_properties is not None: - result["additionalProperties"] = self.additional_properties + result["additionalProperties"] = ( + self.additional_properties.to_json_schema() + if isinstance(self.additional_properties, ToolSchema) + else self.additional_properties + ) # 数组类型 if self.items: diff --git a/tests/unittests/tool/test_model.py b/tests/unittests/tool/test_model.py index b195d0e..11682a5 100644 --- a/tests/unittests/tool/test_model.py +++ b/tests/unittests/tool/test_model.py @@ -265,10 +265,9 @@ def test_from_any_openapi_schema_empty(self): schema = ToolSchema.from_any_openapi_schema(None) assert schema.type == "string" - # Empty dict creates a ToolSchema with None type (pydash_get returns None for missing keys) + # Empty dict is a valid JSON Schema meaning "any" — should not collapse to "string" schema = ToolSchema.from_any_openapi_schema({}) - # Actually returns "string" due to the check at the beginning of the method - assert schema.type == "string" + assert schema.type is None def test_from_any_openapi_schema_non_dict(self): """测试从非 dict 输入创建""" @@ -335,6 +334,72 @@ def test_from_any_openapi_schema_allof(self): assert schema.all_of[0].type == "object" assert schema.all_of[1].type == "object" + def test_from_any_openapi_schema_additional_properties_schema(self): + """测试 additionalProperties 为 schema 对象时的解析""" + openapi_schema = { + "type": "object", + "properties": { + "filters": { + "type": "object", + "additionalProperties": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + }, + } + }, + } + + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + + assert schema.properties is not None + assert "filters" in schema.properties + filters_schema = schema.properties["filters"] + assert filters_schema.additional_properties is not None + assert filters_schema.additional_properties.any_of is not None + assert len(filters_schema.additional_properties.any_of) == 2 + assert filters_schema.additional_properties.any_of[0].type == "string" + assert filters_schema.additional_properties.any_of[1].type == "integer" + + json_schema = schema.to_json_schema() + assert ( + json_schema["properties"]["filters"]["additionalProperties"][ + "anyOf" + ][0]["type"] + == "string" + ) + assert ( + json_schema["properties"]["filters"]["additionalProperties"][ + "anyOf" + ][1]["type"] + == "integer" + ) + + def test_from_any_openapi_schema_empty_additional_properties_schema(self): + """测试 additionalProperties 为空 schema 时保留原语义""" + openapi_schema = { + "type": "object", + "properties": { + "metadata": { + "type": "object", + "additionalProperties": {}, + } + }, + } + + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + + assert schema.properties is not None + metadata_schema = schema.properties["metadata"] + assert metadata_schema.additional_properties is not None + assert metadata_schema.additional_properties.type is None + + json_schema = schema.to_json_schema() + assert ( + json_schema["properties"]["metadata"]["additionalProperties"] == {} + ) + def test_to_json_schema_simple(self): """测试转换为 JSON Schema - 简单情况""" schema = ToolSchema( @@ -659,3 +724,38 @@ def model_dump(self): assert "param1" in info.parameters.properties assert "param2" in info.parameters.properties assert info.parameters.required == ["param1"] + + def test_from_mcp_tool_with_schema_additional_properties(self): + """测试 MCP tool schema 中 additionalProperties 为对象时的解析""" + mcp_tool = { + "name": "get_news_by_date", + "description": "A tool with MCP schema-style date_range", + "inputSchema": { + "type": "object", + "properties": { + "date_range": { + "anyOf": [ + {"type": "string"}, + { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + {"type": "null"}, + ] + } + }, + }, + } + + info = ToolInfo.from_mcp_tool(mcp_tool) + + assert info.name == "get_news_by_date" + assert info.parameters is not None + assert info.parameters.properties is not None + date_range_schema = info.parameters.properties["date_range"] + assert date_range_schema.any_of is not None + assert date_range_schema.any_of[1].type == "object" + assert date_range_schema.any_of[1].additional_properties is not None + assert ( + date_range_schema.any_of[1].additional_properties.type == "string" + ) diff --git a/tests/unittests/toolset/test_model.py b/tests/unittests/toolset/test_model.py index d183587..5cd388a 100644 --- a/tests/unittests/toolset/test_model.py +++ b/tests/unittests/toolset/test_model.py @@ -419,10 +419,9 @@ def test_from_any_openapi_schema_none(self): assert schema.type == "string" def test_from_any_openapi_schema_empty_dict(self): - """测试从空字典创建""" + """测试从空字典创建 — 空 dict 是合法的 JSON Schema,表示 'any'""" schema = ToolSchema.from_any_openapi_schema({}) - # 空字典被视为 falsy,所以返回默认的 string 类型 - assert schema.type == "string" + assert schema.type is None def test_from_any_openapi_schema_non_dict(self): """测试从非字典 schema 创建""" @@ -448,6 +447,69 @@ def test_from_any_openapi_schema_with_properties(self): assert schema.required == ["name"] assert schema.additional_properties is False + def test_from_any_openapi_schema_with_schema_additional_properties(self): + """测试 additionalProperties 为 schema 对象时的解析""" + openapi_schema = { + "type": "object", + "properties": { + "filters": { + "type": "object", + "additionalProperties": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] + }, + } + }, + } + + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + + assert schema.properties is not None + filters_schema = schema.properties["filters"] + assert filters_schema.additional_properties is not None + assert filters_schema.additional_properties.any_of is not None + assert len(filters_schema.additional_properties.any_of) == 2 + + json_schema = schema.to_json_schema() + assert ( + json_schema["properties"]["filters"]["additionalProperties"][ + "anyOf" + ][0]["type"] + == "string" + ) + assert ( + json_schema["properties"]["filters"]["additionalProperties"][ + "anyOf" + ][1]["type"] + == "integer" + ) + + def test_from_any_openapi_schema_with_empty_additional_properties(self): + """测试 additionalProperties 为空 schema 时保留原语义""" + openapi_schema = { + "type": "object", + "properties": { + "metadata": { + "type": "object", + "additionalProperties": {}, + } + }, + } + + schema = ToolSchema.from_any_openapi_schema(openapi_schema) + + assert schema.properties is not None + metadata_schema = schema.properties["metadata"] + assert metadata_schema.additional_properties is not None + assert metadata_schema.additional_properties.type is None + + json_schema = schema.to_json_schema() + assert ( + json_schema["properties"]["metadata"]["additionalProperties"] == {} + ) + def test_from_any_openapi_schema_with_items(self): """测试从带 items 的数组 schema 创建""" openapi_schema = { From e9b7d2f75d5418876f1715edd4746052266f7a9c Mon Sep 17 00:00:00 2001 From: OhYee Date: Fri, 22 May 2026 11:17:38 +0800 Subject: [PATCH 2/2] style(credential, knowledgebase, memory_collection, model, sandbox): format code for better readability Change-Id: If7b4e36b9005765d075e7ac2e535a10c5beb4f37 Co-developed-by: Claude Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: OhYee --- .../credential/__credential_async_template.py | 4 ++- agentrun/credential/credential.py | 12 ++++++-- .../__knowledgebase_async_template.py | 4 ++- agentrun/knowledgebase/knowledgebase.py | 8 ++++-- .../__memory_collection_async_template.py | 4 ++- .../memory_collection/memory_collection.py | 4 ++- .../model/__model_proxy_async_template.py | 4 ++- .../model/__model_service_async_template.py | 4 ++- agentrun/model/model_proxy.py | 4 ++- agentrun/model/model_service.py | 4 ++- agentrun/sandbox/client.py | 6 ++-- agentrun/sandbox/template.py | 8 ++---- tests/e2e/test_workspace_id.py | 28 ++++++++----------- .../test_memory_collection.py | 12 ++++++-- tests/unittests/test_workspace_id.py | 9 ++---- 15 files changed, 65 insertions(+), 50 deletions(-) diff --git a/agentrun/credential/__credential_async_template.py b/agentrun/credential/__credential_async_template.py index c496b21..622153c 100644 --- a/agentrun/credential/__credential_async_template.py +++ b/agentrun/credential/__credential_async_template.py @@ -62,7 +62,9 @@ async def create_async( Returns: Credential: 创建的凭证对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod async def delete_by_name_async( diff --git a/agentrun/credential/credential.py b/agentrun/credential/credential.py index 3beb8c8..b904612 100644 --- a/agentrun/credential/credential.py +++ b/agentrun/credential/credential.py @@ -72,7 +72,9 @@ async def create_async( Returns: Credential: 创建的凭证对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod def create( @@ -113,7 +115,9 @@ def delete_by_name( credential_name: 凭证名称 config: 配置 """ - return cls.__get_client(config=config).delete(credential_name, config=config) + return cls.__get_client(config=config).delete( + credential_name, config=config + ) @classmethod async def update_by_name_async( @@ -185,7 +189,9 @@ def get_by_name(cls, credential_name: str, config: Optional[Config] = None): Returns: Credential: 凭证对象 """ - return cls.__get_client(config=config).get(credential_name, config=config) + return cls.__get_client(config=config).get( + credential_name, config=config + ) @classmethod async def _list_page_async( diff --git a/agentrun/knowledgebase/__knowledgebase_async_template.py b/agentrun/knowledgebase/__knowledgebase_async_template.py index 07d94a5..114496f 100644 --- a/agentrun/knowledgebase/__knowledgebase_async_template.py +++ b/agentrun/knowledgebase/__knowledgebase_async_template.py @@ -81,7 +81,9 @@ async def create_async( Returns: KnowledgeBase: 创建的知识库对象 / Created knowledge base object """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod async def delete_by_name_async( diff --git a/agentrun/knowledgebase/knowledgebase.py b/agentrun/knowledgebase/knowledgebase.py index a6db0e6..5b075a1 100644 --- a/agentrun/knowledgebase/knowledgebase.py +++ b/agentrun/knowledgebase/knowledgebase.py @@ -91,7 +91,9 @@ async def create_async( Returns: KnowledgeBase: 创建的知识库对象 / Created knowledge base object """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod def create( @@ -208,7 +210,9 @@ def get_by_name( Returns: KnowledgeBase: 知识库对象 / KnowledgeBase object """ - return cls.__get_client(config=config).get(knowledge_base_name, config=config) + return cls.__get_client(config=config).get( + knowledge_base_name, config=config + ) @classmethod async def _list_page_async( diff --git a/agentrun/memory_collection/__memory_collection_async_template.py b/agentrun/memory_collection/__memory_collection_async_template.py index d1cdacb..2823516 100644 --- a/agentrun/memory_collection/__memory_collection_async_template.py +++ b/agentrun/memory_collection/__memory_collection_async_template.py @@ -60,7 +60,9 @@ async def create_async( Returns: MemoryCollection: 创建的记忆集合对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod async def delete_by_name_async( diff --git a/agentrun/memory_collection/memory_collection.py b/agentrun/memory_collection/memory_collection.py index a0f9b7e..e12903e 100644 --- a/agentrun/memory_collection/memory_collection.py +++ b/agentrun/memory_collection/memory_collection.py @@ -70,7 +70,9 @@ async def create_async( Returns: MemoryCollection: 创建的记忆集合对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod def create( diff --git a/agentrun/model/__model_proxy_async_template.py b/agentrun/model/__model_proxy_async_template.py index 4c7b29b..35afe62 100644 --- a/agentrun/model/__model_proxy_async_template.py +++ b/agentrun/model/__model_proxy_async_template.py @@ -57,7 +57,9 @@ async def create_async( Returns: ModelProxy: 创建的模型服务对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod async def delete_by_name_async( diff --git a/agentrun/model/__model_service_async_template.py b/agentrun/model/__model_service_async_template.py index 053aa5c..72d838f 100644 --- a/agentrun/model/__model_service_async_template.py +++ b/agentrun/model/__model_service_async_template.py @@ -52,7 +52,9 @@ async def create_async( Returns: ModelService: 创建的模型服务对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod async def delete_by_name_async( diff --git a/agentrun/model/model_proxy.py b/agentrun/model/model_proxy.py index 846a5d6..d9e1642 100644 --- a/agentrun/model/model_proxy.py +++ b/agentrun/model/model_proxy.py @@ -67,7 +67,9 @@ async def create_async( Returns: ModelProxy: 创建的模型服务对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod def create( diff --git a/agentrun/model/model_service.py b/agentrun/model/model_service.py index 1a9568a..f04a8b6 100644 --- a/agentrun/model/model_service.py +++ b/agentrun/model/model_service.py @@ -62,7 +62,9 @@ async def create_async( Returns: ModelService: 创建的模型服务对象 """ - return await cls.__get_client(config=config).create_async(input, config=config) + return await cls.__get_client(config=config).create_async( + input, config=config + ) @classmethod def create( diff --git a/agentrun/sandbox/client.py b/agentrun/sandbox/client.py index e2bc63d..fefd01b 100644 --- a/agentrun/sandbox/client.py +++ b/agentrun/sandbox/client.py @@ -748,8 +748,7 @@ async def delete_sandbox_async( raise ClientError( status_code=0, message=( - "Failed to stop sandbox:" - f" {message or 'Unknown error'}" + f"Failed to stop sandbox: {message or 'Unknown error'}" ), ) @@ -805,8 +804,7 @@ def delete_sandbox( raise ClientError( status_code=0, message=( - "Failed to stop sandbox:" - f" {message or 'Unknown error'}" + f"Failed to stop sandbox: {message or 'Unknown error'}" ), ) diff --git a/agentrun/sandbox/template.py b/agentrun/sandbox/template.py index 93611c4..d95cdc0 100644 --- a/agentrun/sandbox/template.py +++ b/agentrun/sandbox/template.py @@ -118,9 +118,7 @@ async def create_async( ) @classmethod - def create( - cls, input: TemplateInput, config: Optional[Config] = None - ): + def create(cls, input: TemplateInput, config: Optional[Config] = None): return cls.__get_client(config=config).create_template( input, config=config ) @@ -172,9 +170,7 @@ async def get_by_name_async( ) @classmethod - def get_by_name( - cls, template_name: str, config: Optional[Config] = None - ): + def get_by_name(cls, template_name: str, config: Optional[Config] = None): return cls.__get_client(config=config).get_template( template_name=template_name, config=config ) diff --git a/tests/e2e/test_workspace_id.py b/tests/e2e/test_workspace_id.py index 752e7cd..ff14db3 100644 --- a/tests/e2e/test_workspace_id.py +++ b/tests/e2e/test_workspace_id.py @@ -32,18 +32,16 @@ CredentialListInput, ) from agentrun.sandbox import Template -from agentrun.sandbox.model import ( - PageableInput, - TemplateInput, - TemplateType, -) +from agentrun.sandbox.model import PageableInput, TemplateInput, TemplateType from agentrun.utils.exception import ResourceNotExistError WORKSPACE_ID = os.getenv("AGENTRUN_TEST_WORKSPACE_ID") pytestmark = pytest.mark.skipif( not WORKSPACE_ID, - reason="AGENTRUN_TEST_WORKSPACE_ID not configured; skipping workspace_id E2E", + reason=( + "AGENTRUN_TEST_WORKSPACE_ID not configured; skipping workspace_id E2E" + ), ) @@ -98,7 +96,8 @@ async def test_credential_with_workspace_id_async( ) names = [item.credential_name for item in list_results] assert credential_name in names, ( - f"list(workspace_id={ws!r}) 未返回刚创建的凭证 {credential_name!r}," + f"list(workspace_id={ws!r}) 未返回刚创建的凭证" + f" {credential_name!r}," f"实际返回 {names!r}" ) # 列表项的 workspace_id 也应该是同一个 @@ -112,9 +111,7 @@ async def test_credential_with_workspace_id_async( except ResourceNotExistError: pass - def test_credential_with_workspace_id( - self, credential_name: str - ): + def test_credential_with_workspace_id(self, credential_name: str): """凭证创建时指定 workspace_id,回读与列举均能拿到该 workspace_id""" client = CredentialClient() ws = WORKSPACE_ID # type: ignore[assignment] @@ -139,20 +136,17 @@ def test_credential_with_workspace_id( ), f"create 返回的 workspace_id 不匹配: {cred.workspace_id!r}" # 2. get 接口回读 workspace_id - cred_fetched = client.get( - credential_name=credential_name - ) + cred_fetched = client.get(credential_name=credential_name) assert ( cred_fetched.workspace_id == ws ), f"get 返回的 workspace_id 不匹配: {cred_fetched.workspace_id!r}" # 3. list 接口按 workspace_id 过滤,本次创建的资源应在结果中 - list_results = client.list( - CredentialListInput(workspace_id=ws) - ) + list_results = client.list(CredentialListInput(workspace_id=ws)) names = [item.credential_name for item in list_results] assert credential_name in names, ( - f"list(workspace_id={ws!r}) 未返回刚创建的凭证 {credential_name!r}," + f"list(workspace_id={ws!r}) 未返回刚创建的凭证" + f" {credential_name!r}," f"实际返回 {names!r}" ) # 列表项的 workspace_id 也应该是同一个 diff --git a/tests/unittests/memory_collection/test_memory_collection.py b/tests/unittests/memory_collection/test_memory_collection.py index 70b8894..1c90517 100644 --- a/tests/unittests/memory_collection/test_memory_collection.py +++ b/tests/unittests/memory_collection/test_memory_collection.py @@ -630,7 +630,9 @@ def test_build_mem0_config_with_mysql_sync(self, mock_get_credential): assert vs_config["port"] == 3307 assert vs_config["embedding_model_dims"] == 1024 - @patch("agentrun.memory_collection.memory_collection.MemoryCollection._resolve_model_service_config") + @patch( + "agentrun.memory_collection.memory_collection.MemoryCollection._resolve_model_service_config" + ) @patch("agentrun.credential.Credential.get_by_name") def test_build_mem0_config_mysql_embedder_dims_sync( self, mock_get_credential, mock_resolve @@ -660,10 +662,14 @@ def test_build_mem0_config_mysql_embedder_dims_sync( config=EmbedderConfigConfig(model="text-embedding-v3"), ), ) - config = MemoryCollection._build_mem0_config(memory_collection, None, None) + config = MemoryCollection._build_mem0_config( + memory_collection, None, None + ) assert config["embedder"]["config"]["embedding_dims"] == 1024 - @patch("agentrun.memory_collection.memory_collection.MemoryCollection._resolve_model_service_config_async") + @patch( + "agentrun.memory_collection.memory_collection.MemoryCollection._resolve_model_service_config_async" + ) @patch("agentrun.credential.Credential.get_by_name_async") @pytest.mark.asyncio async def test_build_mem0_config_mysql_embedder_dims_async( diff --git a/tests/unittests/test_workspace_id.py b/tests/unittests/test_workspace_id.py index 3795640..e6e311f 100644 --- a/tests/unittests/test_workspace_id.py +++ b/tests/unittests/test_workspace_id.py @@ -36,13 +36,8 @@ ModelServiceCreateInput, ModelServiceListInput, ) -from agentrun.sandbox.model import ( - PageableInput as SandboxPageableInput, -) -from agentrun.sandbox.model import ( - TemplateInput, - TemplateType, -) +from agentrun.sandbox.model import PageableInput as SandboxPageableInput +from agentrun.sandbox.model import TemplateInput, TemplateType from agentrun.utils.model import BaseModel WORKSPACE_ID = "ws-test-12345"