diff --git a/src/db_extractor_full.py b/src/db_extractor_full.py index 9f11a98..f136fcf 100644 --- a/src/db_extractor_full.py +++ b/src/db_extractor_full.py @@ -1,5 +1,8 @@ import multiprocessing from multiprocessing.connection import Connection as multi_processing_connection +from multiprocessing.connection import wait +from multiprocessing import Process, Pipe +from socket import socket import boto3 import pg8000 import json @@ -9,22 +12,23 @@ from uuid import UUID import datetime import gc -from multiprocessing import Process, Pipe +import traceback from typing import List, Tuple # Storing current time, we will use this to update SSM when finished so that we # know for the next run which time to select from current_run_time = datetime.datetime.now() -# Define batch size fetching of records for json dumps -batch_size = 500000 +# How many SQL records a worker will fetch at a time +batch_size = 50000 # Lambda global connection for warm starts # This connection is only used to grab the table names -# When multiprocessing, the child processes each create their own +# When multiprocessing, the worker processes each create their own # connection connection = None +max_processes = multiprocessing.cpu_count() # Processor limit # Class to format json UUID's class UUIDEncoder(json.JSONEncoder): @@ -34,6 +38,26 @@ def default(self, obj): return obj.hex return json.JSONEncoder.default(self, obj) +def parallel_worker(worker_conn, batch, column_names, key_name): + # Handle process worker -> back to json mapping and + # respond it back via the pipe. We'll convert a batch + # of rows into JSON and then close the pipe with our response + try: + fragment = convert_batch_to_json(batch, column_names) + worker_conn.send(("fragment", fragment)) + except Exception as e: + tb = traceback.format_exc() + worker_conn.send(("error", f"ERROR: {e}\nKEY: {key_name}\nTrace: {tb}")) + finally: + worker_conn.close() + +def convert_batch_to_json(batch, column_names): + # Convert a batch of rows to comma delimited JSON fragments (No start/end brackets) + records = [] + for row in batch: + as_dict = map_row_to_columns(row, column_names) + records.append(json.dumps(as_dict, cls=UUIDEncoder, default=str)) + return ",".join(records).encode("utf-8") # Return as bytes # Helper function to batch fetch data and use multipart uploading with S3 def fetch_and_upload_cursor_results( @@ -52,75 +76,29 @@ def fetch_and_upload_cursor_results( buffer = io.BytesIO() buffer_size = 0 min_part_size = 50 * 1024 * 1024 # 50 MB - first_record = True # Track the first record for JSON formatting - + + batch_generator = fetch_batches(cursor) + + # Since a generator doesn't return a list, check the first case of if it is empty try: - # Write the opening bracket for the JSON array - buffer.write(b"[") - buffer_size += 1 - # Fetch the results of the cursor query - for batch in fetch_batches(cursor): - data_with_col_names = ( - map_row_to_columns(row, column_names) for row in batch - ) - for record in data_with_col_names: - if not first_record: - # Add comma before each record except the first - buffer.write(b",") - buffer_size += 1 - else: - # Set to False after processing the first record - first_record = False - - json_line = json.dumps(record, cls=UUIDEncoder, default=str) - json_line_bytes = json_line.encode("utf-8") - buffer.write(json_line_bytes) - buffer_size += len(json_line_bytes) - - # Upload part if buffer reaches the minimum part size of 5MB - if buffer_size >= min_part_size: - # This part is now finished being appended too - # Trigger upload, reset buffer, and proceed with the next one - buffer.seek(0) - response = s3_client.upload_part( - Bucket=bucket_name, - Key=key_name, - PartNumber=part_number, - UploadId=upload_id, - Body=buffer, - ) - parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) - part_number += 1 - buffer.close() - buffer = io.BytesIO() - buffer_size = 0 - - gc.collect() - - if first_record: - # No records were written, make a new buffer and write an empty array - buffer = io.BytesIO() - buffer.write(b"[]") - buffer_size = 2 - else: - # Write closing bracket to close the JSON array - buffer.write(b"]") - buffer_size += 1 - - # Upload the final part (Or only part if no records) - buffer.seek(0) - response = s3_client.upload_part( - Bucket=bucket_name, - Key=key_name, - PartNumber=part_number, - UploadId=upload_id, - Body=buffer, - ) - parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) - buffer.close() + first_batch = next(batch_generator) + except StopIteration: + # This exception WILL be thrown if no SQL records return + first_batch = None - if parts: - # Complete multipart upload + if first_batch is None: + # No records, make a new buffer and write an empty array for the dump + try: + buffer.write(b"[]") + buffer.seek(0) + response = s3_client.upload_part( + Bucket=bucket_name, + Key=key_name, + PartNumber=part_number, + UploadId=upload_id, + Body=buffer, + ) + parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) s3_client.complete_multipart_upload( Bucket=bucket_name, Key=key_name, @@ -130,25 +108,115 @@ def fetch_and_upload_cursor_results( print( f"Data Warehouse Lambda - INFO - DB Extract - Successfully wrote {bucket_name}/{key_name}" ) - else: - # Abort the multipart upload if no parts were uploaded + except Exception as e: + # Abort multipart upload in case of failure s3_client.abort_multipart_upload( Bucket=bucket_name, Key=key_name, UploadId=upload_id ) print( - f"Data Warehouse Lambda - INFO - No data to upload for table {key_name}" + f"Data Warehouse Lambda - ERROR - DB Extract - Error during multipart upload: {e}" ) + raise e + # Early return, mark process to no longer be alive + return + + # Create a manager -> worker process for each batch in parallel + # This makes a manager and worker combo for every batch that we will await later + workers: List[Tuple[Process, multi_processing_connection, int]] = [] + # Make sure we still process the first batch + # fb = first batch + fbManager, fbWorker = Pipe() + fbp = Process(target=parallel_worker, args=(fbWorker, first_batch, column_names, key_name)) + workers.append((fbp, fbManager, 1)) + fbp.start() + # Process remaining batches from the generator + for i, batch in enumerate(batch_generator): + manager, worker = Pipe() + p = Process(target=parallel_worker, args=(worker, batch, column_names, key_name)) + workers.append((p, manager, i+1)) # +1 because of first batch preceding this + p.start() + + # + # Start our JSON output document + # + # Begin with our bracket as we are + # going to format this manually + buffer.write(b"[") + buffer_size += 1 + + print(f'Data Warehouse Lambda - INFO - {key_name} has {len(workers)} workers') + + # Concurrently wait for the workers to be complete + while workers: + finished_workers: List[multi_processing_connection | socket | int] = wait([mgr for (proc, mgr, bIndex) in workers], timeout=None) + for finished_worker in finished_workers: + is_last_worker = len(workers) == 1 + for i, (proc, mgr, batch_index) in enumerate(workers): + if mgr is finished_worker: + if mgr.poll(): + msg_type, data = mgr.recv() + mgr.close() # Free thread + proc.join() # Make sure worker is done + workers.pop(i) # Clear the worker + + if msg_type == "error": + print(f"Data Warehouse Lambda - ERROR - Worker {batch_index} failed: {data}") + raise RuntimeError(f"Worker {batch_index} error: {data}") + print( + f"Data Warehouse Lambda - INFO - DB Extract - {key_name} worker for batch index {batch_index} polled successfully. {len(workers)} workers remain" + ) + # Add data fragment to buffer + buffer.write(data) + buffer_size += len(data) + + if not is_last_worker: + # If it's not the last worker, append `,` to support a final + # [${worker_1_json}, ${worker_2_json}, ${worker_3_json}] output + buffer.write(b",") + buffer_size += 1 + + # Upload part if buffer reaches the minimum part size or + # if this is the final worker + if buffer_size >= min_part_size or is_last_worker: + if is_last_worker: + # Last call, close it up! + buffer.write(b"]") + buffer_size += 1 + buffer.seek(0) + response = s3_client.upload_part( + Bucket=bucket_name, + Key=key_name, + PartNumber=part_number, + UploadId=upload_id, + Body=buffer, + ) + parts.append({"PartNumber": part_number, "ETag": response["ETag"]}) + part_number += 1 + buffer.close() + buffer = io.BytesIO() + buffer_size = 0 + + gc.collect() + + + + print( + f"Data Warehouse Lambda - INFO - DB Extract - {key_name} workers complete, announcing multipart completion" + ) - except Exception as e: - # Abort multipart upload in case of failure - s3_client.abort_multipart_upload( - Bucket=bucket_name, Key=key_name, UploadId=upload_id - ) - print( - f"Data Warehouse Lambda - ERROR - DB Extract - Error during multipart upload: {e}" - ) - raise e - + # Workers have completed and there is nothing left to upload to the multi-part + buffer.close() + del buffer + # Announce multipart upload completion + s3_client.complete_multipart_upload( + Bucket=bucket_name, + Key=key_name, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + print( + f"Data Warehouse Lambda - INFO - DB Extract - Successfully wrote {bucket_name}/{key_name}" + ) # Helper func to yield results rather than return # to boost processing efficiency @@ -164,24 +232,23 @@ def fetch_batches(cursor: pg8000.Cursor): def map_row_to_columns(row, column_names): return {column_names[i]: row[i] for i in range(len(column_names))} - def table_extractor( table_name, json_parameter_value, bucket_name, - child_conn: multi_processing_connection, + worker: multi_processing_connection, ): try: connection = db_conn.get_connection() if connection is None: error_msg = ( - f"Failed to connect to database during child process for {table_name}" + f"Failed to connect to database during worker process for {table_name}" ) print(f"Data Warehouse Lambda - ERROR - DB Extract - {error_msg}") - child_conn.send(error_msg) + worker.send(error_msg) return print( - f"Data Warehouse Lambda - INFO - Child extraction process created for table {table_name}" + f"Data Warehouse Lambda - INFO - Worker extraction process created for table {table_name}" ) s3_key = f"db_data/{str(json_parameter_value['data']['serialNumber'] + 1).zfill(6)}/{table_name}.json" cursor = connection.cursor() # type: ignore @@ -242,14 +309,14 @@ def table_extractor( + " does not match any criteria for data warehousing" ) - # Main child process complete - child_conn.send(f"Successfully processed {table_name}") + # Worker process complete + worker.send(f"Successfully processed {table_name}") except Exception as e: - error_msg = f"Error processing {table_name}: {e}" + error_msg = f"Error processing {table_name}: {repr(e)}" print(f"Data Warehouse Lambda - ERROR - DB Extract - {error_msg}") - child_conn.send(error_msg) + worker.send(error_msg) finally: - child_conn.close() + worker.close() def db_extractor(): @@ -323,46 +390,48 @@ def db_extractor(): # Create multi processes for the number of tables we have # for the number of processors we have - cursor.close() # Close the current cursor, we are done with it. Child processes make new ones + cursor.close() # Close the current cursor, we are done with it. Worker processes make new ones - # Process each table in a separate Process with Pipe - max_processes = multiprocessing.cpu_count() # Retrieve processor limit + # Process each table in a separate Process with Pipe table_index = iter(tables) - processes: List[Tuple[Process, multi_processing_connection]] = ( + workers: List[Tuple[Process, multi_processing_connection]] = ( [] - ) # Holds the individual processes in tuples with the parent pipes - + ) # Holds the individual processes in tuples with the parent/manager pipes + + print( + f"Data Warehouse Lambda - INFO - {max_processes} processors available for use" + ) + while True: - while len(processes) < max_processes: - try: - table_name = next(table_index) - except StopIteration: - # This exception is thrown when the next function can't find anything - break - parent_conn, child_conn = Pipe() - process = Process( - target=table_extractor, - args=(table_name, json_parameter_value, bucket_name, child_conn), - ) - processes.append((process, parent_conn)) - process.start() - if not processes: - # No remaining processes are active + try: + table_name = next(table_index) + except StopIteration: + # This exception is thrown when the next function can't find anything + break + manager, worker = Pipe() + process = Process( + target=table_extractor, + args=(table_name, json_parameter_value, bucket_name, worker), + ) + workers.append((process, manager)) + process.start() + + if not workers: + # No remaining workers have jobs, db export complete break - # Manage processes concurrently without sequential waiting - for i in range(len(processes) - 1, -1, -1): - process, pipe = processes[i] - if not process.is_alive(): - # Process has finished, poll it and receive its message - if pipe.poll(): - message = pipe.recv() - print(f"Data Warehouse Lambda - INFO - {message}") - pipe.close() - process.join() - processes.pop( - i - ) # Remove the finished process to free up processor capacity + for i in range(len(workers) - 1, -1, -1): + finished_workers: List[multi_processing_connection | socket | int] = wait([mgr for (proc, mgr) in workers], timeout=None) + for finished_worker in finished_workers: + # find the matching worker + for i, (proc, mgr) in enumerate(workers): + if mgr is finished_worker: + if mgr.poll(): + message = mgr.recv() + print(f"Data Warehouse Lambda - INFO - {message}") + mgr.close() + proc.join() + workers.pop(i) # Create an SSM client try: