Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def __init__(

self._request_id = ""
self._reconnect_event = asyncio.Event()
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._configure_task: asyncio.Task[None] | None = None

def update_options(
self,
Expand All @@ -313,6 +315,20 @@ def update_options(
# deprecated
keyterms: NotGivenOr[list[str]] = NOT_GIVEN,
) -> None:
if is_given(keyterms):
logger.warning(
"`keyterms` is deprecated, use `keyterm` instead for consistency with Deepgram API."
)
keyterm = keyterms

requires_reconnect = (
is_given(model)
or is_given(sample_rate)
or is_given(mip_opt_out)
or is_given(endpoint_url)
or is_given(tags)
)

if is_given(model):
self._opts.model = model
if is_given(sample_rate):
Expand All @@ -321,11 +337,6 @@ def update_options(
self._opts.eot_threshold = eot_threshold
if is_given(eot_timeout_ms):
self._opts.eot_timeout_ms = eot_timeout_ms
if is_given(keyterms):
logger.warning(
"`keyterms` is deprecated, use `keyterm` instead for consistency with Deepgram API."
)
keyterm = keyterms
if is_given(keyterm):
self._opts.keyterm = keyterm
if is_given(mip_opt_out):
Expand All @@ -339,6 +350,62 @@ def update_options(
if is_given(eager_eot_threshold):
self._opts.eager_eot_threshold = eager_eot_threshold

if requires_reconnect:
self._reconnect_event.set()
elif self._ws is not None and not self._ws.closed:
self._send_configure(
keyterm=keyterm,
eot_threshold=eot_threshold,
eot_timeout_ms=eot_timeout_ms,
eager_eot_threshold=eager_eot_threshold,
language_hint=language_hint,
)
else:
self._reconnect_event.set()

def _send_configure(
self,
*,
keyterm: NotGivenOr[str | list[str]] = NOT_GIVEN,
eot_threshold: NotGivenOr[float] = NOT_GIVEN,
eot_timeout_ms: NotGivenOr[int] = NOT_GIVEN,
eager_eot_threshold: NotGivenOr[float] = NOT_GIVEN,
language_hint: NotGivenOr[list[str]] = NOT_GIVEN,
) -> None:
"""Send a Configure control message to update settings mid-stream without reconnecting."""
configure_msg: dict[str, Any] = {"type": "Configure"}

if is_given(keyterm):
terms = [keyterm] if isinstance(keyterm, str) else list(keyterm)
configure_msg["keyterms"] = terms

thresholds: dict[str, Any] = {}
if is_given(eot_threshold):
thresholds["eot_threshold"] = eot_threshold
if is_given(eot_timeout_ms):
thresholds["eot_timeout_ms"] = eot_timeout_ms
if is_given(eager_eot_threshold):
thresholds["eager_eot_threshold"] = eager_eot_threshold
if thresholds:
configure_msg["thresholds"] = thresholds

if is_given(language_hint):
configure_msg["language_hints"] = language_hint

if len(configure_msg) <= 1:
return

self._configure_task = asyncio.create_task(
self._do_send_configure(json.dumps(configure_msg))
)

async def _do_send_configure(self, msg_str: str) -> None:
try:
if self._ws is not None and not self._ws.closed:
await self._ws.send_str(msg_str)
return
except Exception:
logger.debug("failed to send Configure message, falling back to reconnect")
self._reconnect_event.set()

async def _run(self) -> None:
Expand Down Expand Up @@ -424,6 +491,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
while True:
try:
ws = await self._connect_ws()
self._ws = ws
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
Expand Down Expand Up @@ -451,6 +519,10 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
tasks_group.cancel()
tasks_group.exception() # retrieve the exception
finally:
self._ws = None
if self._configure_task is not None and not self._configure_task.done():
self._configure_task.cancel()
self._configure_task = None
if ws is not None:
await ws.close()

Expand Down Expand Up @@ -568,6 +640,13 @@ def _process_stream_event(self, data: dict) -> None:
end_event = stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
self._event_ch.send_nowait(end_event)

elif data["type"] == "ConfigureSuccess":
logger.debug("Configure message applied", extra={"data": data})

elif data["type"] == "ConfigureFailure":
logger.warning("Configure message rejected by Deepgram", extra={"data": data})
self._reconnect_event.set()

elif data["type"] == "Error":
logger.warning("deepgram sent an error", extra={"data": data})
desc = data.get("description") or "unknown error from deepgram"
Expand Down
Loading