Skip to content

Commit a607727

Browse files
committed
sdk/python: Simplify worker and PolosClient into a single Polos instance
1 parent acb8450 commit a607727

File tree

5 files changed

+352
-5
lines changed

5 files changed

+352
-5
lines changed

sdk/python/polos/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .features.schedules import SchedulePayload
5353
from .middleware.guardrail import GuardrailContext, GuardrailResult, guardrail
5454
from .middleware.hook import HookAction, HookContext, HookResult, hook
55+
from .polos import Polos
5556
from .runtime.batch import batch_agent_invoke, batch_invoke
5657
from .runtime.client import ExecutionHandle, PolosClient
5758
from .runtime.queue import Queue, queue
@@ -78,6 +79,7 @@
7879
)
7980

8081
__all__ = [
82+
"Polos",
8183
"workflow",
8284
"Workflow",
8385
"get_workflow",

sdk/python/polos/features/tracing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,27 @@ def initialize_otel():
470470
_tracer = trace.NoOpTracer()
471471

472472

473+
def shutdown_otel():
474+
"""Flush pending spans and shut down the OpenTelemetry tracer provider.
475+
476+
Must be called before the worker context is torn down so that the
477+
DatabaseSpanExporter can still reach the PolosClient.
478+
"""
479+
global _tracer_provider, _tracer
480+
481+
if _tracer_provider is None:
482+
return
483+
484+
try:
485+
_tracer_provider.force_flush(timeout_millis=5000)
486+
_tracer_provider.shutdown()
487+
except Exception as e:
488+
logger.warning(f"Error during OpenTelemetry shutdown: {e}")
489+
finally:
490+
_tracer_provider = None
491+
_tracer = None
492+
493+
473494
def get_current_span() -> Span | None:
474495
"""Get the current active span."""
475496
if not OTELEMETRY_AVAILABLE:

sdk/python/polos/polos.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
"""Unified Polos class that combines PolosClient and Worker into a single object."""
2+
3+
import asyncio
4+
import logging
5+
import os
6+
from typing import Any
7+
8+
from .agents.agent import Agent
9+
from .core.workflow import get_all_workflows
10+
from .runtime.client import ExecutionHandle, PolosClient
11+
from .runtime.worker import Worker
12+
from .tools.tool import Tool
13+
from .types.types import BatchWorkflowInput
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def _configure_file_logging(log_file: str) -> None:
19+
"""Redirect all SDK logs to a file instead of stdout/stderr."""
20+
handler = logging.FileHandler(log_file, mode="a")
21+
handler.setLevel(logging.INFO)
22+
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)-8s %(name)s: %(message)s"))
23+
# Configure root logger so all SDK loggers (polos.*, httpx, etc.) use the file
24+
root = logging.getLogger()
25+
root.handlers.clear()
26+
root.addHandler(handler)
27+
root.setLevel(logging.INFO)
28+
29+
30+
class Polos:
31+
"""Unified Polos client + worker.
32+
33+
Combines PolosClient (submit/stream work) and Worker (receive/execute work)
34+
into a single object. One process, one import.
35+
36+
For production deployments where you need to scale workers independently
37+
from clients, use the separate PolosClient and Worker classes directly.
38+
39+
Usage::
40+
41+
from polos import Polos, Agent, tool
42+
43+
@tool(description="Get weather")
44+
async def get_weather(ctx, input):
45+
return {"temp": 72}
46+
47+
weather_agent = Agent(
48+
id="weather", provider="openai", model="gpt-4o",
49+
system_prompt="You are a weather assistant.",
50+
tools=[get_weather],
51+
)
52+
53+
# Script mode: context manager
54+
async with Polos() as polos:
55+
result = await weather_agent.run(polos, "What's the weather in Paris?")
56+
print(result)
57+
58+
# Script mode: start/stop
59+
polos = Polos()
60+
await polos.start()
61+
result = await weather_agent.run(polos, "What's the weather?")
62+
await polos.stop()
63+
64+
# Server mode: blocks until SIGINT/SIGTERM
65+
polos = Polos()
66+
await polos.serve()
67+
"""
68+
69+
def __init__(
70+
self,
71+
project_id: str | None = None,
72+
api_url: str | None = None,
73+
api_key: str | None = None,
74+
deployment_id: str | None = None,
75+
port: int = 8000,
76+
max_concurrent_workflows: int | None = None,
77+
channels: list | None = None,
78+
log_file: str | None = None,
79+
):
80+
self._project_id = project_id or os.getenv("POLOS_PROJECT_ID")
81+
self._api_url = api_url or os.getenv("POLOS_API_URL", "http://localhost:8080")
82+
self._api_key = api_key or os.getenv("POLOS_API_KEY")
83+
self._deployment_id = deployment_id or os.getenv("POLOS_DEPLOYMENT_ID", "default")
84+
self._port = port or os.getenv("POLOS_WORKER_PORT", 8000)
85+
self._channels = channels
86+
self._log_file = log_file
87+
88+
# Redirect SDK logs to file if requested
89+
if log_file:
90+
_configure_file_logging(log_file)
91+
92+
# Create client
93+
self._client = PolosClient(
94+
project_id=self._project_id,
95+
api_url=self._api_url,
96+
api_key=self._api_key,
97+
deployment_id=self._deployment_id,
98+
)
99+
100+
# Discover agents, tools, and workflows from the global WORKFLOW_REGISTRY.
101+
# They auto-register when defined (Agent(), Tool(), @workflow), so no need
102+
# to pass them explicitly.
103+
all_workflows = get_all_workflows()
104+
agents = [w for w in all_workflows.values() if isinstance(w, Agent)]
105+
tools = [w for w in all_workflows.values() if isinstance(w, Tool)]
106+
workflows = [w for w in all_workflows.values() if not isinstance(w, (Agent, Tool))]
107+
108+
# Create worker
109+
self._worker = Worker(
110+
client=self._client,
111+
deployment_id=self._deployment_id,
112+
agents=agents,
113+
tools=tools,
114+
workflows=workflows,
115+
max_concurrent_workflows=max_concurrent_workflows,
116+
mode="push",
117+
worker_server_url=f"http://localhost:{self._port}",
118+
log_file=log_file,
119+
)
120+
121+
# Track state
122+
self._started = False
123+
self._server_task: asyncio.Task | None = None
124+
125+
# ── Lifecycle ──
126+
127+
async def start(self):
128+
"""Start the worker in background (non-blocking).
129+
130+
Registers with orchestrator, starts FastAPI server, begins heartbeat.
131+
Returns once registration is complete so the caller can immediately
132+
invoke workflows.
133+
"""
134+
if self._started:
135+
return
136+
137+
# Phase 1: register (blocking — must complete before we return)
138+
await self._worker._register_all()
139+
self._started = True
140+
141+
# Phase 2: start server in background (non-blocking)
142+
self._server_task = asyncio.create_task(self._run_server_safe())
143+
144+
logger.info(
145+
"Polos started (orchestrator=%s, worker=:%d, deployment=%s)",
146+
self._api_url,
147+
self._port,
148+
self._deployment_id,
149+
)
150+
151+
async def _run_server_safe(self):
152+
"""Internal: run the worker server, logging errors instead of crashing."""
153+
try:
154+
await self._worker._run_server()
155+
except Exception as err:
156+
logger.error("Worker server error: %s", err)
157+
158+
async def serve(self):
159+
"""Start the worker and block until shutdown signal (SIGINT/SIGTERM).
160+
161+
This is the deployment mode — equivalent to Worker.run().
162+
Use for servers, Kubernetes, Docker, etc.
163+
"""
164+
await self.start()
165+
if self._server_task:
166+
await self._server_task
167+
168+
async def stop(self):
169+
"""Gracefully stop the worker and clean up."""
170+
if not self._started:
171+
return
172+
173+
await self._worker.shutdown()
174+
175+
if self._server_task:
176+
# Wait for server to finish shutting down
177+
await self._server_task
178+
self._server_task = None
179+
180+
self._started = False
181+
logger.info("Polos stopped")
182+
183+
# ── Context manager ──
184+
185+
async def __aenter__(self):
186+
await self.start()
187+
return self
188+
189+
async def __aexit__(self, exc_type, exc_val, exc_tb):
190+
await self.stop()
191+
192+
# ── Client interface ──
193+
# Delegate to internal PolosClient so agent.run(polos, ...) works.
194+
# agent.run() calls client._submit_workflow() and other methods internally.
195+
196+
@property
197+
def project_id(self):
198+
return self._client.project_id
199+
200+
@property
201+
def api_url(self):
202+
return self._client.api_url
203+
204+
@property
205+
def api_key(self):
206+
return self._client.api_key
207+
208+
@property
209+
def deployment_id(self):
210+
return self._client.deployment_id
211+
212+
def _get_headers(self) -> dict[str, str]:
213+
return self._client._get_headers()
214+
215+
async def _submit_workflow(self, *args, **kwargs) -> ExecutionHandle:
216+
return await self._client._submit_workflow(*args, **kwargs)
217+
218+
async def _submit_workflows(self, *args, **kwargs) -> list[ExecutionHandle]:
219+
return await self._client._submit_workflows(*args, **kwargs)
220+
221+
async def invoke(
222+
self,
223+
workflow_id: str,
224+
payload: Any = None,
225+
queue_name: str | None = None,
226+
queue_concurrency_limit: int | None = None,
227+
concurrency_key: str | None = None,
228+
session_id: str | None = None,
229+
user_id: str | None = None,
230+
initial_state: dict[str, Any] | None = None,
231+
run_timeout_seconds: int | None = None,
232+
) -> ExecutionHandle:
233+
"""Invoke a workflow and return an execution handle."""
234+
return await self._client.invoke(
235+
workflow_id=workflow_id,
236+
payload=payload,
237+
queue_name=queue_name,
238+
queue_concurrency_limit=queue_concurrency_limit,
239+
concurrency_key=concurrency_key,
240+
session_id=session_id,
241+
user_id=user_id,
242+
initial_state=initial_state,
243+
run_timeout_seconds=run_timeout_seconds,
244+
)
245+
246+
async def batch_invoke(
247+
self,
248+
workflows: list[BatchWorkflowInput],
249+
session_id: str | None = None,
250+
user_id: str | None = None,
251+
) -> list[ExecutionHandle]:
252+
"""Invoke multiple workflows in batch."""
253+
return await self._client.batch_invoke(
254+
workflows=workflows,
255+
session_id=session_id,
256+
user_id=user_id,
257+
)
258+
259+
async def resume(
260+
self,
261+
suspend_workflow_id: str,
262+
suspend_execution_id: str,
263+
suspend_step_key: str,
264+
data: Any,
265+
) -> None:
266+
"""Resume a suspended execution."""
267+
return await self._client.resume(
268+
suspend_workflow_id=suspend_workflow_id,
269+
suspend_execution_id=suspend_execution_id,
270+
suspend_step_key=suspend_step_key,
271+
data=data,
272+
)
273+
274+
async def get_execution(self, execution_id: str) -> dict[str, Any]:
275+
"""Get execution details."""
276+
return await self._client.get_execution(execution_id)
277+
278+
async def cancel_execution(self, execution_id: str) -> bool:
279+
"""Cancel an execution."""
280+
return await self._client.cancel_execution(execution_id)

0 commit comments

Comments
 (0)