Skip to content

Commit 83e9d7c

Browse files
Bin1783WillemJiangCopilot
authored
feat:Database connections use connection pools (bytedance#757)
* feat: Implement DeerFlow API server with chat streaming, Langgraph orchestration, and various content generation capabilities. * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 1f403a9 commit 83e9d7c

File tree

1 file changed

+163
-17
lines changed

1 file changed

+163
-17
lines changed

src/server/app.py

Lines changed: 163 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from psycopg_pool import AsyncConnectionPool
2121

2222
from src.config.configuration import get_recursion_limit
23-
from src.config.loader import get_bool_env, get_str_env
23+
from src.config.loader import get_bool_env, get_int_env, get_str_env
2424
from src.config.report_style import ReportStyle
2525
from src.config.tools import SELECTED_RAG_PROVIDER
2626
from src.graph.builder import build_graph_with_memory
@@ -73,10 +73,135 @@
7373

7474
INTERNAL_SERVER_ERROR_DETAIL = "Internal Server Error"
7575

76+
# Global connection pools (initialized at startup if configured)
77+
_pg_pool: Optional[AsyncConnectionPool] = None
78+
_pg_checkpointer: Optional[AsyncPostgresSaver] = None
79+
80+
# Global MongoDB connection (initialized at startup if configured)
81+
_mongo_client: Optional[Any] = None
82+
_mongo_checkpointer: Optional[AsyncMongoDBSaver] = None
83+
84+
85+
from contextlib import asynccontextmanager
86+
87+
88+
@asynccontextmanager
89+
async def lifespan(app):
90+
"""
91+
Application lifecycle manager
92+
- Startup: Register asyncio exception handler and initialize global connection pools
93+
- Shutdown: Clean up global connection pools
94+
"""
95+
global _pg_pool, _pg_checkpointer, _mongo_client, _mongo_checkpointer
96+
97+
# ========== STARTUP ==========
98+
try:
99+
asyncio.get_running_loop()
100+
101+
except RuntimeError as e:
102+
logger.warning(f"Could not register asyncio exception handler: {e}")
103+
104+
# Initialize global connection pool based on configuration
105+
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
106+
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
107+
108+
if not checkpoint_saver or not checkpoint_url:
109+
logger.info("Checkpoint saver not configured, skipping connection pool initialization")
110+
else:
111+
# Initialize PostgreSQL connection pool
112+
if checkpoint_url.startswith("postgresql://"):
113+
pool_min_size = get_int_env("PG_POOL_MIN_SIZE", 5)
114+
pool_max_size = get_int_env("PG_POOL_MAX_SIZE", 20)
115+
pool_timeout = get_int_env("PG_POOL_TIMEOUT", 60)
116+
117+
connection_kwargs = {
118+
"autocommit": True,
119+
"prepare_threshold": 0,
120+
"row_factory": dict_row,
121+
}
122+
123+
logger.info(
124+
f"Initializing global PostgreSQL connection pool: "
125+
f"min_size={pool_min_size}, max_size={pool_max_size}, timeout={pool_timeout}s"
126+
)
127+
128+
try:
129+
_pg_pool = AsyncConnectionPool(
130+
checkpoint_url,
131+
kwargs=connection_kwargs,
132+
min_size=pool_min_size,
133+
max_size=pool_max_size,
134+
timeout=pool_timeout,
135+
)
136+
await _pg_pool.open()
137+
138+
_pg_checkpointer = AsyncPostgresSaver(_pg_pool)
139+
await _pg_checkpointer.setup()
140+
141+
logger.info("Global PostgreSQL connection pool initialized successfully")
142+
except Exception as e:
143+
logger.error(f"Failed to initialize PostgreSQL connection pool: {e}")
144+
_pg_pool = None
145+
_pg_checkpointer = None
146+
raise RuntimeError(
147+
"Checkpoint persistence is explicitly configured with PostgreSQL, "
148+
"but initialization failed. Application will not start."
149+
) from e
150+
151+
# Initialize MongoDB connection pool
152+
elif checkpoint_url.startswith("mongodb://"):
153+
try:
154+
from motor.motor_asyncio import AsyncIOMotorClient
155+
156+
# MongoDB connection pool settings
157+
mongo_max_pool_size = get_int_env("MONGO_MAX_POOL_SIZE", 20)
158+
mongo_min_pool_size = get_int_env("MONGO_MIN_POOL_SIZE", 5)
159+
160+
logger.info(
161+
f"Initializing global MongoDB connection pool: "
162+
f"min_pool_size={mongo_min_pool_size}, max_pool_size={mongo_max_pool_size}"
163+
)
164+
165+
_mongo_client = AsyncIOMotorClient(
166+
checkpoint_url,
167+
maxPoolSize=mongo_max_pool_size,
168+
minPoolSize=mongo_min_pool_size,
169+
)
170+
171+
# Create the MongoDB checkpointer using the global client
172+
_mongo_checkpointer = AsyncMongoDBSaver(_mongo_client)
173+
await _mongo_checkpointer.setup()
174+
175+
logger.info("Global MongoDB connection pool initialized successfully")
176+
except ImportError:
177+
logger.error("motor package not installed. Please install it with: pip install motor")
178+
raise RuntimeError("MongoDB checkpoint persistence is configured but the 'motor' package is not installed. Aborting startup.")
179+
except Exception as e:
180+
logger.error(f"Failed to initialize MongoDB connection pool: {e}")
181+
raise RuntimeError(f"MongoDB checkpoint persistence is configured but could not be initialized: {e}")
182+
183+
# ========== YIELD - Application runs here ==========
184+
yield
185+
186+
# ========== SHUTDOWN ==========
187+
# Close PostgreSQL connection pool
188+
if _pg_pool:
189+
logger.info("Closing global PostgreSQL connection pool")
190+
await _pg_pool.close()
191+
logger.info("Global PostgreSQL connection pool closed")
192+
193+
# Close MongoDB connection
194+
if _mongo_client:
195+
logger.info("Closing global MongoDB connection")
196+
_mongo_client.close()
197+
logger.info("Global MongoDB connection closed")
198+
199+
76200
app = FastAPI(
77201
title="DeerFlow API",
78202
description="API for Deer",
79203
version="0.1.0",
204+
lifespan=lifespan,
80205
)
81206

82207
# Add CORS middleware
@@ -612,23 +737,33 @@ async def _astream_workflow_generator(
612737
f"url_configured={bool(checkpoint_url)}"
613738
)
614739

615-
# Handle checkpointer if configured
616-
connection_kwargs = {
617-
"autocommit": True,
618-
"row_factory": "dict_row",
619-
"prepare_threshold": 0,
620-
}
740+
# Handle checkpointer if configured - prefer global connection pools
621741
if checkpoint_saver and checkpoint_url != "":
622-
if checkpoint_url.startswith("postgresql://"):
623-
logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer")
624-
logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool")
742+
# Try to use global PostgreSQL checkpointer first
743+
if checkpoint_url.startswith("postgresql://") and _pg_checkpointer:
744+
logger.info(f"[{safe_thread_id}] Using global PostgreSQL connection pool")
745+
graph.checkpointer = _pg_checkpointer
746+
graph.store = in_memory_store
747+
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
748+
async for event in _stream_graph_events(
749+
graph, workflow_input, workflow_config, thread_id
750+
):
751+
yield event
752+
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
753+
754+
# Fallback to per-request PostgreSQL connection if global pool not available
755+
elif checkpoint_url.startswith("postgresql://"):
756+
logger.info(f"[{safe_thread_id}] Global pool unavailable, creating per-request PostgreSQL connection")
757+
connection_kwargs = {
758+
"autocommit": True,
759+
"row_factory": "dict_row",
760+
"prepare_threshold": 0,
761+
}
625762
async with AsyncConnectionPool(
626763
checkpoint_url, kwargs=connection_kwargs
627764
) as conn:
628-
logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver")
629765
checkpointer = AsyncPostgresSaver(conn)
630766
await checkpointer.setup()
631-
logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph")
632767
graph.checkpointer = checkpointer
633768
graph.store = in_memory_store
634769
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
@@ -638,13 +773,24 @@ async def _astream_workflow_generator(
638773
yield event
639774
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
640775

641-
if checkpoint_url.startswith("mongodb://"):
642-
logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer")
643-
logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection")
776+
# Try to use global MongoDB checkpointer first
777+
elif checkpoint_url.startswith("mongodb://") and _mongo_checkpointer:
778+
logger.info(f"[{safe_thread_id}] Using global MongoDB connection pool")
779+
graph.checkpointer = _mongo_checkpointer
780+
graph.store = in_memory_store
781+
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
782+
async for event in _stream_graph_events(
783+
graph, workflow_input, workflow_config, thread_id
784+
):
785+
yield event
786+
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
787+
788+
# Fallback to per-request MongoDB connection if global pool not available
789+
elif checkpoint_url.startswith("mongodb://"):
790+
logger.info(f"[{safe_thread_id}] Global pool unavailable, creating per-request MongoDB connection")
644791
async with AsyncMongoDBSaver.from_conn_string(
645792
checkpoint_url
646793
) as checkpointer:
647-
logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph")
648794
graph.checkpointer = checkpointer
649795
graph.store = in_memory_store
650796
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
@@ -655,7 +801,7 @@ async def _astream_workflow_generator(
655801
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
656802
else:
657803
logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph")
658-
# Use graph without MongoDB checkpointer
804+
# Use graph without checkpointer
659805
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
660806
async for event in _stream_graph_events(
661807
graph, workflow_input, workflow_config, thread_id

0 commit comments

Comments
 (0)