From 221f83f90f00efa5900aaf6bdf6c048e8e3c0ded Mon Sep 17 00:00:00 2001 From: John Yin <10972267+john-yin2333@user.noreply.gitee.com> Date: Thu, 21 May 2026 15:57:40 +0800 Subject: [PATCH] fix(provider): preserve Azure streaming tool calls Co-authored-by: Cursor --- flocks/provider/sdk/azure.py | 63 +++++++++-- tests/provider/test_azure_provider.py | 149 ++++++++++++++++++++++++++ 2 files changed, 206 insertions(+), 6 deletions(-) diff --git a/flocks/provider/sdk/azure.py b/flocks/provider/sdk/azure.py index 752c6e660..a555479ab 100644 --- a/flocks/provider/sdk/azure.py +++ b/flocks/provider/sdk/azure.py @@ -2,7 +2,7 @@ Azure OpenAI provider implementation """ -from typing import List, AsyncIterator +from typing import Any, Dict, List, AsyncIterator import os from flocks.provider.provider import ( @@ -210,11 +210,62 @@ async def chat_stream( else: raise + tool_calls: Dict[int, Dict[str, Any]] = {} + async for chunk in stream: - if chunk.choices: - choice = chunk.choices[0] - if choice.delta.content: + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + if delta is None: + if choice.finish_reason: + if tool_calls: + yield StreamChunk( + delta="", + finish_reason="tool_calls", + tool_calls=[tool_calls[i] for i in sorted(tool_calls.keys())], + ) + else: + yield StreamChunk(delta="", finish_reason=choice.finish_reason) + continue + + delta_text = getattr(delta, "content", None) + if delta_text: + yield StreamChunk(delta=delta_text, finish_reason=None) + + delta_tool_calls = getattr(delta, "tool_calls", None) + if delta_tool_calls: + for tool_call_delta in delta_tool_calls: + index = getattr(tool_call_delta, "index", 0) + if index not in tool_calls: + tool_calls[index] = { + "id": getattr(tool_call_delta, "id", None) or "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + + tool_call_id = getattr(tool_call_delta, "id", None) + if tool_call_id: + tool_calls[index]["id"] = tool_call_id + + function_delta = getattr(tool_call_delta, "function", None) + if function_delta: + function_name = getattr(function_delta, "name", None) + if function_name: + tool_calls[index]["function"]["name"] = function_name + + function_arguments = getattr(function_delta, "arguments", None) + if function_arguments: + tool_calls[index]["function"]["arguments"] += function_arguments + + if choice.finish_reason: + if tool_calls: yield StreamChunk( - delta=choice.delta.content, - finish_reason=choice.finish_reason, + delta="", + finish_reason="tool_calls", + tool_calls=[tool_calls[i] for i in sorted(tool_calls.keys())], ) + else: + yield StreamChunk(delta="", finish_reason=choice.finish_reason) diff --git a/tests/provider/test_azure_provider.py b/tests/provider/test_azure_provider.py index d23d426db..286cb7436 100644 --- a/tests/provider/test_azure_provider.py +++ b/tests/provider/test_azure_provider.py @@ -1,7 +1,63 @@ +from types import SimpleNamespace + +import pytest + +from flocks.provider.provider import ChatMessage from flocks.provider.provider import ModelCapabilities, ModelInfo from flocks.provider.sdk.azure import AzureProvider +class _FakeAzureStream: + def __init__(self, chunks): + self._chunks = chunks + + def __aiter__(self): + self._index = 0 + return self + + async def __anext__(self): + if self._index >= len(self._chunks): + raise StopAsyncIteration + chunk = self._chunks[self._index] + self._index += 1 + return chunk + + +class _FakeAzureCompletions: + def __init__(self, chunks): + self._chunks = chunks + self.last_request = None + + async def create(self, **kwargs): + self.last_request = kwargs + return _FakeAzureStream(self._chunks) + + +class _FakeAzureClient: + def __init__(self, chunks): + self.completions = _FakeAzureCompletions(chunks) + self.chat = SimpleNamespace(completions=self.completions) + + +def _chunk(delta=None, finish_reason=None): + return SimpleNamespace( + choices=[ + SimpleNamespace( + delta=delta, + finish_reason=finish_reason, + ) + ] + ) + + +def _tool_call_delta(index=0, call_id=None, name=None, arguments=None): + return SimpleNamespace( + index=index, + id=call_id, + function=SimpleNamespace(name=name, arguments=arguments), + ) + + def test_azure_provider_returns_configured_deployment_models(): provider = AzureProvider() provider._config_models = [ @@ -31,3 +87,96 @@ def test_azure_provider_returns_fallback_models_without_config(): assert {m.id for m in models} == {"gpt-5.4", "gpt-5-mini"} assert all(m.provider_id == "azure" for m in models) + + +@pytest.mark.asyncio +async def test_azure_chat_stream_emits_tool_calls(): + chunks = [ + _chunk( + delta=SimpleNamespace( + content=None, + tool_calls=[ + _tool_call_delta( + index=0, + call_id="call_1", + name="delegate_task", + arguments='{"subagent_type":"explore",', + ) + ], + ) + ), + _chunk( + delta=SimpleNamespace( + content=None, + tool_calls=[ + _tool_call_delta( + index=0, + arguments='"prompt":"say ok"}', + ) + ], + ) + ), + _chunk( + delta=SimpleNamespace(content=None, tool_calls=None), + finish_reason="tool_calls", + ), + ] + client = _FakeAzureClient(chunks) + provider = AzureProvider() + provider._get_client = lambda: client + + emitted = [ + chunk + async for chunk in provider.chat_stream( + model_id="gpt-5.4-mini", + messages=[ChatMessage(role="user", content="call a sub agent")], + tools=[ + { + "type": "function", + "function": { + "name": "delegate_task", + "description": "delegate work", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + ] + + assert client.completions.last_request["tools"] + assert len(emitted) == 1 + assert emitted[0].finish_reason == "tool_calls" + assert emitted[0].tool_calls == [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "delegate_task", + "arguments": '{"subagent_type":"explore","prompt":"say ok"}', + }, + } + ] + + +@pytest.mark.asyncio +async def test_azure_chat_stream_still_emits_text_chunks(): + client = _FakeAzureClient([ + _chunk(delta=SimpleNamespace(content="hello", tool_calls=None)), + _chunk( + delta=SimpleNamespace(content=None, tool_calls=None), + finish_reason="stop", + ), + ]) + provider = AzureProvider() + provider._get_client = lambda: client + + emitted = [ + chunk + async for chunk in provider.chat_stream( + model_id="gpt-5.4-mini", + messages=[ChatMessage(role="user", content="hi")], + ) + ] + + assert [chunk.delta for chunk in emitted] == ["hello", ""] + assert emitted[-1].finish_reason == "stop"