Skip to content
19 changes: 9 additions & 10 deletions mlpstorage_py/checkpointing/storage_writers/minio_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,34 +323,33 @@ def _flush_part(self) -> None:
rate = written_gb / elapsed if elapsed > 0 else 0.0
print(f'\r[Writer] {written_gb:.2f} GB, {rate:.2f} GB/s ', end='', flush=True)

def write_chunk(self, buffer: memoryview, size: int) -> int:
def write_chunk(self, buffer: bytes, size: int) -> int:
"""Write chunk, flushing parts as they fill up.

Args:
buffer: Memory buffer containing data to write
buffer: Bytes containing data to write
size: Number of bytes to write from buffer

Returns:
Number of bytes written
"""
data = bytes(buffer[:size])
offset = 0

while offset < size:
# Calculate how much we can add to current part
remaining_in_part = self.part_size - self.part_buffer_size
chunk_remaining = size - offset
to_write = min(remaining_in_part, chunk_remaining)

# Add to part buffer
self.part_buffer.write(data[offset:offset + to_write])
self.part_buffer.write(buffer[offset:offset + to_write])
self.part_buffer_size += to_write
offset += to_write

# Flush if part is full
if self.part_buffer_size >= self.part_size:
self._flush_part()

self.total_bytes += size
return size

