Skip to content

Commit c8f936b

Browse files
authored
fix(sessions): resolve async deadlock in multiplexed session manager (#1520)
This PR resolves a critical deadlock issue when acquiring or maintaining a multiplexed session asynchronously. The bug occurs because DatabaseSessionsManager previously used a synchronous threading.Lock around self._get_multiplexed_session() and _maintain_multiplexed_session(). When a thread attempts to await the multiplexed session creation (return await ...) while holding a synchronous thread lock, the entire asyncio event loop becomes blocked for any other coroutine trying to access the lock.
1 parent f822fd7 commit c8f936b

File tree

12 files changed

+113
-48
lines changed

12 files changed

+113
-48
lines changed

packages/google-cloud-spanner/google/cloud/spanner_v1/_async/database_sessions_manager.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def __init__(self, database, pool):
7171
self._pool = pool
7272
self._multiplexed_session: Optional[Session] = None
7373
self._multiplexed_session_thread: Optional[CrossSync.Task] = None
74-
# Use threading.Lock because this is accessed in a synchronous maintenance thread
75-
self._multiplexed_session_lock: threading.Lock = threading.Lock()
76-
self._multiplexed_session_terminate_event: CrossSync.Event = CrossSync.Event()
74+
self._init_lock = threading.Lock()
75+
self._multiplexed_session_lock: Optional[CrossSync.Lock] = None
76+
self._multiplexed_session_terminate_event: Optional[CrossSync.Event] = None
7777

7878
@CrossSync.convert
7979
async def get_session(self, transaction_type: TransactionType) -> Session:
@@ -119,7 +119,13 @@ async def _get_multiplexed_session(self) -> Session:
119119
120120
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
121121
:returns: a multiplexed session."""
122-
with CrossSync.rm_aio(self._multiplexed_session_lock):
122+
with self._init_lock:
123+
if self._multiplexed_session_lock is None:
124+
self._multiplexed_session_lock = CrossSync.Lock()
125+
if self._multiplexed_session_terminate_event is None:
126+
self._multiplexed_session_terminate_event = CrossSync.Event()
127+
128+
async with self._multiplexed_session_lock:
123129
if self._multiplexed_session is None:
124130
self._multiplexed_session = await self._build_multiplexed_session()
125131
self._multiplexed_session_thread = self._build_maintenance_thread()
@@ -193,7 +199,7 @@ async def _maintain_multiplexed_session(session_manager_ref) -> None:
193199
if time() - session_created_time < refresh_interval_seconds:
194200
await CrossSync.sleep(polling_interval_seconds)
195201
continue
196-
with manager._multiplexed_session_lock:
202+
async with manager._multiplexed_session_lock:
197203
await CrossSync.run_if_async(manager._multiplexed_session.delete)
198204
manager._multiplexed_session = (
199205
await manager._build_multiplexed_session()
@@ -220,7 +226,8 @@ def _getenv(cls, env_var_name: str) -> bool:
220226
@CrossSync.convert
221227
async def close(self) -> None:
222228
"""Closes the database session manager and stops all background tasks."""
223-
self._multiplexed_session_terminate_event.set()
229+
if self._multiplexed_session_terminate_event is not None:
230+
self._multiplexed_session_terminate_event.set()
224231
if self._multiplexed_session_thread is not None:
225232
if CrossSync.is_async:
226233
self._multiplexed_session_thread.cancel()

packages/google-cloud-spanner/google/cloud/spanner_v1/batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def wrapped_method():
243243
max_commit_delay=max_commit_delay,
244244
request_options=request_options,
245245
)
246-
(call_metadata, error_augmenter) = database.with_error_augmentation(
246+
call_metadata, error_augmenter = database.with_error_augmentation(
247247
getattr(database, "_next_nth_request", 0), 1, metadata, span
248248
)
249249
commit_method = functools.partial(

packages/google-cloud-spanner/google/cloud/spanner_v1/database.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282
trace_call,
8383
)
8484
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
85-
8685
from google.cloud.spanner_v1.table import Table
8786

8887
SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"
@@ -211,11 +210,9 @@ def __init__(
211210
def _resource_info(self):
212211
"""Resource information for metrics labels."""
213212
return {
214-
"project": (
215-
self._instance._client.project
216-
if self._instance and self._instance._client
217-
else None
218-
),
213+
"project": self._instance._client.project
214+
if self._instance and self._instance._client
215+
else None,
219216
"instance": self._instance.instance_id if self._instance else None,
220217
"database": self.database_id,
221218
}
@@ -533,7 +530,7 @@ def with_error_augmentation(
533530
tuple: (metadata_list, context_manager)"""
534531
if span is None:
535532
span = get_current_span()
536-
(metadata, request_id) = _metadata_with_request_id_and_req_id(
533+
metadata, request_id = _metadata_with_request_id_and_req_id(
537534
self._nth_client_id,
538535
self._channel_id,
539536
nth_request,
@@ -810,7 +807,7 @@ def execute_pdml():
810807
session = self._sessions_manager.get_session(transaction_type)
811808
try:
812809
add_span_event(span, "Starting BeginTransaction")
813-
(call_metadata, error_augmenter) = self.with_error_augmentation(
810+
call_metadata, error_augmenter = self.with_error_augmentation(
814811
self._next_nth_request, 1, metadata, span
815812
)
816813
with error_augmenter:

packages/google-cloud-spanner/google/cloud/spanner_v1/database_sessions_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ def __init__(self, database, pool):
6969
self._pool = pool
7070
self._multiplexed_session: Optional[Session] = None
7171
self._multiplexed_session_thread: Optional[CrossSync._Sync_Impl.Task] = None
72-
self._multiplexed_session_lock: threading.Lock = threading.Lock()
73-
self._multiplexed_session_terminate_event: CrossSync._Sync_Impl.Event = (
74-
CrossSync._Sync_Impl.Event()
75-
)
72+
self._init_lock = threading.Lock()
73+
self._multiplexed_session_lock: Optional[CrossSync._Sync_Impl.Lock] = None
74+
self._multiplexed_session_terminate_event: Optional[
75+
CrossSync._Sync_Impl.Event
76+
] = None
7677

7778
def get_session(self, transaction_type: TransactionType) -> Session:
7879
"""Returns a session for the given transaction type from the database session manager.
@@ -115,6 +116,11 @@ def _get_multiplexed_session(self) -> Session:
115116
116117
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
117118
:returns: a multiplexed session."""
119+
with self._init_lock:
120+
if self._multiplexed_session_lock is None:
121+
self._multiplexed_session_lock = CrossSync._Sync_Impl.Lock()
122+
if self._multiplexed_session_terminate_event is None:
123+
self._multiplexed_session_terminate_event = CrossSync._Sync_Impl.Event()
118124
with self._multiplexed_session_lock:
119125
if self._multiplexed_session is None:
120126
self._multiplexed_session = self._build_multiplexed_session()
@@ -205,7 +211,8 @@ def _getenv(cls, env_var_name: str) -> bool:
205211

206212
def close(self) -> None:
207213
"""Closes the database session manager and stops all background tasks."""
208-
self._multiplexed_session_terminate_event.set()
214+
if self._multiplexed_session_terminate_event is not None:
215+
self._multiplexed_session_terminate_event.set()
209216
if self._multiplexed_session_thread is not None:
210217
self._multiplexed_session_thread.join()
211218
if self._multiplexed_session is not None:

packages/google-cloud-spanner/google/cloud/spanner_v1/instance.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,9 @@ def database(
479479
database_role=database_role,
480480
enable_drop_protection=enable_drop_protection,
481481
)
482-
db._pool.bind(db)
482+
res = db._pool.bind(db)
483+
if res is not None:
484+
res
483485
return db
484486

485487
def list_databases(self, page_size=None):

packages/google-cloud-spanner/google/cloud/spanner_v1/pool.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _fill_pool(self):
304304
f"Creating {request.session_count} sessions",
305305
span_event_attributes,
306306
)
307-
(call_metadata, error_augmenter) = database.with_error_augmentation(
307+
call_metadata, error_augmenter = database.with_error_augmentation(
308308
database._next_nth_request, 1, metadata, span
309309
)
310310
with error_augmenter:
@@ -612,7 +612,7 @@ def bind(self, database):
612612
) as span, MetricsCapture(self._resource_info):
613613
returned_session_count = 0
614614
while returned_session_count < self.size:
615-
(call_metadata, error_augmenter) = database.with_error_augmentation(
615+
call_metadata, error_augmenter = database.with_error_augmentation(
616616
database._next_nth_request, 1, metadata, span
617617
)
618618
with error_augmenter:
@@ -654,7 +654,7 @@ def get(self, timeout=None):
654654
ping_after = None
655655
session = None
656656
try:
657-
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
657+
ping_after, session = CrossSync._Sync_Impl.queue_get(
658658
self._sessions, block=True, timeout=timeout
659659
)
660660
except CrossSync._Sync_Impl.QueueEmpty as e:
@@ -698,9 +698,7 @@ def clear(self):
698698
"""Delete all sessions in the pool."""
699699
while True:
700700
try:
701-
(_, session) = CrossSync._Sync_Impl.queue_get(
702-
self._sessions, block=False
703-
)
701+
_, session = CrossSync._Sync_Impl.queue_get(self._sessions, block=False)
704702
except CrossSync._Sync_Impl.QueueEmpty:
705703
break
706704
else:
@@ -713,7 +711,7 @@ def ping(self):
713711
or during the "idle" phase of an event loop."""
714712
while True:
715713
try:
716-
(ping_after, session) = CrossSync._Sync_Impl.queue_get(
714+
ping_after, session = CrossSync._Sync_Impl.queue_get(
717715
self._sessions, block=False
718716
)
719717
except CrossSync._Sync_Impl.QueueEmpty:

packages/google-cloud-spanner/google/cloud/spanner_v1/session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def create(self):
188188
observability_options=observability_options,
189189
metadata=metadata,
190190
) as span, MetricsCapture(self._resource_info):
191-
(call_metadata, error_augmenter) = database.with_error_augmentation(
191+
call_metadata, error_augmenter = database.with_error_augmentation(
192192
nth_request, 1, metadata, span
193193
)
194194
with error_augmenter:
@@ -232,7 +232,7 @@ def exists(self):
232232
observability_options=observability_options,
233233
metadata=metadata,
234234
) as span, MetricsCapture(self._resource_info):
235-
(call_metadata, error_augmenter) = database.with_error_augmentation(
235+
call_metadata, error_augmenter = database.with_error_augmentation(
236236
nth_request, 1, metadata, span
237237
)
238238
with error_augmenter:
@@ -283,7 +283,7 @@ def delete(self):
283283
observability_options=observability_options,
284284
metadata=metadata,
285285
) as span, MetricsCapture(self._resource_info):
286-
(call_metadata, error_augmenter) = database.with_error_augmentation(
286+
call_metadata, error_augmenter = database.with_error_augmentation(
287287
nth_request, 1, metadata, span
288288
)
289289
with error_augmenter:
@@ -300,7 +300,7 @@ def ping(self):
300300
metadata = _metadata_with_prefix(database.name)
301301
nth_request = database._next_nth_request
302302
with trace_call("CloudSpanner.Session.ping", self) as span:
303-
(call_metadata, error_augmenter) = database.with_error_augmentation(
303+
call_metadata, error_augmenter = database.with_error_augmentation(
304304
nth_request, 1, metadata, span
305305
)
306306
with error_augmenter:

packages/google-cloud-spanner/google/cloud/spanner_v1/snapshot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def execute_sql(
322322
raise ValueError("Transaction has not begun.")
323323
if params is not None:
324324
params_pb = Struct(
325-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
325+
fields={key: _make_value_pb(value) for key, value in params.items()}
326326
)
327327
else:
328328
params_pb = {}
@@ -513,7 +513,7 @@ def partition_query(
513513
raise ValueError("Cannot partition a single-use transaction.")
514514
if params is not None:
515515
params_pb = Struct(
516-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
516+
fields={key: _make_value_pb(value) for key, value in params.items()}
517517
)
518518
else:
519519
params_pb = Struct()
@@ -614,7 +614,7 @@ def wrapped_method():
614614
begin_transaction_request = BeginTransactionRequest(
615615
**begin_request_kwargs
616616
)
617-
(call_metadata, error_augmenter) = database.with_error_augmentation(
617+
call_metadata, error_augmenter = database.with_error_augmentation(
618618
nth_request, attempt.increment(), metadata, span
619619
)
620620
begin_transaction_method = functools.partial(

packages/google-cloud-spanner/google/cloud/spanner_v1/streamed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _consume_next(self):
147147

148148
def __iter__(self):
149149
while True:
150-
(iter_rows, self._rows[:]) = (self._rows[:], ())
150+
iter_rows, self._rows[:] = (self._rows[:], ())
151151
while iter_rows:
152152
yield iter_rows.pop(0)
153153
if self._done:
@@ -230,7 +230,7 @@ def to_dict_list(self):
230230
rows.append(
231231
{
232232
column: value
233-
for (column, value) in zip(
233+
for column, value in zip(
234234
[column.name for column in self._metadata.row_type.fields], row
235235
)
236236
}
@@ -291,7 +291,7 @@ def _merge_array(lhs, rhs, type_):
291291
if element_type.code in _UNMERGEABLE_TYPES:
292292
lhs.list_value.values.extend(rhs.list_value.values)
293293
return lhs
294-
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
294+
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
295295
if not len(lhs) or not len(rhs):
296296
return Value(list_value=ListValue(values=lhs + rhs))
297297
first = rhs.pop(0)
@@ -316,7 +316,7 @@ def _merge_array(lhs, rhs, type_):
316316
def _merge_struct(lhs, rhs, type_):
317317
"""Helper for '_merge_by_type'."""
318318
fields = type_.struct_type.fields
319-
(lhs, rhs) = (list(lhs.list_value.values), list(rhs.list_value.values))
319+
lhs, rhs = (list(lhs.list_value.values), list(rhs.list_value.values))
320320
if not len(lhs) or not len(rhs):
321321
return Value(list_value=ListValue(values=lhs + rhs))
322322
candidate_type = fields[len(lhs) - 1].type_

packages/google-cloud-spanner/google/cloud/spanner_v1/transaction.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def rollback(self) -> None:
162162

163163
def wrapped_method(*args, **kwargs):
164164
attempt.increment()
165-
(call_metadata, error_augmenter) = database.with_error_augmentation(
165+
call_metadata, error_augmenter = database.with_error_augmentation(
166166
nth_request, attempt.value, metadata, span
167167
)
168168
rollback_method = functools.partial(
@@ -269,7 +269,7 @@ def wrapped_method(*args, **kwargs):
269269
is_multiplexed = getattr(self._session, "is_multiplexed", False)
270270
if is_multiplexed and self._precommit_token is not None:
271271
commit_request_args["precommit_token"] = self._precommit_token
272-
(call_metadata, error_augmenter) = database.with_error_augmentation(
272+
call_metadata, error_augmenter = database.with_error_augmentation(
273273
nth_request, attempt.value, metadata, span
274274
)
275275
commit_method = functools.partial(
@@ -300,7 +300,7 @@ def before_next_retry(nth_retry, delay_in_seconds):
300300
if commit_response_pb._pb.HasField("precommit_token"):
301301
add_span_event(span, commit_retry_event_name)
302302
nth_request = database._next_nth_request
303-
(call_metadata, error_augmenter) = database.with_error_augmentation(
303+
call_metadata, error_augmenter = database.with_error_augmentation(
304304
nth_request, 1, metadata, span
305305
)
306306
with error_augmenter:
@@ -338,7 +338,7 @@ def _make_params_pb(params, param_types):
338338
If ``params`` is None but ``param_types`` is not None."""
339339
if params:
340340
return Struct(
341-
fields={key: _make_value_pb(value) for (key, value) in params.items()}
341+
fields={key: _make_value_pb(value) for key, value in params.items()}
342342
)
343343
return {}
344344

@@ -417,7 +417,7 @@ def execute_update(
417417
metadata.append(
418418
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
419419
)
420-
(seqno, self._execute_sql_request_count) = (
420+
seqno, self._execute_sql_request_count = (
421421
self._execute_sql_request_count,
422422
self._execute_sql_request_count + 1,
423423
)
@@ -454,7 +454,7 @@ def execute_update(
454454

455455
def wrapped_method(*args, **kwargs):
456456
attempt.increment()
457-
(call_metadata, error_augmenter) = database.with_error_augmentation(
457+
call_metadata, error_augmenter = database.with_error_augmentation(
458458
nth_request, attempt.value, metadata
459459
)
460460
execute_sql_method = functools.partial(
@@ -544,7 +544,7 @@ def batch_update(
544544
if isinstance(statement, str):
545545
parsed.append(ExecuteBatchDmlRequest.Statement(sql=statement))
546546
else:
547-
(dml, params, param_types) = statement
547+
dml, params, param_types = statement
548548
params_pb = self._make_params_pb(params, param_types)
549549
parsed.append(
550550
ExecuteBatchDmlRequest.Statement(
@@ -556,7 +556,7 @@ def batch_update(
556556
metadata.append(
557557
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
558558
)
559-
(seqno, self._execute_sql_request_count) = (
559+
seqno, self._execute_sql_request_count = (
560560
self._execute_sql_request_count,
561561
self._execute_sql_request_count + 1,
562562
)
@@ -590,7 +590,7 @@ def batch_update(
590590

591591
def wrapped_method(*args, **kwargs):
592592
attempt.increment()
593-
(call_metadata, error_augmenter) = database.with_error_augmentation(
593+
call_metadata, error_augmenter = database.with_error_augmentation(
594594
nth_request, attempt.value, metadata
595595
)
596596
execute_batch_dml_method = functools.partial(

0 commit comments

Comments
 (0)