Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,9 @@ async def _recv_loop(self) -> None:
continue

data = json.loads(msg.data)
context_id = data.get("contextId")
# ElevenLabs currently sends snake_case context IDs on the websocket API,
# while older responses and some examples use camelCase.
context_id = data.get("contextId") or data.get("context_id")
ctx = self._context_data.get(context_id) if context_id is not None else None

if error := data.get("error"):
Expand All @@ -733,6 +735,13 @@ async def _recv_loop(self) -> None:
continue

if ctx is None:
if data.get("type") == "flush_done":
logger.debug(
"ignoring elevenlabs flush_done message for inactive context",
extra={"context_id": context_id, "data": data},
)
continue

logger.warning(
"unexpected message received from elevenlabs tts", extra={"data": data}
)
Expand Down
193 changes: 192 additions & 1 deletion tests/test_plugin_elevenlabs_tts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,81 @@
"""Unit tests for ElevenLabs TTS plugin configuration behavior."""
"""Unit tests for ElevenLabs TTS plugin configuration and websocket behavior."""

import asyncio
import base64
import json
from types import SimpleNamespace

import aiohttp
import pytest

from livekit.plugins.elevenlabs import tts as elevenlabs_tts


class _FakeWebSocket:
def __init__(self, messages: list[object]) -> None:
self._messages = messages
self.closed = False

async def receive(self) -> object:
if self._messages:
return self._messages.pop(0)
return SimpleNamespace(type=aiohttp.WSMsgType.CLOSE, data="")

async def close(self) -> None:
self.closed = True


class _FakeEmitter:
def __init__(self) -> None:
self.audio_chunks: list[bytes] = []
self.timed_transcript_pushes = 0

def push(self, audio: bytes) -> None:
self.audio_chunks.append(audio)

def push_timed_transcript(self, _timed_words: object) -> None:
self.timed_transcript_pushes += 1


class _FakeStream:
def __init__(self) -> None:
self._text_buffer = ""
self._start_times_ms: list[int] = []
self._durations_ms: list[int] = []


class _FakeConnection:
def __init__(self, context_id: str, messages: list[object]) -> None:
self._closed = False
self._ws = _FakeWebSocket(messages)
self._is_current = True
self._active_contexts = {context_id}
self.emitter = _FakeEmitter()
self.waiter: asyncio.Future[None] = asyncio.get_event_loop().create_future()
self._context_data = {
context_id: elevenlabs_tts._StreamData(
emitter=self.emitter,
stream=_FakeStream(),
waiter=self.waiter,
)
}
self.preferred_alignment = "normalized"

def _cleanup_context(self, context_id: str) -> None:
ctx = self._context_data.pop(context_id, None)
if ctx and ctx.timeout_timer:
ctx.timeout_timer.cancel()
self._active_contexts.discard(context_id)

async def aclose(self) -> None:
self._closed = True
await self._ws.close()


def _websocket_text_message(payload: dict[str, object]) -> object:
return SimpleNamespace(type=aiohttp.WSMsgType.TEXT, data=json.dumps(payload))


def test_auto_mode_defaults_to_true_without_chunk_length_schedule() -> None:
tts = elevenlabs_tts.TTS(api_key="test-key")
assert tts._opts.auto_mode is True
Expand Down Expand Up @@ -62,3 +135,121 @@ def test_build_context_init_packet_includes_pronunciation_dictionaries() -> None
"version_id": "v1",
}
]


@pytest.mark.asyncio
async def test_recv_loop_accepts_snake_case_context_id() -> None:
context_id = "ctx_123"
audio_chunk = b"hello-audio"
connection = _FakeConnection(
context_id,
[
_websocket_text_message(
{
"context_id": context_id,
"audio": base64.b64encode(audio_chunk).decode("ascii"),
"isFinal": True,
}
),
],
)

await elevenlabs_tts._Connection._recv_loop(connection)

assert connection.emitter.audio_chunks == [audio_chunk]
assert connection.waiter.done()
assert connection.waiter.result() is None
assert connection._context_data == {}


@pytest.mark.asyncio
async def test_recv_loop_still_accepts_camel_case_context_id() -> None:
context_id = "ctx_123"
audio_chunk = b"hello-audio"
connection = _FakeConnection(
context_id,
[
_websocket_text_message(
{
"contextId": context_id,
"audio": base64.b64encode(audio_chunk).decode("ascii"),
"isFinal": True,
}
),
],
)

await elevenlabs_tts._Connection._recv_loop(connection)

assert connection.emitter.audio_chunks == [audio_chunk]
assert connection.waiter.done()
assert connection.waiter.result() is None
assert connection._context_data == {}


@pytest.mark.asyncio
async def test_recv_loop_ignores_flush_done_for_active_context() -> None:
context_id = "ctx_123"
audio_chunk = b"hello-audio"
connection = _FakeConnection(
context_id,
[
_websocket_text_message(
{
"type": "flush_done",
"context_id": context_id,
"status_code": 206,
"done": False,
"data": "",
"flush_done": True,
}
),
_websocket_text_message(
{
"context_id": context_id,
"audio": base64.b64encode(audio_chunk).decode("ascii"),
"isFinal": True,
}
),
],
)

await elevenlabs_tts._Connection._recv_loop(connection)

assert connection.emitter.audio_chunks == [audio_chunk]
assert connection.waiter.done()
assert connection.waiter.result() is None


@pytest.mark.asyncio
async def test_recv_loop_ignores_flush_done_for_inactive_context() -> None:
context_id = "ctx_123"
audio_chunk = b"hello-audio"
connection = _FakeConnection(
context_id,
[
_websocket_text_message(
{
"type": "flush_done",
"context_id": "already_closed_context",
"status_code": 206,
"done": False,
"data": "",
"flush_done": True,
}
),
_websocket_text_message(
{
"context_id": context_id,
"audio": base64.b64encode(audio_chunk).decode("ascii"),
"isFinal": True,
}
),
],
)

await elevenlabs_tts._Connection._recv_loop(connection)

assert connection.emitter.audio_chunks == [audio_chunk]
assert connection.waiter.done()
assert connection.waiter.result() is None
Loading