Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: ignore the implementation details of message acknowledgements fr…
…om the cloud Pub/Sub Message

fixes: #311
  • Loading branch information
dpcollins-google committed Mar 7, 2022
commit 3ae1090b1c1a120e0d436452ceac69aff419064e
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from google.cloud.pubsublite.internal import fast_serialize
from google.cloud.pubsublite.types import FlowControlSettings
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.cloudpubsub.internal.wrapped_message import (
AckId,
WrappedMessage,
)
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer
from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler
from google.cloud.pubsublite.cloudpubsub.internal.single_subscriber import (
Expand All @@ -44,19 +48,6 @@ class _SizedMessage(NamedTuple):
size_bytes: int


class _AckId(NamedTuple):
generation: int
offset: int

def encode(self) -> str:
return fast_serialize.dump([self.generation, self.offset])

@staticmethod
def parse(payload: str) -> "_AckId": # pytype: disable=invalid-annotation
loaded = fast_serialize.load(payload)
return _AckId(generation=loaded[0], offset=loaded[1])


ResettableSubscriberFactory = Callable[[SubscriberResetHandler], Subscriber]


Expand All @@ -69,10 +60,10 @@ class SinglePartitionSingleSubscriber(
_nack_handler: NackHandler
_transformer: MessageTransformer

_queue: queue.Queue
_ack_generation_id: int
_messages_by_ack_id: Dict[str, _SizedMessage]
_looper_future: asyncio.Future
_messages_by_ack_id: Dict[AckId, _SizedMessage]

_loop: asyncio.BaseEventLoop

def __init__(
self,
Expand All @@ -89,7 +80,6 @@ def __init__(
self._nack_handler = nack_handler
self._transformer = transformer

self._queue = queue.Queue()
self._ack_generation_id = 0
self._messages_by_ack_id = {}

Expand All @@ -104,19 +94,33 @@ def _wrap_message(self, message: SequencedMessage.meta.pb) -> Message:
rewrapped._pb = message
cps_message = self._transformer.transform(rewrapped)
offset = message.cursor.offset
ack_id_str = _AckId(self._ack_generation_id, offset).encode()
ack_id = AckId(self._ack_generation_id, offset)
self._ack_set_tracker.track(offset)
self._messages_by_ack_id[ack_id_str] = _SizedMessage(
self._messages_by_ack_id[ack_id] = _SizedMessage(
cps_message, message.size_bytes
)
wrapped_message = Message(
cps_message._pb,
ack_id=ack_id_str,
delivery_attempt=0,
request_queue=self._queue,
wrapped_message = WrappedMessage(
pb=cps_message._pb,
ack_id=ack_id,
ack_handler=lambda id, ack: self._on_ack_threadsafe(id, ack),
)
return wrapped_message

def _on_ack_threadsafe(self, ack_id: AckId, should_ack: bool) -> None:
"""A function called when a message is acked, may happen from any thread."""
if should_ack:
return self._loop.call_soon_threadsafe(lambda: self._handle_ack(ack_id))
try:
sized_message = self._messages_by_ack_id[ack_id]
# Call the threadsafe version on ack since the callback may be called from another thread.
self._nack_handler.on_nack(
sized_message.message, lambda: self._on_ack_threadsafe(ack_id, True)
)
except Exception as e:
e2 = adapt_error(e)
do_fail = lambda: self.fail(e2)
return self._loop.call_soon_threadsafe(lambda: self.fail(e2))

async def read(self) -> List[Message]:
try:
latest_batch = await self.await_unless_failed(self._underlying.read())
Expand All @@ -126,40 +130,19 @@ async def read(self) -> List[Message]:
self.fail(e)
raise e

def _handle_ack(self, message: requests.AckRequest):
def _handle_ack(self, ack_id: AckId):
flow_control = FlowControlRequest()
flow_control._pb.allowed_messages = 1
flow_control._pb.allowed_bytes = self._messages_by_ack_id[
message.ack_id
].size_bytes
flow_control._pb.allowed_bytes = self._messages_by_ack_id[ack_id].size_bytes
self._underlying.allow_flow(flow_control)
del self._messages_by_ack_id[message.ack_id]
del self._messages_by_ack_id[ack_id]
# Always refill flow control tokens, but do not commit offsets from outdated generations.
ack_id = _AckId.parse(message.ack_id)
if ack_id.generation == self._ack_generation_id:
try:
self._ack_set_tracker.ack(ack_id.offset)
except GoogleAPICallError as e:
self.fail(e)

def _handle_nack(self, message: requests.NackRequest):
sized_message = self._messages_by_ack_id[message.ack_id]
try:
# Put the ack request back into the queue since the callback may be called from another thread.
self._nack_handler.on_nack(
sized_message.message,
lambda: self._queue.put(
requests.AckRequest(
ack_id=message.ack_id,
byte_size=0, # Ignored
time_to_ack=0, # Ignored
ordering_key="", # Ignored
)
),
)
except GoogleAPICallError as e:
self.fail(e)

async def _handle_queue_message(
self,
message: Union[
Expand All @@ -183,21 +166,10 @@ async def _handle_queue_message(
else:
self._handle_nack(message)

async def _looper(self):
while True:
try:
# This is not an asyncio.Queue, and therefore we cannot do `await self._queue.get()`.
# A blocking wait would block the event loop, this needs to be a queue.Queue for
# compatibility with the Cloud Pub/Sub Message's requirements.
queue_message = self._queue.get_nowait()
await self._handle_queue_message(queue_message)
except queue.Empty:
await asyncio.sleep(0.1)

async def __aenter__(self):
self._loop = asyncio.get_event_loop()
await self._ack_set_tracker.__aenter__()
await self._underlying.__aenter__()
self._looper_future = asyncio.ensure_future(self._looper())
self._underlying.allow_flow(
FlowControlRequest(
allowed_messages=self._flow_control_settings.messages_outstanding,
Expand All @@ -207,7 +179,5 @@ async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self._looper_future.cancel()
await wait_ignore_cancelled(self._looper_future)
await self._underlying.__aexit__(exc_type, exc_value, traceback)
await self._ack_set_tracker.__aexit__(exc_type, exc_value, traceback)
66 changes: 66 additions & 0 deletions google/cloud/pubsublite/cloudpubsub/internal/wrapped_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from concurrent import futures
import logging
from typing import NamedTuple, Callable

from google.cloud.pubsub_v1.subscriber.message import Message
from google.pubsub_v1 import PubsubMessage
from google.cloud.pubsub_v1.subscriber.exceptions import AcknowledgeStatus

from google.cloud.pubsublite.internal import fast_serialize


class AckId(NamedTuple):
generation: int
offset: int

def encode(self) -> str:
return str(self.generation) + "," + str(self.offset)


_SUCCESS_FUTURE = futures.Future()
_SUCCESS_FUTURE.set_result(AcknowledgeStatus.SUCCESS)


class WrappedMessage(Message):
_id: AckId
_ack_handler: Callable[[AckId, bool], None]

def __init__(
self,
pb: PubsubMessage.meta.pb,
ack_id: AckId,
ack_handler: Callable[[AckId, bool], None],
):
super().__init__(pb, ack_id.encode(), 1, None)
self._id = ack_id
self._ack_handler = ack_handler

def ack(self):
self._ack_handler(self._id, True)

def ack_with_response(self) -> "futures.Future":
self._ack_handler(self._id, True)
return _SUCCESS_FUTURE

def nack(self):
self._ack_handler(self._id, False)

def nack_with_response(self) -> "futures.Future":
self._ack_handler(self._id, False)
return _SUCCESS_FUTURE

def drop(self):
logging.warning(
"Likely incorrect call to drop() on Pub/Sub Lite message. Pub/Sub Lite does not support redelivery in this way."
)

def modify_ack_deadline(self, seconds: int):
logging.warning(
"Likely incorrect call to modify_ack_deadline() on Pub/Sub Lite message. Pub/Sub Lite does not support redelivery in this way."
)

def modify_ack_deadline_with_response(self, seconds: int) -> "futures.Future":
logging.warning(
"Likely incorrect call to modify_ack_deadline_with_response() on Pub/Sub Lite message. Pub/Sub Lite does not support redelivery in this way."
)
return _SUCCESS_FUTURE
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.cloud.pubsublite.cloudpubsub.internal.ack_set_tracker import AckSetTracker
from google.cloud.pubsublite.cloudpubsub.internal.single_partition_subscriber import (
SinglePartitionSingleSubscriber,
_AckId,
AckId,
)
from google.cloud.pubsublite.cloudpubsub.message_transformer import MessageTransformer
from google.cloud.pubsublite.cloudpubsub.nack_handler import NackHandler
Expand All @@ -48,7 +48,7 @@ def mock_async_context_manager(cm):


def ack_id(generation, offset) -> str:
return _AckId(generation, offset).encode()
return AckId(generation, offset).encode()


@pytest.fixture()
Expand Down