Skip to content

Commit 1a2a6f7

Browse files
committed
Improve reorg handling with conservative invalidation and auto-restart
- Use prev_range.start for invalidation range to ensure all potentially affected data is invalidated (addresses Ford review comment) - Remove _pending_batch logic - caller now handles stream restart - Clear prev_ranges_by_network for affected networks on reorg detection - Add auto-restart loop in query_and_load_streaming to transparently handle reorgs without requiring user intervention - Close old stream before restarting to prevent resource leaks - Update tests to reflect new invalidation behavior
1 parent cdb1f28 commit 1a2a6f7

File tree

3 files changed

+66
-50
lines changed

3 files changed

+66
-50
lines changed

src/amp/client.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -791,42 +791,62 @@ def query_and_load_streaming(
791791
self.logger.warning(f'Failed to load checkpoint, starting from beginning: {e}')
792792

793793
try:
794-
# Execute streaming query with Flight SQL
795-
# Create a CommandStatementQuery message
796-
command_query = FlightSql_pb2.CommandStatementQuery()
797-
command_query.query = query
798-
799-
# Add resume watermark if provided
800-
if resume_watermark:
801-
# TODO: Add watermark to query metadata when Flight SQL supports it
802-
self.logger.info(f'Resuming stream from watermark: {resume_watermark}')
803-
804-
# Wrap the CommandStatementQuery in an Any type
805-
any_command = Any()
806-
any_command.Pack(command_query)
807-
cmd = any_command.SerializeToString()
808-
809-
self.logger.info('Establishing Flight SQL connection...')
810-
flight_descriptor = flight.FlightDescriptor.for_command(cmd)
811-
info = self.conn.get_flight_info(flight_descriptor)
812-
reader = self.conn.do_get(info.endpoints[0].ticket)
813-
814-
# Create streaming iterator
815-
stream_iterator = StreamingResultIterator(reader)
816-
self.logger.info('Stream connection established, waiting for data...')
817-
818-
# Optionally wrap with reorg detection
819-
if with_reorg_detection:
820-
stream_iterator = ReorgAwareStream(stream_iterator, resume_watermark=resume_watermark)
821-
self.logger.info('Reorg detection enabled for streaming query')
822-
823-
# Start continuous loading with checkpoint support
824794
with loader_instance:
825-
self.logger.info(f'Starting continuous load to {destination}. Press Ctrl+C to stop.')
826-
# Pass connection_name for checkpoint saving
827-
yield from loader_instance.load_stream_continuous(
828-
stream_iterator, destination, connection_name=connection_name, **load_config.__dict__
829-
)
795+
while True:
796+
# Execute streaming query with Flight SQL
797+
# Create a CommandStatementQuery message
798+
command_query = FlightSql_pb2.CommandStatementQuery()
799+
command_query.query = query
800+
801+
# Add resume watermark if provided
802+
if resume_watermark:
803+
# TODO: Add watermark to query metadata when Flight SQL supports it
804+
self.logger.info(f'Resuming stream from watermark: {resume_watermark}')
805+
806+
# Wrap the CommandStatementQuery in an Any type
807+
any_command = Any()
808+
any_command.Pack(command_query)
809+
cmd = any_command.SerializeToString()
810+
811+
self.logger.info('Establishing Flight SQL connection...')
812+
flight_descriptor = flight.FlightDescriptor.for_command(cmd)
813+
info = self.conn.get_flight_info(flight_descriptor)
814+
reader = self.conn.do_get(info.endpoints[0].ticket)
815+
816+
# Create streaming iterator
817+
stream_iterator = StreamingResultIterator(reader)
818+
self.logger.info('Stream connection established, waiting for data...')
819+
820+
# Optionally wrap with reorg detection
821+
if with_reorg_detection:
822+
stream_iterator = ReorgAwareStream(stream_iterator, resume_watermark=resume_watermark)
823+
self.logger.info('Reorg detection enabled for streaming query')
824+
825+
# Start continuous loading with checkpoint support
826+
self.logger.info(f'Starting continuous load to {destination}. Press Ctrl+C to stop.')
827+
828+
reorg_result = None
829+
# Pass connection_name for checkpoint saving
830+
for result in loader_instance.load_stream_continuous(
831+
stream_iterator, destination, connection_name=connection_name, **load_config.__dict__
832+
):
833+
yield result
834+
# Break on reorg to restart stream
835+
if result.is_reorg:
836+
reorg_result = result
837+
break
838+
839+
# Check if we need to restart due to reorg
840+
if reorg_result:
841+
# Close the old stream before restarting
842+
if hasattr(stream_iterator, 'close'):
843+
stream_iterator.close()
844+
self.logger.info('Reorg detected, restarting stream with new resume position...')
845+
resume_watermark = loader_instance.state_store.get_resume_position(connection_name, destination)
846+
continue
847+
848+
# Normal exit - stream completed
849+
break
830850

831851
except Exception as e:
832852
self.logger.error(f'Streaming query failed: {e}')

src/amp/streaming/reorg.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,20 +75,16 @@ def __next__(self) -> ResponseBatch:
7575
for range in batch.metadata.ranges:
7676
self.prev_ranges_by_network[range.network] = range
7777

78-
# If we detected a reorg, yield the reorg notification first
78+
# If we detected a reorg, return reorg batch
79+
# Caller decides whether to stop/restart or continue
7980
if invalidation_ranges:
8081
self.logger.info(f'Reorg detected with {len(invalidation_ranges)} invalidation ranges')
81-
# Store the batch to yield after the reorg
82-
self._pending_batch = batch
82+
# Clear memory for affected networks so restart works correctly
83+
for inv_range in invalidation_ranges:
84+
if inv_range.network in self.prev_ranges_by_network:
85+
del self.prev_ranges_by_network[inv_range.network]
8386
return ResponseBatch.reorg_batch(invalidation_ranges)
8487

85-
# Check if we have a pending batch from a previous reorg detection
86-
# REVIEW: I think we should remove this
87-
if hasattr(self, '_pending_batch'):
88-
pending = self._pending_batch
89-
delattr(self, '_pending_batch')
90-
return pending
91-
9288
# Normal case - just return the data batch
9389
return batch
9490

@@ -144,7 +140,7 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]:
144140
if is_reorg:
145141
invalidation = BlockRange(
146142
network=current_range.network,
147-
start=current_range.start,
143+
start=prev_range.start,
148144
end=max(current_range.end, prev_range.end),
149145
hash=prev_range.hash,
150146
)

