Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ struct AicpuExecutor {
}

__attribute__((noinline, cold)) void log_stall_diagnostics(
int32_t thread_idx, int32_t task_count, int32_t idle_iterations, int32_t last_progress_count, void *sm_base
int32_t thread_idx, int32_t task_count, int32_t idle_iterations, int32_t last_progress_count
) {
int32_t c = completed_tasks_.load(std::memory_order_relaxed);
DEV_ALWAYS(
Expand All @@ -567,12 +567,12 @@ struct AicpuExecutor {
);
CoreTracker &tracker = core_trackers_[thread_idx];
PTO2SchedulerState *sched = &rt->scheduler;
PTO2SharedMemoryHeader *sm_header_diag = static_cast<PTO2SharedMemoryHeader *>(sm_base);
int32_t cnt_ready = 0, cnt_waiting = 0, cnt_inflight = 0;
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
int32_t ring_task_count = sm_header_diag->rings[r].fc.current_task_index.load(std::memory_order_relaxed);
PTO2SharedMemoryRingHeader &ring = *sched->ring_sched_states[r].ring;
int32_t ring_task_count = ring.fc.current_task_index.load(std::memory_order_relaxed);
for (int32_t si = 0; si < ring_task_count; si++) {
PTO2TaskSlotState &slot_state = sched->get_slot_state(r, si);
PTO2TaskSlotState &slot_state = ring.get_slot_state_by_task_id(si);
PTO2TaskState st = slot_state.task_state.load(std::memory_order_relaxed);
int32_t rc = slot_state.fanin_refcount.load(std::memory_order_relaxed);
int32_t fi = slot_state.fanin_count;
Expand Down Expand Up @@ -1892,14 +1892,11 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa
CoreTracker &tracker = core_trackers_[thread_idx];
DEV_INFO("Thread %d: resolve_and_dispatch_pto2 entry", thread_idx);

void *sm_base = runtime->get_pto2_gm_sm_ptr();
if (!sm_base) {
DEV_ERROR("PTO2 dispatch: sm_base is null");
PTO2SharedMemoryHeader *header = rt->scheduler.sm_header;
if (!header) {
DEV_ERROR("PTO2 dispatch: header is null");
return -1;
}
DEV_INFO("Thread %d: sm_base=%p", thread_idx, sm_base);

PTO2SharedMemoryHeader *header = static_cast<PTO2SharedMemoryHeader *>(sm_base);
DEV_INFO(
"Thread %d: header=%p, task_desc_offset[0]=%lu, window_size=%lu", thread_idx, static_cast<void *>(header),
static_cast<uint64_t>(header->rings[0].task_descriptors_offset),
Expand Down Expand Up @@ -2140,7 +2137,7 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime *runtime, int32_t threa
}

if (thread_idx == 0 && task_count > 0 && idle_iterations % STALL_LOG_INTERVAL == 0) {
log_stall_diagnostics(thread_idx, task_count, idle_iterations, last_progress_count, sm_base);
log_stall_diagnostics(thread_idx, task_count, idle_iterations, last_progress_count);
}
if (idle_iterations > MAX_IDLE_ITERATIONS) {
return handle_timeout_exit(
Expand Down
119 changes: 49 additions & 70 deletions src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_orchestrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ static void *pto2_aligned_zalloc(size_t size, size_t alignment) {
static int32_t pto2_orch_mark_fatal(PTO2OrchestratorState *orch, int32_t error_code) {
always_assert(orch != nullptr);
orch->fatal = true;
if (error_code == PTO2_ERROR_NONE || orch->sm_handle == nullptr || orch->sm_handle->header == nullptr) {
if (error_code == PTO2_ERROR_NONE || orch->sm_header == nullptr) {
return PTO2_ERROR_NONE;
}

int32_t expected = PTO2_ERROR_NONE;
std::atomic<int32_t> &orch_error_code = orch->sm_handle->header->orch_error_code;
std::atomic<int32_t> &orch_error_code = orch->sm_header->orch_error_code;
if (orch_error_code.compare_exchange_strong(expected, error_code, std::memory_order_acq_rel)) {
return error_code;
}
Expand Down Expand Up @@ -170,10 +170,14 @@ void pto2_orch_report_fatal(PTO2OrchestratorState *orch, int32_t error_code, con
}

struct PTO2FaninBuilder {
PTO2TaskSlotState *inline_slots[PTO2_FANIN_INLINE_CAP];
PTO2FaninBuilder(PTO2FaninPool &spill_pool) :
count(0),
spill_start(0),
spill_pool(spill_pool) {}
int32_t count{0};
int32_t spill_start{0};
PTO2FaninPool *spill_pool{nullptr};
PTO2FaninPool &spill_pool;
PTO2TaskSlotState *inline_slots[PTO2_FANIN_INLINE_CAP];

template <typename Fn>
PTO2FaninForEachReturn<Fn> for_each(Fn &&fn) const {
Expand All @@ -197,9 +201,7 @@ struct PTO2FaninBuilder {
};

static bool pto2_append_fanin_or_fail(
PTO2OrchestratorState *orch, PTO2TaskId task_id, int32_t tensor_arg_index, TensorArgType ptype,
PTO2TaskSlotState *prod_state, PTO2FaninBuilder *fanin_builder, PTO2SchedulerState *sched, PTO2RingFlowControl &fc,
uint8_t ring_id, const char *reason
PTO2OrchestratorState *orch, PTO2TaskSlotState *prod_state, PTO2FaninBuilder *fanin_builder, uint8_t ring_id
) {
if (fanin_builder->contains(prod_state)) {
return true;
Expand All @@ -210,22 +212,8 @@ static bool pto2_append_fanin_or_fail(
return true;
}

if (sched == nullptr || fanin_builder->spill_pool == nullptr) {
LOG_ERROR("========================================");
LOG_ERROR("FATAL: Fanin Spill Builder Misconfigured!");
LOG_ERROR("========================================");
LOG_ERROR("Missing scheduler or fanin spill pool while appending dynamic fanin.");
LOG_ERROR(" task_id.raw: %" PRIu64, task_id.raw);
LOG_ERROR(" tensor_arg_index: %d", tensor_arg_index);
LOG_ERROR(" tensor_arg_type: %d", static_cast<int>(ptype));
LOG_ERROR(" reason: %s", reason);
LOG_ERROR("========================================");
pto2_orch_mark_fatal(orch, PTO2_ERROR_DEPENDENCY_OVERFLOW);
return false;
}

PTO2FaninPool &fanin_pool = *fanin_builder->spill_pool;
fanin_pool.ensure_space(*sched, fc, ring_id, 1);
PTO2FaninPool &fanin_pool = fanin_builder->spill_pool;
fanin_pool.ensure_space(orch->sm_header->rings[ring_id], 1);
int32_t spill_idx = fanin_pool.top;
PTO2FaninSpillEntry *entry = fanin_pool.alloc();
if (entry == nullptr) {
Expand Down Expand Up @@ -321,18 +309,16 @@ static bool pto2_prepare_task(
return false;
}

auto sched = orch->scheduler;
out->alloc_result = allocator.alloc(total_output_size);
if (out->alloc_result.failed()) {
pto2_orch_mark_fatal(orch, PTO2_ERROR_HEAP_RING_DEADLOCK);
return false;
}

auto &rs = sched->ring_sched_states[ring_id];
out->task_id = PTO2TaskId::make(ring_id, static_cast<uint32_t>(out->alloc_result.task_id));
out->slot_state = &rs.get_slot_state_by_slot(out->alloc_result.slot);
out->task = &orch->sm_handle->task_descriptors[ring_id][out->alloc_result.slot];
out->payload = &orch->sm_handle->task_payloads[ring_id][out->alloc_result.slot];
out->slot_state = &orch->sm_header->rings[ring_id].get_slot_state_by_slot(out->alloc_result.slot);
out->task = &orch->sm_header->rings[ring_id].task_descriptors[out->alloc_result.slot];
out->payload = &orch->sm_header->rings[ring_id].task_payloads[out->alloc_result.slot];

pto2_prefetch_payload(out->payload, args.tensor_count(), args.scalar_count());

Expand Down Expand Up @@ -360,25 +346,25 @@ static bool pto2_prepare_task(
// =============================================================================

bool pto2_orchestrator_init(
PTO2OrchestratorState *orch, PTO2SharedMemoryHandle *sm_handle, void *gm_heap, uint64_t heap_size,
PTO2OrchestratorState *orch, PTO2SharedMemoryHeader *sm_header, void *gm_heap, uint64_t heap_size,
int32_t dep_pool_capacity
) {
*orch = PTO2OrchestratorState{};

orch->sm_handle = sm_handle;
orch->sm_header = sm_header;
orch->gm_heap_base = gm_heap;
orch->gm_heap_size = heap_size * PTO2_MAX_RING_DEPTH;
orch->fatal = false;

// Initialize per-ring resources
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
void *ring_heap_base = reinterpret_cast<char *>(gm_heap) + r * heap_size;
auto &fc = sm_handle->header->rings[r].fc;
auto &ring = sm_header->rings[r];

// Initialize unified task allocator
orch->rings[r].task_allocator.init(
sm_handle->task_descriptors[r], sm_handle->header->rings[r].task_window_size, &fc.current_task_index,
&fc.last_task_alive, ring_heap_base, heap_size, &sm_handle->header->orch_error_code
ring.task_descriptors, ring.task_window_size, &ring.fc.current_task_index, &ring.fc.last_task_alive,
ring_heap_base, heap_size, &sm_header->orch_error_code
);

size_t fanin_pool_bytes =
Expand All @@ -391,13 +377,13 @@ bool pto2_orchestrator_init(
}
return false;
}
orch->rings[r].fanin_pool.init(fanin_entries, dep_pool_capacity, &sm_handle->header->orch_error_code);
orch->rings[r].fanin_pool.init(fanin_entries, dep_pool_capacity, &sm_header->orch_error_code);
}

// Initialize TensorMap with per-ring task window sizes
int32_t task_window_sizes[PTO2_MAX_RING_DEPTH];
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
task_window_sizes[r] = sm_handle->header->rings[r].task_window_size;
task_window_sizes[r] = sm_header->rings[r].task_window_size;
}
if (!orch->tensor_map.init_default(task_window_sizes)) {
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
Expand Down Expand Up @@ -574,17 +560,14 @@ pto2_submit_mixed_task(PTO2OrchestratorState *orch, const MixedKernels &mixed_ke
}
uint8_t ring_id = prepared.task_id.ring();
PTO2SchedulerState *sched = orch->scheduler;
PTO2RingFlowControl &fc = orch->sm_handle->header->rings[ring_id].fc;
PTO2RingFlowControl &fc = orch->sm_header->rings[ring_id].fc;
PTO2TaskId task_id = prepared.task_id;
PTO2TaskSlotState &cur_slot_state = *prepared.slot_state;
PTO2TaskDescriptor &task = *prepared.task;
PTO2TaskPayload &payload = *prepared.payload;
result.set_task_id(task_id);

PTO2FaninBuilder fanin_builder;
fanin_builder.count = 0;
fanin_builder.spill_start = 0;
fanin_builder.spill_pool = &orch->rings[ring_id].fanin_pool;
PTO2FaninBuilder fanin_builder(orch->rings[ring_id].fanin_pool);

CYCLE_COUNT_LAP_RECORD(g_orch_alloc_cycle, AicpuPhaseId::ORCH_ALLOC, task_id.raw);

Expand Down Expand Up @@ -617,10 +600,8 @@ pto2_submit_mixed_task(PTO2OrchestratorState *orch, const MixedKernels &mixed_ke
PTO2TaskId owner = tensor->owner_task_id;
if (owner.is_valid()) {
PTO2TaskSlotState *prod_state =
&sched->ring_sched_states[owner.ring()].get_slot_state_by_task_id(owner.local());
if (!pto2_append_fanin_or_fail(
orch, task_id, i, ptype, prod_state, &fanin_builder, sched, fc, ring_id, "creator retention"
)) {
&orch->sm_header->rings[owner.ring()].get_slot_state_by_task_id(owner.local());
if (!pto2_append_fanin_or_fail(orch, prod_state, &fanin_builder, ring_id)) {
return result;
}
}
Expand All @@ -641,10 +622,8 @@ pto2_submit_mixed_task(PTO2OrchestratorState *orch, const MixedKernels &mixed_ke
auto overlap_status = lookup_result.entries[r].overlap_status;
auto prod_ring = entry.producer_task_id.ring();
auto prod_local = entry.producer_task_id.local();
PTO2TaskSlotState *prod_state = &sched->ring_sched_states[prod_ring].get_slot_state_by_task_id(prod_local);
if (!pto2_append_fanin_or_fail(
orch, task_id, i, ptype, prod_state, &fanin_builder, sched, fc, ring_id, "overlap lookup"
)) {
PTO2TaskSlotState *prod_state = &orch->sm_header->rings[prod_ring].get_slot_state_by_task_id(prod_local);
if (!pto2_append_fanin_or_fail(orch, prod_state, &fanin_builder, ring_id)) {
return result;
}
if (ptype == TensorArgType::INOUT && overlap_status == OverlapStatus::COVERED) {
Expand All @@ -669,7 +648,7 @@ pto2_submit_mixed_task(PTO2OrchestratorState *orch, const MixedKernels &mixed_ke

CYCLE_COUNT_LAP_RECORD(g_orch_insert_cycle, AicpuPhaseId::ORCH_INSERT, task_id.raw);

// === STEP 5: Batch-write to GM (single cache line burst) ===
// === STEP 5: Batch-write to GM (single cache line burst) + Record fanin metadata ===
// Deferred from allocation phase to avoid scattered GM writes that get
// evicted by TensorMap lookup/insert cache pressure.
__builtin_prefetch(&task, 1, 1);
Expand All @@ -680,35 +659,35 @@ pto2_submit_mixed_task(PTO2OrchestratorState *orch, const MixedKernels &mixed_ke
task.packed_buffer_base = prepared.alloc_result.packed_base;
task.packed_buffer_end = prepared.alloc_result.packed_end;

// Increment fanout_count on each producer (no lock — only orch writes this field).
// Prevents premature CONSUMED: scope_end's release_producer checks fanout_refcount == fanout_count.
pto2_for_each_fanin_storage(
fanin_builder.inline_slots, fanin_builder.count, fanin_builder.spill_start, fanin_builder.spill_pool,
[](PTO2TaskSlotState *producer) {
producer->fanout_count++;
}
);

int32_t inline_count = std::min(fanin_builder.count, PTO2_FANIN_INLINE_CAP);
// Store fanin metadata in payload for scheduler to iterate
payload.fanin_actual_count = fanin_builder.count;
payload.fanin_spill_start = fanin_builder.spill_start;
payload.fanin_spill_pool = &fanin_builder.spill_pool;
Comment thread
poursoul marked this conversation as resolved.
for (int i = 0; i < inline_count; i++) {
payload.fanin_inline_slot_states[i] = fanin_builder.inline_slots[i];
}

payload.init(args, result, prepared.alloc_result, layout);

CYCLE_COUNT_LAP_RECORD(g_orch_args_cycle, AicpuPhaseId::ORCH_PARAMS, task_id.raw);
#if PTO2_ORCH_PROFILING
g_orch_args_atomic_count += 2; // fanout_lock.store + fanout_count.store
#endif

// === STEP 6: Record fanin metadata + push to wiring queue ===
// === STEP 6: push to wiring queue ===
// Deferred wiring: orchestrator only stores dependency metadata and increments
// fanout_count. The actual fanout_head wiring (lock + dep_pool + early_finished)
// is handled asynchronously by scheduler thread 0 via the wiring queue.
int32_t fanin_count = fanin_builder.count;
int32_t inline_count = std::min(fanin_count, PTO2_FANIN_INLINE_CAP);
int32_t spill_count = fanin_count - inline_count;

// Store fanin metadata in payload for scheduler to iterate
payload.fanin_actual_count = fanin_count;
payload.fanin_spill_start = (spill_count > 0) ? fanin_builder.spill_start : 0;
payload.fanin_spill_pool = (spill_count > 0) ? fanin_builder.spill_pool : nullptr;
for (int i = 0; i < inline_count; i++) {
payload.fanin_inline_slot_states[i] = fanin_builder.inline_slots[i];
}

// Increment fanout_count on each producer (no lock — only orch writes this field).
// Prevents premature CONSUMED: scope_end's release_producer checks fanout_refcount == fanout_count.
pto2_for_each_fanin_slot_state(payload, [](PTO2TaskSlotState *producer) {
producer->fanout_count += 1;
});

// Push to global wiring queue — scheduler sets fanin_count, wires fanout, checks readiness
while (!sched->wiring.queue.push(&cur_slot_state)) {
SPIN_WAIT_HINT();
Expand Down Expand Up @@ -794,7 +773,7 @@ TaskOutputTensors pto2_alloc_tensors(PTO2OrchestratorState *orch, const Arg &arg
payload.init(args, outputs, prepared.alloc_result, layout);
payload.fanin_actual_count = 0;
payload.fanin_spill_start = 0;
payload.fanin_spill_pool = nullptr;
payload.fanin_spill_pool = &orch->rings[prepared.task_id.ring()].fanin_pool;

CYCLE_COUNT_LAP_RECORD(g_orch_args_cycle, AicpuPhaseId::ORCH_PARAMS, prepared.task_id.raw);

Expand Down Expand Up @@ -842,7 +821,7 @@ void pto2_orchestrator_done(PTO2OrchestratorState *orch) {
);
}
}
orch->sm_handle->header->orchestrator_done.store(1, std::memory_order_release);
orch->sm_header->orchestrator_done.store(1, std::memory_order_release);
#if !PTO2_ORCH_PROFILING && PTO2_PROFILING
g_orch_submit_idx = 0;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
*/
struct PTO2OrchestratorState {
// === SHARED MEMORY ACCESS ===
PTO2SharedMemoryHandle *sm_handle;
PTO2SharedMemoryHeader *sm_header;

// === PER-RING RESOURCES ===
PTO2RingSet rings[PTO2_MAX_RING_DEPTH];
Expand Down Expand Up @@ -120,13 +120,13 @@ struct PTO2OrchestratorState {
* Initialize orchestrator state
*
* @param orch Orchestrator state to initialize
* @param sm_handle Shared memory handle
* @param sm_header Shared memory header
* @param gm_heap GM heap memory for output buffers
* @param heap_size Size of GM heap
* @return true on success
*/
bool pto2_orchestrator_init(
PTO2OrchestratorState *orch, PTO2SharedMemoryHandle *sm_handle, void *gm_heap, uint64_t heap_size,
PTO2OrchestratorState *orch, PTO2SharedMemoryHeader *sm_header, void *gm_heap, uint64_t heap_size,
int32_t dep_pool_capacity = PTO2_DEP_LIST_POOL_SIZE
);

Expand Down
Loading
Loading