Expand Down
11 changes: 5 additions & 6 deletions mlpstorage_py/checkpointing/storage_writers/s3torch_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,17 @@ def __init__(
print(f"[S3TorchWriter] region={region}, endpoint={endpoint or 'AWS S3'}")
print(f"[S3TorchWriter] (multipart auto-managed by s3torchconnector)")

def write_chunk(self, buffer: memoryview, size: int) -> int:
def write_chunk(self, buffer: bytes, size: int) -> int:
"""Write chunk directly to S3 (streaming).

Args:
buffer: Memory buffer containing data to write
buffer: Bytes containing data to write
size: Number of bytes to write from buffer

Returns:
Number of bytes written
"""
data = bytes(buffer[:size])
self.writer.write(data) # Stream directly to S3
self.writer.write(buffer) # Stream directly to S3
self.total_bytes += size
elapsed = time.monotonic() - self._start_time
written_gb = self.total_bytes / 1e9
Expand Down
51 changes: 35 additions & 16 deletions mlpstorage_py/checkpointing/streaming_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def save(
print(f"Use dgen-py: {self.use_dgen}")
print("=" * 80)

start_time = time.time()
setup_start = time.time()

# Create buffer pool
buffers, buffer_names = self._create_buffer_pool()

Expand Down Expand Up @@ -205,7 +205,9 @@ def save(
)
writer_proc.start()
print(f"\n[Main] Writer process started (PID={writer_proc.pid})")

pipeline_start = time.time()
setup_time = pipeline_start - setup_start

try:
# Producer loop
print(f"[Main] Starting producer at {time.perf_counter():.3f}s")
Expand Down Expand Up @@ -249,7 +251,7 @@ def save(
if 'error' in stats:
raise RuntimeError(f"Writer process error: {stats['error']}")

return self._format_results(stats, gen_time, time.time() - start_time, total_size_bytes)
return self._format_results(stats, gen_time, time.time() - pipeline_start, total_size_bytes, setup_time)

def _create_buffer_pool(self):
"""Create shared memory buffer pool."""
Expand Down Expand Up @@ -392,6 +394,7 @@ def _writer_process(buffer_names, chunk_size, filepath, total_size,
total_io_time = 0.0
chunks_written = 0
_write_error = None # Error from write loop, if any
io_wall_start = time.perf_counter()

try:
while written < total_size:
Expand All @@ -402,9 +405,12 @@ def _writer_process(buffer_names, chunk_size, filepath, total_size,
buffer_idx, nbytes = item
shm = buffers[buffer_idx]

# Copy buffer outside timed window to avoid memcpy inflation
chunk_bytes = bytes(shm.buf[:nbytes])

# Time ONLY the I/O operation
io_start = time.perf_counter()
bytes_written = writer.write_chunk(shm.buf, nbytes)
bytes_written = writer.write_chunk(chunk_bytes, nbytes)
total_io_time += time.perf_counter() - io_start

written += bytes_written
Expand Down Expand Up @@ -434,6 +440,8 @@ def _writer_process(buffer_names, chunk_size, filepath, total_size,
if _write_error is None:
_write_error = f"close() failed: {e}"

io_wall_end = time.perf_counter()

# Force cleanup of storage-library resources.
try:
del writer
Expand All @@ -444,6 +452,7 @@ def _writer_process(buffer_names, chunk_size, filepath, total_size,
# Build result dict — single put to stats_queue.
result = {
'io_time': total_io_time,
'io_wall_time': io_wall_end - io_wall_start,
'close_time': close_time,
'total_bytes': written,
'chunks_written': chunks_written,
Expand Down Expand Up @@ -471,20 +480,26 @@ def _writer_process(buffer_names, chunk_size, filepath, total_size,
sys.stdout.flush()
os._exit(exit_code)

def _format_results(self, stats, gen_time, total_time, total_size_bytes):
def _format_results(self, stats, gen_time, total_time, total_size_bytes, setup_time=0.0):
"""Format results for return."""
gen_throughput = (total_size_bytes / (1024**3)) / gen_time
io_throughput = (stats['total_bytes'] / (1024**3)) / stats['io_time']

actual_bytes_gb = stats['total_bytes'] / (1024**3)
gen_throughput = actual_bytes_gb / gen_time
io_throughput = actual_bytes_gb / stats.get('io_wall_time', stats['io_time'])

if stats['total_bytes'] != total_size_bytes:
print(f"[Warning] Bytes written ({stats['total_bytes']}) != requested ({total_size_bytes}); throughput ratio uses actual bytes for both numerators.")

# Calculate improved metrics
throughput_ratio = gen_throughput / io_throughput
pipeline_overhead = ((total_time - max(gen_time, stats['io_time'])) / total_time) * 100
bottleneck = "I/O" if stats['io_time'] > gen_time else "Generation"

results = {
'gen_time': gen_time,
'io_time': stats['io_time'],
'io_accumulated_time': stats['io_time'],
'io_wall_time': stats.get('io_wall_time', stats['io_time']),
'close_time': stats.get('close_time', 0.0),
'setup_time': setup_time,
'total_time': total_time,
'total_bytes': stats['total_bytes'],
'chunks': stats['chunks_written'],
Expand All @@ -495,13 +510,15 @@ def _format_results(self, stats, gen_time, total_time, total_size_bytes):
'bottleneck': bottleneck,
'backend_stats': stats.get('backend_stats', {})
}

print("\n" + "=" * 80)
print("RESULTS")
print("=" * 80)
print(f"Setup: {results['setup_time']:.4f}s (buffer pool + fork overhead)")
print(f"Generation: {results['gen_time']:.4f}s @ {results['gen_throughput_gbps']:.2f} GB/s")
print(f"I/O: {results['io_time']:.4f}s @ {results['io_throughput_gbps']:.2f} GB/s")
print(f" - write: {results['io_time'] - results['close_time']:.4f}s")
print(f"I/O: {results['io_wall_time']:.4f}s (wall) @ {results['io_throughput_gbps']:.2f} GB/s")
print(f" - accumulated: {results['io_accumulated_time']:.4f}s (sum of per-chunk timers)")
print(f" - write: {results['io_accumulated_time'] - results['close_time']:.4f}s")
print(f" - close: {results['close_time']:.4f}s (fsync/finalize)")
print(f"Total: {results['total_time']:.4f}s")
print(f"")
Expand All @@ -510,7 +527,7 @@ def _format_results(self, stats, gen_time, total_time, total_size_bytes):
print(f"Bottleneck: {results['bottleneck']}")
print(f"Chunks: {results['chunks']}")
print("=" * 80)

return results

def load(
Expand Down Expand Up @@ -699,11 +716,13 @@ def _read_block(reader, block_start, block_end, worker_id):
io_time = max(t for _, t, _ in results)
chunks = sum(c for _, _, c in results)
finally:
all_backend_stats = []
for r in readers:
try:
backend_stats = r.close()
all_backend_stats.append(r.close())
except Exception:
pass
backend_stats = all_backend_stats

total_time = time.time() - wall_start
io_gbps = (total_read / 1024**3) / io_time if io_time > 0 else 0.0
Expand Down
7 changes: 4 additions & 3 deletions mlpstorage_py/cluster_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,9 +1537,10 @@ def collect(self) -> Dict[str, Any]:
)

if result.returncode != 0:
self.logger.warning(
f"MPI collection returned non-zero exit code: "
f"{result.returncode}\nstderr: {result.stderr}"
raise RuntimeError(
f"MPI cluster collection failed (exit code {result.returncode}). "
f"Partial data from {len(collected_data)} hosts was collected but cannot be trusted. "
f"stderr: {result.stderr}"
)

self.logger.info(
Expand Down
10 changes: 7 additions & 3 deletions mlpstorage_py/rules_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,9 +1729,13 @@ def calculate_training_data_size(args, cluster_information, dataset_params, read

# Required Minimum Dataset size is 5x the total client memory
dataset_size_bytes = 5 * total_mem_bytes
file_size_bytes = dataset_params['num_samples_per_file'] * dataset_params['record_length_bytes']

min_num_files_by_bytes = dataset_size_bytes // file_size_bytes
record_length_bytes = dataset_params.get('record_length_bytes')
if not record_length_bytes:
logger.warning('record_length_bytes missing from dataset params (parquet?); skipping byte-based dataset size check')
file_size_bytes = 0
else:
file_size_bytes = dataset_params['num_samples_per_file'] * record_length_bytes
min_num_files_by_bytes = (dataset_size_bytes // file_size_bytes) if file_size_bytes else 0
num_samples_by_bytes = min_num_files_by_bytes * dataset_params['num_samples_per_file']
min_samples = 500 * num_processes * reader_params['batch_size']
min_num_files_by_samples = min_samples // dataset_params['num_samples_per_file']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def closed_mpi_processes(self):
if num_processes != 8:
self.log.error(
"CLOSED submission with model %s in subset mode requires %d processes, got %d",
model_key,
model_name,
8,
num_processes
)
Expand Down
14 changes: 5 additions & 9 deletions mlpstorage_py/submission_checker/checks/training_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,17 @@ def recalculate_dataset_size(self):

# From summary
num_accelerators = summary.get("num_accelerators", 1)
num_hosts = summary.get("num_hosts", 1)
host_memory_gb = summary.get("host_memory_GB", [0])[0]

total_host_memory = sum(summary.get("host_memory_GB", [0]))

if record_length == 0:
self.log.error("Record length is 0, cannot calculate dataset size")
valid = False
continue

# Calculate min samples from steps per epoch
num_steps_per_epoch = max(MIN_STEPS_PER_EPOCH,
num_files_train * num_samples_per_file // (batch_size * num_accelerators))
min_samples_steps = num_steps_per_epoch * batch_size * num_accelerators

min_samples_steps = MIN_STEPS_PER_EPOCH * batch_size * num_accelerators

# Calculate min samples from host memory
total_host_memory = num_hosts * host_memory_gb
min_samples_memory = (total_host_memory * HOST_MEMORY_MULTIPLIER *
1024 * 1024 * 1024 / record_length)

Expand Down
Loading