From 34d84aaace13b68f1cd322236fe0a664abbf4b80 Mon Sep 17 00:00:00 2001 From: cameroncaci Date: Fri, 4 Apr 2025 13:28:24 -0400 Subject: [PATCH 1/3] add semaphore support to limit max amount of workers and move cursor to the database instead of the client --- src/db_extractor_full.py | 93 ++++++++++++++++++++++++++++++++-------- 1 file changed, 74 insertions(+), 19 deletions(-) diff --git a/src/db_extractor_full.py b/src/db_extractor_full.py index f136fcf..bc4e7b4 100644 --- a/src/db_extractor_full.py +++ b/src/db_extractor_full.py @@ -2,6 +2,8 @@ from multiprocessing.connection import Connection as multi_processing_connection from multiprocessing.connection import wait from multiprocessing import Process, Pipe +from threading import Semaphore +import resource from socket import socket import boto3 import pg8000 @@ -20,7 +22,13 @@ current_run_time = datetime.datetime.now() # How many SQL records a worker will fetch at a time -batch_size = 50000 +batch_size = 25000 + +# Max amount of concurrent batch workers. +# Not workers in general, just the amount allowed to work the batch processing +# at the same time for a given table. +# So X amount for table Y and X amount for table Z +MAX_CONCURRENT_BATCH_WORKERS = 5 # Lambda global connection for warm starts # This connection is only used to grab the table names @@ -38,7 +46,7 @@ def default(self, obj): return obj.hex return json.JSONEncoder.default(self, obj) -def parallel_worker(worker_conn, batch, column_names, key_name): +def parallel_worker(worker_conn, batch, column_names, key_name, sema): # Handle process worker -> back to json mapping and # respond it back via the pipe. We'll convert a batch # of rows into JSON and then close the pipe with our response @@ -50,6 +58,7 @@ def parallel_worker(worker_conn, batch, column_names, key_name): worker_conn.send(("error", f"ERROR: {e}\nKEY: {key_name}\nTrace: {tb}")) finally: worker_conn.close() + sema.release() def convert_batch_to_json(batch, column_names): # Convert a batch of rows to comma delimited JSON fragments (No start/end brackets) @@ -63,6 +72,7 @@ def convert_batch_to_json(batch, column_names): def fetch_and_upload_cursor_results( cursor: pg8000.Cursor, bucket_name, key_name, column_names ): + sema = Semaphore(MAX_CONCURRENT_BATCH_WORKERS) s3_client = boto3.client("s3") # Initiate multipart upload multipart_upload = s3_client.create_multipart_upload( @@ -77,6 +87,12 @@ def fetch_and_upload_cursor_results( buffer_size = 0 min_part_size = 50 * 1024 * 1024 # 50 MB + usage = resource.getrusage(resource.RUSAGE_SELF) + mem_mb = usage.ru_maxrss / 1024 # KB to MB + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Creating batch generator for {bucket_name}/{key_name}\n" + f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" + ) batch_generator = fetch_batches(cursor) # Since a generator doesn't return a list, check the first case of if it is empty @@ -120,20 +136,41 @@ def fetch_and_upload_cursor_results( # Early return, mark process to no longer be alive return + # Get memory usage + usage = resource.getrusage(resource.RUSAGE_SELF) + mem_mb = usage.ru_maxrss / 1024 # KB to MB + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning worker generation for {bucket_name}/{key_name}\n" + f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" + ) # Create a manager -> worker process for each batch in parallel # This makes a manager and worker combo for every batch that we will await later workers: List[Tuple[Process, multi_processing_connection, int]] = [] # Make sure we still process the first batch # fb = first batch fbManager, fbWorker = Pipe() - fbp = Process(target=parallel_worker, args=(fbWorker, first_batch, column_names, key_name)) + fbp = Process(target=parallel_worker, args=(fbWorker, first_batch, column_names, key_name, sema)) workers.append((fbp, fbManager, 1)) + usage = resource.getrusage(resource.RUSAGE_SELF) + mem_mb = usage.ru_maxrss / 1024 # KB to MB + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning first batch for {bucket_name}/{key_name}\n" + f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" + ) fbp.start() # Process remaining batches from the generator for i, batch in enumerate(batch_generator): + usage = resource.getrusage(resource.RUSAGE_SELF) + mem_mb = usage.ru_maxrss / 1024 # KB to MB + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning batch {i} for {bucket_name}/{key_name}\n" + f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" + ) + + sema.acquire() manager, worker = Pipe() - p = Process(target=parallel_worker, args=(worker, batch, column_names, key_name)) - workers.append((p, manager, i+1)) # +1 because of first batch preceding this + p = Process(target=parallel_worker, args=(worker, batch, column_names, key_name, sema)) + workers.append((p, manager, i+1)) # +1 because of first batch preceding this p.start() # @@ -148,6 +185,12 @@ def fetch_and_upload_cursor_results( # Concurrently wait for the workers to be complete while workers: + usage = resource.getrusage(resource.RUSAGE_SELF) + mem_mb = usage.ru_maxrss / 1024 # KB to MB + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Waiting for workers for {bucket_name}/{key_name}\n" + f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" + ) finished_workers: List[multi_processing_connection | socket | int] = wait([mgr for (proc, mgr, bIndex) in workers], timeout=None) for finished_worker in finished_workers: is_last_worker = len(workers) == 1 @@ -222,7 +265,9 @@ def fetch_and_upload_cursor_results( # to boost processing efficiency def fetch_batches(cursor: pg8000.Cursor): while True: - batch = cursor.fetchmany(batch_size) + cursor.execute(f"FETCH FORWARD {batch_size} FROM data_cursor") + # Fetch data from client-side cursor that was provided by the server-side cursor + batch = cursor.fetchall() if not batch: # No more batch results break @@ -232,6 +277,7 @@ def fetch_batches(cursor: pg8000.Cursor): def map_row_to_columns(row, column_names): return {column_names[i]: row[i] for i in range(len(column_names))} + def table_extractor( table_name, json_parameter_value, @@ -269,6 +315,7 @@ def table_extractor( # Set the statement timeout to 600 seconds for this session cursor.execute("SET statement_timeout = '600s'") + cursor.execute("BEGIN READ ONLY;") # Handle if the table being iterated on does not have updated_at or created_at # Since we do not have timestamps to compare to, we must full dump the table without updated or created at @@ -278,30 +325,39 @@ def table_extractor( + str(table_name) ) # Tell the database to execute this query, we will ingest it in chunks - cursor.execute("SELECT * FROM " + str(table_name)) + # Create a server-side cursor + cursor.execute(f"DECLARE data_cursor CURSOR FOR SELECT * FROM {table_name}") # Fetch cursor results and upload to S3 fetch_and_upload_cursor_results(cursor, bucket_name, s3_key, column_names) + cursor.execute("CLOSE data_cursor") + cursor.execute("COMMIT;") # If we have created_at but no updated_at, we dump based only on created_at elif found_updated_at == False and found_created_at == True: last_run_time = json_parameter_value["data"]["lastRunTime"] cursor.execute( - "SELECT * FROM " - + str(table_name) - + " where (created_at > '" - + str(last_run_time) - + "') order by created_at;" + f""" + DECLARE data_cursor CURSOR FOR + SELECT * FROM {table_name} + WHERE created_at > %s + ORDER BY created_at + """, + (last_run_time,) ) fetch_and_upload_cursor_results(cursor, bucket_name, s3_key, column_names) + cursor.execute("CLOSE data_cursor") + cursor.execute("COMMIT;") # If we have created_at and updated_at, we dump based on both elif found_updated_at == True and found_created_at == True: last_run_time = json_parameter_value["data"]["lastRunTime"] - cursor.execute( - "SELECT * FROM " - + str(table_name) - + " where ((created_at > %s) OR (updated_at > %s)) order by created_at;", - (last_run_time, last_run_time), - ) + cursor.execute(f""" + DECLARE data_cursor CURSOR FOR + SELECT * FROM {table_name} + WHERE ((created_at > %s) OR (updated_at > %s)) + ORDER BY created_at + """, (last_run_time, last_run_time,)) fetch_and_upload_cursor_results(cursor, bucket_name, s3_key, column_names) + cursor.execute("CLOSE data_cursor") + cursor.execute("COMMIT;") else: print( "Data Warehouse Lambda - ERROR - DB Extract - " @@ -378,7 +434,6 @@ def db_extractor(): "archived_access_codes", "schema_migration", "audit_history_tableslist", - "audit_history", "v_locations", ] From aafc4520961332223a99450d14ba21c5561a021b Mon Sep 17 00:00:00 2001 From: cameroncaci Date: Fri, 4 Apr 2025 14:31:44 -0400 Subject: [PATCH 2/3] remove sema and put memory prints into a helper func --- src/db_extractor_full.py | 63 ++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/src/db_extractor_full.py b/src/db_extractor_full.py index bc4e7b4..2f5feb4 100644 --- a/src/db_extractor_full.py +++ b/src/db_extractor_full.py @@ -2,7 +2,6 @@ from multiprocessing.connection import Connection as multi_processing_connection from multiprocessing.connection import wait from multiprocessing import Process, Pipe -from threading import Semaphore import resource from socket import socket import boto3 @@ -45,8 +44,15 @@ def default(self, obj): # if the obj is uuid, we simply return the value of uuid return obj.hex return json.JSONEncoder.default(self, obj) + +def print_memory(): + usage = resource.getrusage(resource.RUSAGE_SELF) + mem_mb = usage.ru_maxrss / 1024 # KB to MB + print( + f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" + ) -def parallel_worker(worker_conn, batch, column_names, key_name, sema): +def parallel_worker(worker_conn, batch, column_names, key_name): # Handle process worker -> back to json mapping and # respond it back via the pipe. We'll convert a batch # of rows into JSON and then close the pipe with our response @@ -58,7 +64,6 @@ def parallel_worker(worker_conn, batch, column_names, key_name, sema): worker_conn.send(("error", f"ERROR: {e}\nKEY: {key_name}\nTrace: {tb}")) finally: worker_conn.close() - sema.release() def convert_batch_to_json(batch, column_names): # Convert a batch of rows to comma delimited JSON fragments (No start/end brackets) @@ -72,7 +77,6 @@ def convert_batch_to_json(batch, column_names): def fetch_and_upload_cursor_results( cursor: pg8000.Cursor, bucket_name, key_name, column_names ): - sema = Semaphore(MAX_CONCURRENT_BATCH_WORKERS) s3_client = boto3.client("s3") # Initiate multipart upload multipart_upload = s3_client.create_multipart_upload( @@ -87,12 +91,8 @@ def fetch_and_upload_cursor_results( buffer_size = 0 min_part_size = 50 * 1024 * 1024 # 50 MB - usage = resource.getrusage(resource.RUSAGE_SELF) - mem_mb = usage.ru_maxrss / 1024 # KB to MB - print( - f"Data Warehouse Lambda - DEBUG - DB Extract - Creating batch generator for {bucket_name}/{key_name}\n" - f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" - ) + print(f"Data Warehouse Lambda - DEBUG - DB Extract - Creating batch generator for {bucket_name}/{key_name}\n") + print_memory() batch_generator = fetch_batches(cursor) # Since a generator doesn't return a list, check the first case of if it is empty @@ -136,61 +136,56 @@ def fetch_and_upload_cursor_results( # Early return, mark process to no longer be alive return - # Get memory usage - usage = resource.getrusage(resource.RUSAGE_SELF) - mem_mb = usage.ru_maxrss / 1024 # KB to MB + # + # Start our JSON output document + # + # Begin with our bracket as we are + # going to format this manually + buffer.write(b"[") + buffer_size += 1 + print( f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning worker generation for {bucket_name}/{key_name}\n" - f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" ) + print_memory() # Create a manager -> worker process for each batch in parallel # This makes a manager and worker combo for every batch that we will await later workers: List[Tuple[Process, multi_processing_connection, int]] = [] # Make sure we still process the first batch # fb = first batch fbManager, fbWorker = Pipe() - fbp = Process(target=parallel_worker, args=(fbWorker, first_batch, column_names, key_name, sema)) + fbp = Process(target=parallel_worker, args=(fbWorker, first_batch, column_names, key_name)) workers.append((fbp, fbManager, 1)) - usage = resource.getrusage(resource.RUSAGE_SELF) - mem_mb = usage.ru_maxrss / 1024 # KB to MB print( f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning first batch for {bucket_name}/{key_name}\n" - f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" ) + print_memory() fbp.start() # Process remaining batches from the generator for i, batch in enumerate(batch_generator): - usage = resource.getrusage(resource.RUSAGE_SELF) - mem_mb = usage.ru_maxrss / 1024 # KB to MB print( f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning batch {i} for {bucket_name}/{key_name}\n" - f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" ) + print_memory() - sema.acquire() manager, worker = Pipe() - p = Process(target=parallel_worker, args=(worker, batch, column_names, key_name, sema)) + p = Process(target=parallel_worker, args=(worker, batch, column_names, key_name)) workers.append((p, manager, i+1)) # +1 because of first batch preceding this p.start() - - # - # Start our JSON output document - # - # Begin with our bracket as we are - # going to format this manually - buffer.write(b"[") - buffer_size += 1 + + # if len(workers) >= MAX_CONCURRENT_BATCH_WORKERS: + # Max workers reached, before creating an ew process we need to + # handle the finished ones + print(f'Data Warehouse Lambda - INFO - {key_name} has {len(workers)} workers') # Concurrently wait for the workers to be complete while workers: - usage = resource.getrusage(resource.RUSAGE_SELF) - mem_mb = usage.ru_maxrss / 1024 # KB to MB print( f"Data Warehouse Lambda - DEBUG - DB Extract - Waiting for workers for {bucket_name}/{key_name}\n" - f"Data Warehouse Lambda - DEBUG - Memory usage: {mem_mb:.2f} MB" ) + print_memory() finished_workers: List[multi_processing_connection | socket | int] = wait([mgr for (proc, mgr, bIndex) in workers], timeout=None) for finished_worker in finished_workers: is_last_worker = len(workers) == 1 From 0901af8a7fd9e33f8982d7eee899753a6fc94a81 Mon Sep 17 00:00:00 2001 From: cameroncaci Date: Mon, 7 Apr 2025 15:11:34 -0400 Subject: [PATCH 3/3] massive performance overhaul with worker tracking --- src/db_extractor_full.py | 362 ++++++++++++++++++++++++++------------- 1 file changed, 241 insertions(+), 121 deletions(-) diff --git a/src/db_extractor_full.py b/src/db_extractor_full.py index 2f5feb4..8e908dd 100644 --- a/src/db_extractor_full.py +++ b/src/db_extractor_full.py @@ -2,6 +2,7 @@ from multiprocessing.connection import Connection as multi_processing_connection from multiprocessing.connection import wait from multiprocessing import Process, Pipe +import multiprocessing.connection import resource from socket import socket import boto3 @@ -14,20 +15,20 @@ import datetime import gc import traceback -from typing import List, Tuple +from typing import Any, Iterator, List, Tuple # Storing current time, we will use this to update SSM when finished so that we # know for the next run which time to select from current_run_time = datetime.datetime.now() # How many SQL records a worker will fetch at a time -batch_size = 25000 +batch_size = 20000 # Max amount of concurrent batch workers. # Not workers in general, just the amount allowed to work the batch processing # at the same time for a given table. # So X amount for table Y and X amount for table Z -MAX_CONCURRENT_BATCH_WORKERS = 5 +MAX_CONCURRENT_BATCH_WORKERS = 4 # Lambda global connection for warm starts # This connection is only used to grab the table names @@ -73,6 +74,38 @@ def convert_batch_to_json(batch, column_names): records.append(json.dumps(as_dict, cls=UUIDEncoder, default=str)) return ",".join(records).encode("utf-8") # Return as bytes +def upload_empty_json(s3_client, bucket_name, key_name, part_number, upload_id, parts): + buffer = io.BytesIO() + try: + buffer.write(b"[]") + buffer.seek(0) + response = s3_client.upload_part( + Bucket=bucket_name, + Key=key_name, + PartNumber=part_number, + UploadId=upload_id, + Body=buffer, + ) + parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) + s3_client.complete_multipart_upload( + Bucket=bucket_name, + Key=key_name, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + print( + f"Data Warehouse Lambda - INFO - DB Extract - Successfully wrote {bucket_name}/{key_name}" + ) + except Exception as e: + # Abort multipart upload in case of failure + s3_client.abort_multipart_upload( + Bucket=bucket_name, Key=key_name, UploadId=upload_id + ) + print( + f"Data Warehouse Lambda - ERROR - DB Extract - Error during multipart upload: {e}" + ) + raise e + # Helper function to batch fetch data and use multipart uploading with S3 def fetch_and_upload_cursor_results( cursor: pg8000.Cursor, bucket_name, key_name, column_names @@ -93,48 +126,7 @@ def fetch_and_upload_cursor_results( print(f"Data Warehouse Lambda - DEBUG - DB Extract - Creating batch generator for {bucket_name}/{key_name}\n") print_memory() - batch_generator = fetch_batches(cursor) - - # Since a generator doesn't return a list, check the first case of if it is empty - try: - first_batch = next(batch_generator) - except StopIteration: - # This exception WILL be thrown if no SQL records return - first_batch = None - - if first_batch is None: - # No records, make a new buffer and write an empty array for the dump - try: - buffer.write(b"[]") - buffer.seek(0) - response = s3_client.upload_part( - Bucket=bucket_name, - Key=key_name, - PartNumber=part_number, - UploadId=upload_id, - Body=buffer, - ) - parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) - s3_client.complete_multipart_upload( - Bucket=bucket_name, - Key=key_name, - UploadId=upload_id, - MultipartUpload={"Parts": parts}, - ) - print( - f"Data Warehouse Lambda - INFO - DB Extract - Successfully wrote {bucket_name}/{key_name}" - ) - except Exception as e: - # Abort multipart upload in case of failure - s3_client.abort_multipart_upload( - Bucket=bucket_name, Key=key_name, UploadId=upload_id - ) - print( - f"Data Warehouse Lambda - ERROR - DB Extract - Error during multipart upload: {e}" - ) - raise e - # Early return, mark process to no longer be alive - return + batch_generator = lookahead(fetch_batches(cursor)) # # Start our JSON output document @@ -150,94 +142,206 @@ def fetch_and_upload_cursor_results( print_memory() # Create a manager -> worker process for each batch in parallel # This makes a manager and worker combo for every batch that we will await later - workers: List[Tuple[Process, multi_processing_connection, int]] = [] - # Make sure we still process the first batch - # fb = first batch - fbManager, fbWorker = Pipe() - fbp = Process(target=parallel_worker, args=(fbWorker, first_batch, column_names, key_name)) - workers.append((fbp, fbManager, 1)) - print( - f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning first batch for {bucket_name}/{key_name}\n" - ) - print_memory() - fbp.start() + workers: List[Tuple[Process, multi_processing_connection, int, bool]] = [] + # Process remaining batches from the generator - for i, batch in enumerate(batch_generator): + did_batches_run = False + all_workers_have_finished = False + total_finished_workers: List[Tuple[Process, multi_processing_connection, int, bool]] = [] + for i, (batch, is_last_generator) in enumerate(batch_generator): + did_batches_run = True print( f"Data Warehouse Lambda - DEBUG - DB Extract - Beginning batch {i} for {bucket_name}/{key_name}\n" ) print_memory() + # Queue the worker manager, worker = Pipe() p = Process(target=parallel_worker, args=(worker, batch, column_names, key_name)) - workers.append((p, manager, i+1)) # +1 because of first batch preceding this + workers.append((p, manager, i, is_last_generator)) p.start() - # if len(workers) >= MAX_CONCURRENT_BATCH_WORKERS: - # Max workers reached, before creating an ew process we need to - # handle the finished ones + # Prep to extract the last batch worker from the queue + # This is done so that we can conditionally make sure + # that we close the JSON array properly `]` in the + # multi-part upload + last_batch_worker = None - - print(f'Data Warehouse Lambda - INFO - {key_name} has {len(workers)} workers') - - # Concurrently wait for the workers to be complete - while workers: - print( - f"Data Warehouse Lambda - DEBUG - DB Extract - Waiting for workers for {bucket_name}/{key_name}\n" - ) - print_memory() - finished_workers: List[multi_processing_connection | socket | int] = wait([mgr for (proc, mgr, bIndex) in workers], timeout=None) - for finished_worker in finished_workers: - is_last_worker = len(workers) == 1 - for i, (proc, mgr, batch_index) in enumerate(workers): - if mgr is finished_worker: - if mgr.poll(): - msg_type, data = mgr.recv() - mgr.close() # Free thread - proc.join() # Make sure worker is done - workers.pop(i) # Clear the worker - - if msg_type == "error": - print(f"Data Warehouse Lambda - ERROR - Worker {batch_index} failed: {data}") - raise RuntimeError(f"Worker {batch_index} error: {data}") + # If we hit the max amount of workers or it is the last generator + # then we want to infinitely loop until either a worker is freed + # or the loop is completely exited because the last generator is present + while len(workers) >= MAX_CONCURRENT_BATCH_WORKERS or (is_last_generator and not all_workers_have_finished): + # Max workers reached or this is the end of the line for the generator, + # before creating a new process we need to + # clear out the current workers + if len(workers) >= MAX_CONCURRENT_BATCH_WORKERS: + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Max concurrent batch workers reached for {bucket_name}/{key_name}\nbeginning upload\nHave all workers finished? {all_workers_have_finished}" + ) + if is_last_generator: + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Last generator reached for {bucket_name}/{key_name}\nbeginning upload\nHave all workers finished? {all_workers_have_finished}" + ) + + # Grab any finished worker managers + current_finished_worker_managers = wait( + [mgr for (_, mgr, _, _) in workers], timeout=5 + ) + + # Map the manager to the worker + current_finished_workers: List[Tuple[Process, multi_processing_connection, int, bool]] = [ + (proc, mgr, bIndex, is_final_batch) + for (proc, mgr, bIndex, is_final_batch) in workers + if mgr in current_finished_worker_managers + ] + + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Found this many current finished worker managers {len(current_finished_worker_managers)} which translated to {len(current_finished_workers)} current finished workers {bucket_name}/{key_name}" + ) + + # Add them to the total + for finished_worker in current_finished_workers: + # For each finished worker, find their respective `worker` entry + # and once found, add it to the total_finished_workers + for worker in workers: + proc, mgr, batch_index, is_final_batch = worker + if mgr is finished_worker[1] and worker not in total_finished_workers: + total_finished_workers.append(worker) + break + + print( + f"With {len(current_finished_workers)} current finished workers, I now have {len(total_finished_workers)} total finished workers for {key_name}" + ) + + # Iterate over each finished worker, we will poll and upload each of them + # making sure to only process the final batch worker last if present + for finished_worker in total_finished_workers: + proc, mgr, batch_index, is_final_batch = finished_worker + if last_batch_worker is None: + print( + f"No last batch worker found for {key_name}" + ) + # Last batch worker hasn't been found yet + # Check if this worker is the last batch worker + if is_final_batch: + print( + f"Found the last batch worker found for {key_name}" + ) + # This is the one + last_batch_worker = finished_worker + if last_batch_worker is not None: + print( + f"if last_batch_worker is not None for {key_name} with a len of {len(workers)} workers" + ) + # We know that if this worker is present, all remaining workers are present + # We should only process the final worker if it is the last one + # (Yes this if condition can be joined int one, but this is more readable) + if len(workers) > 1: + # There are unpopped workers present, we should not process the + # final batch worker until it is the last one print( - f"Data Warehouse Lambda - INFO - DB Extract - {key_name} worker for batch index {batch_index} polled successfully. {len(workers)} workers remain" + f"Data Warehouse Lambda - DEBUG - DB Extract - {bucket_name}/{key_name}\n attempted to poll the last worker while more remain, skipping" ) - # Add data fragment to buffer - buffer.write(data) - buffer_size += len(data) + # Instead of immediately continuing the loop, wait until the next worker is complete before continuing + # This is a way of preserving CPU threads from being hogged by an infinite loop - if not is_last_worker: - # If it's not the last worker, append `,` to support a final - # [${worker_1_json}, ${worker_2_json}, ${worker_3_json}] output - buffer.write(b",") - buffer_size += 1 - - # Upload part if buffer reaches the minimum part size or - # if this is the final worker - if buffer_size >= min_part_size or is_last_worker: - if is_last_worker: - # Last call, close it up! - buffer.write(b"]") - buffer_size += 1 - buffer.seek(0) - response = s3_client.upload_part( - Bucket=bucket_name, - Key=key_name, - PartNumber=part_number, - UploadId=upload_id, - Body=buffer, + # Only wait for a new worker if there are any pending+ + # This is a good condition because it will only be run + # if the final batch worker has been found, meaning + # the len of total finished workers must be equal to len of workers + # themselves + if len(total_finished_workers) != len(workers): + # Wait for the other worker that is not the last batch worker + wait( + [tfwMgr for (_, tfwMgr, _, _) in workers if tfwMgr is not last_batch_worker[1]], + timeout=5 ) - parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) - part_number += 1 - buffer.close() - buffer = io.BytesIO() - buffer_size = 0 - - gc.collect() + + # Check if this worker is the last batch worker while len workers > 1 + # If len workers > 1 and this is the last batch worker, skip this worker and process the other one + if mgr == last_batch_worker[1]: + # This worker's manager is the same manager as the last batch worker, skip it + continue + + # Poll the manager, receive their message, upload and pop + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - Seeing if manager has a message for {bucket_name}/{key_name}" + ) + if mgr.poll(): + print( + f"Data Warehouse Lambda - DEBUG - DB Extract - polling manager with a len of {len(total_finished_workers)} total finished worker and len {len(workers)} workers for {bucket_name}/{key_name}" + ) + msg_type, data = mgr.recv() + mgr.close() # Free thread + proc.join() # Make sure worker is done + # Remove this worker from the workers list + total_finished_workers.remove(finished_worker) + + # Find the corresponding worker and pop it + for j, w in enumerate(workers): + # If the managers are the same, pop it + if w[1] is mgr: + workers.pop(j) + break + print( + f"attempted to remove finished worker from total finished workers, new len of total finished workers is {len(total_finished_workers)} as well as len {len(workers)} workers for {key_name}\nIs final batch: {is_final_batch}" + ) + is_final_worker = is_final_batch and len(workers) == 0 # Track if this is the last one + if is_final_batch and len(workers) > 1: + # This condition should be impossible as it would + # allow the processing of the final batch worker + # even though there are other workers to be processed first + raise RuntimeError(f"Worker {key_name} unexpected error: the final batch worker is attempting to be polled even though they are not the last worker") + if msg_type == "error": + print(f"Data Warehouse Lambda - ERROR - Worker {batch_index} failed: {data}") + raise RuntimeError(f"Worker {batch_index} error: {data}") + print( + f"Data Warehouse Lambda - INFO - DB Extract - {key_name} worker for batch index {batch_index} polled successfully. {len(workers)} workers remain" + ) + # Add data fragment to buffer + buffer.write(data) + buffer_size += len(data) + + if not is_final_worker: + # If it's not the last worker, append `,` to support a final + # [${worker_1_json}, ${worker_2_json}, ${worker_3_json}] output + buffer.write(b",") + buffer_size += 1 + + # Upload part if buffer reaches the minimum part size or + # if this is the final worker + if buffer_size >= min_part_size or is_final_worker: + print( + f"Triggering upload with buffer size {buffer_size} for {key_name}" + ) + if is_final_worker: + all_workers_have_finished = True + print(f"Final worker for {key_name} for batch index {batch_index} reached, {len(workers)} workers remain") + # Last call, close it up! + buffer.write(b"]") + buffer_size += 1 + buffer.seek(0) + response = s3_client.upload_part( + Bucket=bucket_name, + Key=key_name, + PartNumber=part_number, + UploadId=upload_id, + Body=buffer, + ) + parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) + part_number += 1 + buffer = io.BytesIO() + buffer_size = 0 + gc.collect() + if not did_batches_run: + # The for loop won't execute if the cursor returns no rows + upload_empty_json(s3_client, bucket_name, key_name, part_number, upload_id, parts) + # Early return + return + print( f"Data Warehouse Lambda - INFO - DB Extract - {key_name} workers complete, announcing multipart completion" ) @@ -246,16 +350,32 @@ def fetch_and_upload_cursor_results( buffer.close() del buffer # Announce multipart upload completion - s3_client.complete_multipart_upload( - Bucket=bucket_name, - Key=key_name, - UploadId=upload_id, - MultipartUpload={"Parts": parts}, - ) + try: + s3_client.complete_multipart_upload( + Bucket=bucket_name, + Key=key_name, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + except Exception as e: + raise RuntimeError(f"Multipart upload completion error for {key_name}: {e}") print( f"Data Warehouse Lambda - INFO - DB Extract - Successfully wrote {bucket_name}/{key_name}" ) +# helper when fetching batches to look ahead in the generator +# letting us know when we've reached the last entry in the +# generator. This assists in json formatting +def lookahead(gen: Iterator) -> Iterator[Tuple[Any, bool]]: + try: + prev = next(gen) + except StopIteration: + return + for val in gen: + yield prev, False + prev = val + yield prev, True # End of generator reached + # Helper func to yield results rather than return # to boost processing efficiency def fetch_batches(cursor: pg8000.Cursor):