|
| 1 | +"""Unified Polos class that combines PolosClient and Worker into a single object.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import logging |
| 5 | +import os |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +from .agents.agent import Agent |
| 9 | +from .core.workflow import get_all_workflows |
| 10 | +from .runtime.client import ExecutionHandle, PolosClient |
| 11 | +from .runtime.worker import Worker |
| 12 | +from .tools.tool import Tool |
| 13 | +from .types.types import BatchWorkflowInput |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +def _configure_file_logging(log_file: str) -> None: |
| 19 | + """Redirect all SDK logs to a file instead of stdout/stderr.""" |
| 20 | + handler = logging.FileHandler(log_file, mode="a") |
| 21 | + handler.setLevel(logging.INFO) |
| 22 | + handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)-8s %(name)s: %(message)s")) |
| 23 | + # Configure root logger so all SDK loggers (polos.*, httpx, etc.) use the file |
| 24 | + root = logging.getLogger() |
| 25 | + root.handlers.clear() |
| 26 | + root.addHandler(handler) |
| 27 | + root.setLevel(logging.INFO) |
| 28 | + |
| 29 | + |
| 30 | +class Polos: |
| 31 | + """Unified Polos client + worker. |
| 32 | +
|
| 33 | + Combines PolosClient (submit/stream work) and Worker (receive/execute work) |
| 34 | + into a single object. One process, one import. |
| 35 | +
|
| 36 | + For production deployments where you need to scale workers independently |
| 37 | + from clients, use the separate PolosClient and Worker classes directly. |
| 38 | +
|
| 39 | + Usage:: |
| 40 | +
|
| 41 | + from polos import Polos, Agent, tool |
| 42 | +
|
| 43 | + @tool(description="Get weather") |
| 44 | + async def get_weather(ctx, input): |
| 45 | + return {"temp": 72} |
| 46 | +
|
| 47 | + weather_agent = Agent( |
| 48 | + id="weather", provider="openai", model="gpt-4o", |
| 49 | + system_prompt="You are a weather assistant.", |
| 50 | + tools=[get_weather], |
| 51 | + ) |
| 52 | +
|
| 53 | + # Script mode: context manager |
| 54 | + async with Polos() as polos: |
| 55 | + result = await weather_agent.run(polos, "What's the weather in Paris?") |
| 56 | + print(result) |
| 57 | +
|
| 58 | + # Script mode: start/stop |
| 59 | + polos = Polos() |
| 60 | + await polos.start() |
| 61 | + result = await weather_agent.run(polos, "What's the weather?") |
| 62 | + await polos.stop() |
| 63 | +
|
| 64 | + # Server mode: blocks until SIGINT/SIGTERM |
| 65 | + polos = Polos() |
| 66 | + await polos.serve() |
| 67 | + """ |
| 68 | + |
| 69 | + def __init__( |
| 70 | + self, |
| 71 | + project_id: str | None = None, |
| 72 | + api_url: str | None = None, |
| 73 | + api_key: str | None = None, |
| 74 | + deployment_id: str | None = None, |
| 75 | + port: int = 8000, |
| 76 | + max_concurrent_workflows: int | None = None, |
| 77 | + channels: list | None = None, |
| 78 | + log_file: str | None = None, |
| 79 | + ): |
| 80 | + self._project_id = project_id or os.getenv("POLOS_PROJECT_ID") |
| 81 | + self._api_url = api_url or os.getenv("POLOS_API_URL", "http://localhost:8080") |
| 82 | + self._api_key = api_key or os.getenv("POLOS_API_KEY") |
| 83 | + self._deployment_id = deployment_id or os.getenv("POLOS_DEPLOYMENT_ID", "default") |
| 84 | + self._port = port or os.getenv("POLOS_WORKER_PORT", 8000) |
| 85 | + self._channels = channels |
| 86 | + self._log_file = log_file |
| 87 | + |
| 88 | + # Redirect SDK logs to file if requested |
| 89 | + if log_file: |
| 90 | + _configure_file_logging(log_file) |
| 91 | + |
| 92 | + # Create client |
| 93 | + self._client = PolosClient( |
| 94 | + project_id=self._project_id, |
| 95 | + api_url=self._api_url, |
| 96 | + api_key=self._api_key, |
| 97 | + deployment_id=self._deployment_id, |
| 98 | + ) |
| 99 | + |
| 100 | + # Discover agents, tools, and workflows from the global WORKFLOW_REGISTRY. |
| 101 | + # They auto-register when defined (Agent(), Tool(), @workflow), so no need |
| 102 | + # to pass them explicitly. |
| 103 | + all_workflows = get_all_workflows() |
| 104 | + agents = [w for w in all_workflows.values() if isinstance(w, Agent)] |
| 105 | + tools = [w for w in all_workflows.values() if isinstance(w, Tool)] |
| 106 | + workflows = [w for w in all_workflows.values() if not isinstance(w, (Agent, Tool))] |
| 107 | + |
| 108 | + # Create worker |
| 109 | + self._worker = Worker( |
| 110 | + client=self._client, |
| 111 | + deployment_id=self._deployment_id, |
| 112 | + agents=agents, |
| 113 | + tools=tools, |
| 114 | + workflows=workflows, |
| 115 | + max_concurrent_workflows=max_concurrent_workflows, |
| 116 | + mode="push", |
| 117 | + worker_server_url=f"http://localhost:{self._port}", |
| 118 | + log_file=log_file, |
| 119 | + ) |
| 120 | + |
| 121 | + # Track state |
| 122 | + self._started = False |
| 123 | + self._server_task: asyncio.Task | None = None |
| 124 | + |
| 125 | + # ── Lifecycle ── |
| 126 | + |
| 127 | + async def start(self): |
| 128 | + """Start the worker in background (non-blocking). |
| 129 | +
|
| 130 | + Registers with orchestrator, starts FastAPI server, begins heartbeat. |
| 131 | + Returns once registration is complete so the caller can immediately |
| 132 | + invoke workflows. |
| 133 | + """ |
| 134 | + if self._started: |
| 135 | + return |
| 136 | + |
| 137 | + # Phase 1: register (blocking — must complete before we return) |
| 138 | + await self._worker._register_all() |
| 139 | + self._started = True |
| 140 | + |
| 141 | + # Phase 2: start server in background (non-blocking) |
| 142 | + self._server_task = asyncio.create_task(self._run_server_safe()) |
| 143 | + |
| 144 | + logger.info( |
| 145 | + "Polos started (orchestrator=%s, worker=:%d, deployment=%s)", |
| 146 | + self._api_url, |
| 147 | + self._port, |
| 148 | + self._deployment_id, |
| 149 | + ) |
| 150 | + |
| 151 | + async def _run_server_safe(self): |
| 152 | + """Internal: run the worker server, logging errors instead of crashing.""" |
| 153 | + try: |
| 154 | + await self._worker._run_server() |
| 155 | + except Exception as err: |
| 156 | + logger.error("Worker server error: %s", err) |
| 157 | + |
| 158 | + async def serve(self): |
| 159 | + """Start the worker and block until shutdown signal (SIGINT/SIGTERM). |
| 160 | +
|
| 161 | + This is the deployment mode — equivalent to Worker.run(). |
| 162 | + Use for servers, Kubernetes, Docker, etc. |
| 163 | + """ |
| 164 | + await self.start() |
| 165 | + if self._server_task: |
| 166 | + await self._server_task |
| 167 | + |
| 168 | + async def stop(self): |
| 169 | + """Gracefully stop the worker and clean up.""" |
| 170 | + if not self._started: |
| 171 | + return |
| 172 | + |
| 173 | + await self._worker.shutdown() |
| 174 | + |
| 175 | + if self._server_task: |
| 176 | + # Wait for server to finish shutting down |
| 177 | + await self._server_task |
| 178 | + self._server_task = None |
| 179 | + |
| 180 | + self._started = False |
| 181 | + logger.info("Polos stopped") |
| 182 | + |
| 183 | + # ── Context manager ── |
| 184 | + |
| 185 | + async def __aenter__(self): |
| 186 | + await self.start() |
| 187 | + return self |
| 188 | + |
| 189 | + async def __aexit__(self, exc_type, exc_val, exc_tb): |
| 190 | + await self.stop() |
| 191 | + |
| 192 | + # ── Client interface ── |
| 193 | + # Delegate to internal PolosClient so agent.run(polos, ...) works. |
| 194 | + # agent.run() calls client._submit_workflow() and other methods internally. |
| 195 | + |
| 196 | + @property |
| 197 | + def project_id(self): |
| 198 | + return self._client.project_id |
| 199 | + |
| 200 | + @property |
| 201 | + def api_url(self): |
| 202 | + return self._client.api_url |
| 203 | + |
| 204 | + @property |
| 205 | + def api_key(self): |
| 206 | + return self._client.api_key |
| 207 | + |
| 208 | + @property |
| 209 | + def deployment_id(self): |
| 210 | + return self._client.deployment_id |
| 211 | + |
| 212 | + def _get_headers(self) -> dict[str, str]: |
| 213 | + return self._client._get_headers() |
| 214 | + |
| 215 | + async def _submit_workflow(self, *args, **kwargs) -> ExecutionHandle: |
| 216 | + return await self._client._submit_workflow(*args, **kwargs) |
| 217 | + |
| 218 | + async def _submit_workflows(self, *args, **kwargs) -> list[ExecutionHandle]: |
| 219 | + return await self._client._submit_workflows(*args, **kwargs) |
| 220 | + |
| 221 | + async def invoke( |
| 222 | + self, |
| 223 | + workflow_id: str, |
| 224 | + payload: Any = None, |
| 225 | + queue_name: str | None = None, |
| 226 | + queue_concurrency_limit: int | None = None, |
| 227 | + concurrency_key: str | None = None, |
| 228 | + session_id: str | None = None, |
| 229 | + user_id: str | None = None, |
| 230 | + initial_state: dict[str, Any] | None = None, |
| 231 | + run_timeout_seconds: int | None = None, |
| 232 | + ) -> ExecutionHandle: |
| 233 | + """Invoke a workflow and return an execution handle.""" |
| 234 | + return await self._client.invoke( |
| 235 | + workflow_id=workflow_id, |
| 236 | + payload=payload, |
| 237 | + queue_name=queue_name, |
| 238 | + queue_concurrency_limit=queue_concurrency_limit, |
| 239 | + concurrency_key=concurrency_key, |
| 240 | + session_id=session_id, |
| 241 | + user_id=user_id, |
| 242 | + initial_state=initial_state, |
| 243 | + run_timeout_seconds=run_timeout_seconds, |
| 244 | + ) |
| 245 | + |
| 246 | + async def batch_invoke( |
| 247 | + self, |
| 248 | + workflows: list[BatchWorkflowInput], |
| 249 | + session_id: str | None = None, |
| 250 | + user_id: str | None = None, |
| 251 | + ) -> list[ExecutionHandle]: |
| 252 | + """Invoke multiple workflows in batch.""" |
| 253 | + return await self._client.batch_invoke( |
| 254 | + workflows=workflows, |
| 255 | + session_id=session_id, |
| 256 | + user_id=user_id, |
| 257 | + ) |
| 258 | + |
| 259 | + async def resume( |
| 260 | + self, |
| 261 | + suspend_workflow_id: str, |
| 262 | + suspend_execution_id: str, |
| 263 | + suspend_step_key: str, |
| 264 | + data: Any, |
| 265 | + ) -> None: |
| 266 | + """Resume a suspended execution.""" |
| 267 | + return await self._client.resume( |
| 268 | + suspend_workflow_id=suspend_workflow_id, |
| 269 | + suspend_execution_id=suspend_execution_id, |
| 270 | + suspend_step_key=suspend_step_key, |
| 271 | + data=data, |
| 272 | + ) |
| 273 | + |
| 274 | + async def get_execution(self, execution_id: str) -> dict[str, Any]: |
| 275 | + """Get execution details.""" |
| 276 | + return await self._client.get_execution(execution_id) |
| 277 | + |
| 278 | + async def cancel_execution(self, execution_id: str) -> bool: |
| 279 | + """Cancel an execution.""" |
| 280 | + return await self._client.cancel_execution(execution_id) |
0 commit comments