tests/unit/test_streaming_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ class MockIterator:
474474

475475
assert len(invalidations) == 1
476476
assert invalidations[0].network == 'ethereum'
477-
assert invalidations[0].start == 180
477+
assert invalidations[0].start == 100 # prev_range.start
478478
assert invalidations[0].end == 280 # max(280, 200)
479479

480480
def test_detect_reorg_multiple_networks(self):
@@ -504,12 +504,12 @@ class MockIterator:
504504

505505
# Check ethereum reorg
506506
eth_inv = next(inv for inv in invalidations if inv.network == 'ethereum')
507-
assert eth_inv.start == 150
507+
assert eth_inv.start == 100 # prev_range.start
508508
assert eth_inv.end == 250
509509

510510
# Check polygon reorg
511511
poly_inv = next(inv for inv in invalidations if inv.network == 'polygon')
512-
assert poly_inv.start == 140
512+
assert poly_inv.start == 50 # prev_range.start
513513
assert poly_inv.end == 240
514514

515515
def test_detect_reorg_same_range_no_reorg(self):
@@ -546,7 +546,7 @@ class MockIterator:
546546
invalidations = stream._detect_reorg(current_ranges)
547547

548548
assert len(invalidations) == 1
549-
assert invalidations[0].start == 250
549+
assert invalidations[0].start == 100 # prev_range.start
550550
assert invalidations[0].end == 300 # max(280, 300)
551551

552552
def test_is_duplicate_batch_all_same(self):

0 commit comments

Comments
 (0)