Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions flocks/provider/sdk/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Azure OpenAI provider implementation
"""

from typing import List, AsyncIterator
from typing import Any, Dict, List, AsyncIterator
import os

from flocks.provider.provider import (
Expand Down Expand Up @@ -210,11 +210,62 @@ async def chat_stream(
else:
raise

tool_calls: Dict[int, Dict[str, Any]] = {}

async for chunk in stream:
if chunk.choices:
choice = chunk.choices[0]
if choice.delta.content:
if not chunk.choices:
continue

choice = chunk.choices[0]
delta = choice.delta

if delta is None:
if choice.finish_reason:
if tool_calls:
yield StreamChunk(
delta="",
finish_reason="tool_calls",
tool_calls=[tool_calls[i] for i in sorted(tool_calls.keys())],
)
else:
yield StreamChunk(delta="", finish_reason=choice.finish_reason)
continue

delta_text = getattr(delta, "content", None)
if delta_text:
yield StreamChunk(delta=delta_text, finish_reason=None)

delta_tool_calls = getattr(delta, "tool_calls", None)
if delta_tool_calls:
for tool_call_delta in delta_tool_calls:
index = getattr(tool_call_delta, "index", 0)
if index not in tool_calls:
tool_calls[index] = {
"id": getattr(tool_call_delta, "id", None) or "",
"type": "function",
"function": {"name": "", "arguments": ""},
}

tool_call_id = getattr(tool_call_delta, "id", None)
if tool_call_id:
tool_calls[index]["id"] = tool_call_id

function_delta = getattr(tool_call_delta, "function", None)
if function_delta:
function_name = getattr(function_delta, "name", None)
if function_name:
tool_calls[index]["function"]["name"] = function_name

function_arguments = getattr(function_delta, "arguments", None)
if function_arguments:
tool_calls[index]["function"]["arguments"] += function_arguments

if choice.finish_reason:
if tool_calls:
yield StreamChunk(
delta=choice.delta.content,
finish_reason=choice.finish_reason,
delta="",
finish_reason="tool_calls",
tool_calls=[tool_calls[i] for i in sorted(tool_calls.keys())],
)
else:
yield StreamChunk(delta="", finish_reason=choice.finish_reason)
149 changes: 149 additions & 0 deletions tests/provider/test_azure_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,63 @@
from types import SimpleNamespace

import pytest

from flocks.provider.provider import ChatMessage
from flocks.provider.provider import ModelCapabilities, ModelInfo
from flocks.provider.sdk.azure import AzureProvider


class _FakeAzureStream:
def __init__(self, chunks):
self._chunks = chunks

def __aiter__(self):
self._index = 0
return self

async def __anext__(self):
if self._index >= len(self._chunks):
raise StopAsyncIteration
chunk = self._chunks[self._index]
self._index += 1
return chunk


class _FakeAzureCompletions:
def __init__(self, chunks):
self._chunks = chunks
self.last_request = None

async def create(self, **kwargs):
self.last_request = kwargs
return _FakeAzureStream(self._chunks)


class _FakeAzureClient:
def __init__(self, chunks):
self.completions = _FakeAzureCompletions(chunks)
self.chat = SimpleNamespace(completions=self.completions)


def _chunk(delta=None, finish_reason=None):
return SimpleNamespace(
choices=[
SimpleNamespace(
delta=delta,
finish_reason=finish_reason,
)
]
)


def _tool_call_delta(index=0, call_id=None, name=None, arguments=None):
return SimpleNamespace(
index=index,
id=call_id,
function=SimpleNamespace(name=name, arguments=arguments),
)


def test_azure_provider_returns_configured_deployment_models():
provider = AzureProvider()
provider._config_models = [
Expand Down Expand Up @@ -31,3 +87,96 @@ def test_azure_provider_returns_fallback_models_without_config():

assert {m.id for m in models} == {"gpt-5.4", "gpt-5-mini"}
assert all(m.provider_id == "azure" for m in models)


@pytest.mark.asyncio
async def test_azure_chat_stream_emits_tool_calls():
chunks = [
_chunk(
delta=SimpleNamespace(
content=None,
tool_calls=[
_tool_call_delta(
index=0,
call_id="call_1",
name="delegate_task",
arguments='{"subagent_type":"explore",',
)
],
)
),
_chunk(
delta=SimpleNamespace(
content=None,
tool_calls=[
_tool_call_delta(
index=0,
arguments='"prompt":"say ok"}',
)
],
)
),
_chunk(
delta=SimpleNamespace(content=None, tool_calls=None),
finish_reason="tool_calls",
),
]
client = _FakeAzureClient(chunks)
provider = AzureProvider()
provider._get_client = lambda: client

emitted = [
chunk
async for chunk in provider.chat_stream(
model_id="gpt-5.4-mini",
messages=[ChatMessage(role="user", content="call a sub agent")],
tools=[
{
"type": "function",
"function": {
"name": "delegate_task",
"description": "delegate work",
"parameters": {"type": "object", "properties": {}},
},
}
],
)
]

assert client.completions.last_request["tools"]
assert len(emitted) == 1
assert emitted[0].finish_reason == "tool_calls"
assert emitted[0].tool_calls == [
{
"id": "call_1",
"type": "function",
"function": {
"name": "delegate_task",
"arguments": '{"subagent_type":"explore","prompt":"say ok"}',
},
}
]


@pytest.mark.asyncio
async def test_azure_chat_stream_still_emits_text_chunks():
client = _FakeAzureClient([
_chunk(delta=SimpleNamespace(content="hello", tool_calls=None)),
_chunk(
delta=SimpleNamespace(content=None, tool_calls=None),
finish_reason="stop",
),
])
provider = AzureProvider()
provider._get_client = lambda: client

emitted = [
chunk
async for chunk in provider.chat_stream(
model_id="gpt-5.4-mini",
messages=[ChatMessage(role="user", content="hi")],
)
]

assert [chunk.delta for chunk in emitted] == ["hello", ""]
assert emitted[-1].finish_reason == "stop"