From 3a35eaf7c6b6ecf2a5f349029c3f7db76d8d1231 Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 20:40:39 -0400 Subject: [PATCH 1/8] coord-T9: integration test suite (Gap #3) + fl_round_number seq fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add internal/testutil package: MustDB (skips when OWM_TEST_DSN unset, runs migrations, returns pgxpool), TruncateAll, Logger, NodeFixture/Sign - Fill in 7 skipped integration tests across 4 packages: registry: TestRegisterIntegration (register → activate → re-register) scheduler: TestScheduleIntegration, TestRequeueTimedOutIntegration fl: TestOpenRoundIntegration, TestSubmitGradientIntegration, TestTryAggregateIntegration (nil storage → placeholder fedAvg) stake: TestSlash_MockForceClose (mock LN, real DB, verifies slashing_events + node status + stake_status) - Add migration 000009: fl_rounds.round_number now has an auto-increment default via a dedicated sequence (INSERT in OpenRound omitted the column; without a DEFAULT the insert would violate NOT NULL) Co-Authored-By: Claude Opus 4.7 --- owm-coordinator/internal/fl/rounds_test.go | 137 +++++++++++++++++- .../internal/registry/registry_test.go | 59 +++++++- .../internal/scheduler/scheduler_test.go | 88 ++++++++++- .../internal/stake/verifier_test.go | 67 ++++++++- owm-coordinator/internal/testutil/db.go | 111 ++++++++++++++ .../000009_fl_round_number_seq.down.sql | 2 + .../000009_fl_round_number_seq.up.sql | 6 + 7 files changed, 458 insertions(+), 12 deletions(-) create mode 100644 owm-coordinator/internal/testutil/db.go create mode 100644 owm-coordinator/migrations/000009_fl_round_number_seq.down.sql create mode 100644 owm-coordinator/migrations/000009_fl_round_number_seq.up.sql diff --git a/owm-coordinator/internal/fl/rounds_test.go b/owm-coordinator/internal/fl/rounds_test.go index b40b706..cf96770 100644 --- a/owm-coordinator/internal/fl/rounds_test.go +++ b/owm-coordinator/internal/fl/rounds_test.go @@ -1,9 +1,15 @@ package fl_test import ( + "context" + "crypto/sha256" + "encoding/hex" "testing" + "github.com/google/uuid" + "github.com/owmnetwork/owm-coordinator/internal/fl" + "github.com/owmnetwork/owm-coordinator/internal/testutil" ) // TestRoundStatusConstants verifies that FL round status values are stable strings. @@ -44,14 +50,139 @@ func TestGradientSubmissionZeroValue(t *testing.T) { } // Integration tests require a live PostgreSQL instance. +// Run with: OWM_TEST_DSN=postgres://... go test ./internal/fl/... + func TestOpenRoundIntegration(t *testing.T) { - t.Skip("integration test — set OWM_TEST_DSN and remove t.Skip to run") + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + o := fl.New(pool, nil, nil, testutil.Logger()) + + roundNumber, err := o.OpenRound(ctx, "owm-model-v1") + if err != nil { + t.Fatalf("OpenRound: %v", err) + } + if roundNumber <= 0 { + t.Errorf("round_number: got %d, want > 0", roundNumber) + } + + status, err := o.GetRoundStatus(ctx, roundNumber) + if err != nil { + t.Fatalf("GetRoundStatus: %v", err) + } + if status != fl.RoundStatusOpen { + t.Errorf("status: got %q, want %q", status, fl.RoundStatusOpen) + } } func TestSubmitGradientIntegration(t *testing.T) { - t.Skip("integration test — set OWM_TEST_DSN and remove t.Skip to run") + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + o := fl.New(pool, nil, nil, testutil.Logger()) + + roundNumber, err := o.OpenRound(ctx, "owm-model-v1") + if err != nil { + t.Fatalf("OpenRound: %v", err) + } + + // fl_participants.node_id references nodes, so insert a real node first. + nodeID := uuid.New() + _, err = pool.Exec(ctx, + `INSERT INTO nodes (node_id, public_key, ln_node_uri, tier, status) + VALUES ($1, $2, $3, 't1', 'active')`, + nodeID, "pk-submit-"+nodeID.String(), nodeID.String()+"@127.0.0.1:9735", + ) + if err != nil { + t.Fatalf("insert node: %v", err) + } + + gradientData := make([]byte, 64) + h := sha256.Sum256(gradientData) + err = o.SubmitGradient(ctx, fl.GradientSubmission{ + NodeID: nodeID, + RoundID: roundNumber, + GradientData: gradientData, + GradientHash: hex.EncodeToString(h[:]), + S3URL: "gradients/1/test.bin", + }) + if err != nil { + t.Fatalf("SubmitGradient: %v", err) + } + + var count int + _ = pool.QueryRow(ctx, + `SELECT COUNT(*) FROM fl_participants fp + JOIN fl_rounds r ON r.round_id = fp.round_id + WHERE r.round_number = $1 AND fp.node_id = $2`, + roundNumber, nodeID, + ).Scan(&count) + if count != 1 { + t.Errorf("fl_participants: got %d row(s), want 1", count) + } } func TestTryAggregateIntegration(t *testing.T) { - t.Skip("integration test — set OWM_TEST_DSN and remove t.Skip to run") + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + // Use a fast-failing OTS calendar so the async stampOTS goroutine exits quickly. + o := fl.NewWithConfig(pool, nil, nil, &fl.OrchestratorConfig{ + OTSCalendars: []string{"http://127.0.0.1:1/digest"}, + }, testutil.Logger()) + + roundNumber, err := o.OpenRound(ctx, "owm-model-v1") + if err != nil { + t.Fatalf("OpenRound: %v", err) + } + + // Submit gradients from 3 nodes (minParticipants = 3). + for i := 0; i < 3; i++ { + nodeID := uuid.New() + _, err = pool.Exec(ctx, + `INSERT INTO nodes (node_id, public_key, ln_node_uri, tier, status) + VALUES ($1, $2, $3, 't1', 'active')`, + nodeID, + "pk-agg-"+nodeID.String(), + nodeID.String()+"@127.0.0.1:9735", + ) + if err != nil { + t.Fatalf("insert node %d: %v", i, err) + } + data := make([]byte, 64) + data[0] = byte(i) + h := sha256.Sum256(data) + if err := o.SubmitGradient(ctx, fl.GradientSubmission{ + NodeID: nodeID, + RoundID: roundNumber, + GradientHash: hex.EncodeToString(h[:]), + }); err != nil { + t.Fatalf("SubmitGradient %d: %v", i, err) + } + } + + summary, err := o.TryAggregate(ctx, roundNumber) + if err != nil { + t.Fatalf("TryAggregate: %v", err) + } + if summary == nil { + t.Fatal("TryAggregate: got nil summary, expected aggregation to proceed") + } + if summary.ParticipantCount != 3 { + t.Errorf("ParticipantCount: got %d, want 3", summary.ParticipantCount) + } + if summary.AggregatedHash == "" { + t.Error("AggregatedHash: expected non-empty") + } + + status, err := o.GetRoundStatus(ctx, roundNumber) + if err != nil { + t.Fatalf("GetRoundStatus: %v", err) + } + if status != fl.RoundStatusComplete { + t.Errorf("round status: got %q, want %q", status, fl.RoundStatusComplete) + } } diff --git a/owm-coordinator/internal/registry/registry_test.go b/owm-coordinator/internal/registry/registry_test.go index 777300c..088ec7d 100644 --- a/owm-coordinator/internal/registry/registry_test.go +++ b/owm-coordinator/internal/registry/registry_test.go @@ -1,10 +1,12 @@ package registry_test import ( + "context" "testing" "time" "github.com/owmnetwork/owm-coordinator/internal/registry" + "github.com/owmnetwork/owm-coordinator/internal/testutil" ) // TestCanonicalRegisterMessageFormat verifies that signature inputs cannot be @@ -61,9 +63,60 @@ func TestNodeCapabilitiesZeroValue(t *testing.T) { } // Integration tests require a live PostgreSQL instance. -// Run with: OWM_TEST_DSN=postgres://... go test ./internal/registry/... -tags integration +// Run with: OWM_TEST_DSN=postgres://... go test ./internal/registry/... func TestRegisterIntegration(t *testing.T) { - t.Skip("integration test — set OWM_TEST_DSN and remove t.Skip to run") - _ = time.Now() // placeholder — real test would create pgxpool, call Register, etc. + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + reg := registry.New(pool, testutil.Logger()) + nf := testutil.NewNodeFixture(t) + + ts := time.Now().Unix() + caps := registry.NodeCapabilities{ + Tier: registry.TierT1, + VRAMGB: 8, + RAMGB: 16, + BandwidthMbps: 100, + SupportedTaskTypes: []string{"inference"}, + } + sig := nf.Sign(nf.LNNodeURI, registry.TierT1, ts) + + node, err := reg.Register(ctx, nf.PubKeyHex, nf.LNNodeURI, "", caps, sig, ts) + if err != nil { + t.Fatalf("Register: %v", err) + } + if node.Status != registry.StatusPending { + t.Errorf("status: got %q, want %q", node.Status, registry.StatusPending) + } + if node.Tier != registry.TierT1 { + t.Errorf("tier: got %q, want %q", node.Tier, registry.TierT1) + } + + // Activate transitions pending → active. + if err := reg.Activate(ctx, node.NodeID); err != nil { + t.Fatalf("Activate: %v", err) + } + got, err := reg.GetByPublicKey(ctx, nf.PubKeyHex) + if err != nil { + t.Fatalf("GetByPublicKey: %v", err) + } + if got.Status != registry.StatusActive { + t.Errorf("post-activate status: got %q, want %q", got.Status, registry.StatusActive) + } + + // Re-registration resets to pending (idempotent upsert). + ts2 := time.Now().Unix() + sig2 := nf.Sign(nf.LNNodeURI, registry.TierT1, ts2) + node2, err := reg.Register(ctx, nf.PubKeyHex, nf.LNNodeURI, "", caps, sig2, ts2) + if err != nil { + t.Fatalf("re-Register: %v", err) + } + if node2.NodeID != node.NodeID { + t.Errorf("re-register: node_id changed: got %s, want %s", node2.NodeID, node.NodeID) + } + if node2.Status != registry.StatusPending { + t.Errorf("re-register status: got %q, want %q", node2.Status, registry.StatusPending) + } } diff --git a/owm-coordinator/internal/scheduler/scheduler_test.go b/owm-coordinator/internal/scheduler/scheduler_test.go index 93d7d45..1f68f7e 100644 --- a/owm-coordinator/internal/scheduler/scheduler_test.go +++ b/owm-coordinator/internal/scheduler/scheduler_test.go @@ -1,9 +1,15 @@ package scheduler_test import ( + "context" "testing" + "time" + "github.com/google/uuid" + + "github.com/owmnetwork/owm-coordinator/internal/registry" "github.com/owmnetwork/owm-coordinator/internal/scheduler" + "github.com/owmnetwork/owm-coordinator/internal/testutil" ) // TestTaskTypeConstants ensures task type values match the protobuf task_type strings. @@ -48,10 +54,88 @@ func TestAssignmentFields(t *testing.T) { } // Integration tests require a live PostgreSQL instance. +// Run with: OWM_TEST_DSN=postgres://... go test ./internal/scheduler/... + func TestScheduleIntegration(t *testing.T) { - t.Skip("integration test — set OWM_TEST_DSN and remove t.Skip to run") + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + reg := registry.New(pool, testutil.Logger()) + nf := testutil.NewNodeFixture(t) + + ts := time.Now().Unix() + caps := registry.NodeCapabilities{Tier: registry.TierT1, VRAMGB: 8, RAMGB: 16, BandwidthMbps: 100} + sig := nf.Sign(nf.LNNodeURI, registry.TierT1, ts) + node, err := reg.Register(ctx, nf.PubKeyHex, nf.LNNodeURI, "", caps, sig, ts) + if err != nil { + t.Fatalf("Register: %v", err) + } + if err := reg.Activate(ctx, node.NodeID); err != nil { + t.Fatalf("Activate: %v", err) + } + + sched := scheduler.New(pool, reg, nil, testutil.Logger()) + a, err := sched.Schedule(ctx, scheduler.TaskInference, "inputhash-abc", 60) + if err != nil { + t.Fatalf("Schedule: %v", err) + } + if a.NodeID != node.NodeID { + t.Errorf("assigned node: got %s, want %s", a.NodeID, node.NodeID) + } + if a.TaskType != scheduler.TaskInference { + t.Errorf("task type: got %s, want inference", a.TaskType) + } + if a.RewardSats <= 0 { + t.Errorf("reward_sats: expected > 0, got %d", a.RewardSats) + } } func TestRequeueTimedOutIntegration(t *testing.T) { - t.Skip("integration test — set OWM_TEST_DSN and remove t.Skip to run") + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + reg := registry.New(pool, testutil.Logger()) + nf := testutil.NewNodeFixture(t) + + ts := time.Now().Unix() + caps := registry.NodeCapabilities{Tier: registry.TierT1, VRAMGB: 8, RAMGB: 16, BandwidthMbps: 100} + sig := nf.Sign(nf.LNNodeURI, registry.TierT1, ts) + node, err := reg.Register(ctx, nf.PubKeyHex, nf.LNNodeURI, "", caps, sig, ts) + if err != nil { + t.Fatalf("Register: %v", err) + } + if err := reg.Activate(ctx, node.NodeID); err != nil { + t.Fatalf("Activate: %v", err) + } + + // Insert a task that started 2 minutes ago with a 60-second timeout (already timed out). + taskID := uuid.New() + _, err = pool.Exec(ctx, + `INSERT INTO tasks + (task_id, task_type, assigned_node, node_ln_uri, status, input_hash, + reward_sats, submitted_at, started_at, timeout_seconds) + VALUES ($1, 'inference', $2, $3, 'running', 'hash', 10, + now() - interval '2 minutes', now() - interval '2 minutes', 60)`, + taskID, node.NodeID, node.LNNodeURI, + ) + if err != nil { + t.Fatalf("insert timed-out task: %v", err) + } + + sched := scheduler.New(pool, reg, nil, testutil.Logger()) + n, err := sched.RequeueTimedOut(ctx) + if err != nil { + t.Fatalf("RequeueTimedOut: %v", err) + } + if n != 1 { + t.Errorf("requeued count: got %d, want 1", n) + } + + var status string + _ = pool.QueryRow(ctx, `SELECT status FROM tasks WHERE task_id = $1`, taskID).Scan(&status) + if status != "pending" { + t.Errorf("task status after requeue: got %q, want pending", status) + } } diff --git a/owm-coordinator/internal/stake/verifier_test.go b/owm-coordinator/internal/stake/verifier_test.go index 4d1724a..31e5df2 100644 --- a/owm-coordinator/internal/stake/verifier_test.go +++ b/owm-coordinator/internal/stake/verifier_test.go @@ -3,6 +3,7 @@ package stake import ( "context" "testing" + "time" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgxpool" @@ -10,6 +11,7 @@ import ( "github.com/owmnetwork/owm-coordinator/internal/lightning" "github.com/owmnetwork/owm-coordinator/internal/lightning/mock" + "github.com/owmnetwork/owm-coordinator/internal/testutil" ) func TestVerifyStake_SufficientStake(t *testing.T) { @@ -92,9 +94,66 @@ func TestSlash_NilClient(t *testing.T) { } } -// TestSlash_MockForceClose would require DB (node_stakes, slashing_events, etc.). -// Plan: "DB interaction must be mocked with pgxmock or skipped via build tag". -// We skip it here; integration tests can cover full Slash with mock LN. +// TestSlash_MockForceClose is an integration test: it requires a live PostgreSQL +// instance but uses a mock Lightning client so no real LN node is needed. +// Run with: OWM_TEST_DSN=postgres://... go test ./internal/stake/... func TestSlash_MockForceClose(t *testing.T) { - t.Skip("Slash requires DB; use integration test or pgxmock") + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + + // Insert a node and its stake record directly; no need for the registry. + nodeID := uuid.New() + _, err := pool.Exec(ctx, + `INSERT INTO nodes (node_id, public_key, ln_node_uri, tier, status) + VALUES ($1, $2, $3, 't1', 'active')`, + nodeID, "pk-slash-"+nodeID.String(), nodeID.String()+"@127.0.0.1:9735", + ) + if err != nil { + t.Fatalf("insert node: %v", err) + } + _, err = pool.Exec(ctx, + `INSERT INTO node_stakes + (node_id, channel_id, channel_capacity, local_balance, tier_minimum, stake_status) + VALUES ($1, 'test-channel-slash', 200000, 200000, 100000, 'active')`, + nodeID, + ) + if err != nil { + t.Fatalf("insert node_stakes: %v", err) + } + + lnMock := mock.New() + v := New(pool, lnMock, lnMock, SlashConfig{CooldownDuration: time.Hour}, zap.NewNop()) + + if err := v.Slash(ctx, nodeID, "t1", "test-misbehavior", "evidence-hash-abc", 3); err != nil { + t.Fatalf("Slash: %v", err) + } + + // Slashing event must be recorded. + var eventCount int + _ = pool.QueryRow(ctx, + `SELECT COUNT(*) FROM slashing_events WHERE node_id = $1`, nodeID, + ).Scan(&eventCount) + if eventCount != 1 { + t.Errorf("slashing_events: got %d, want 1", eventCount) + } + + // Node must be suspended. + var status string + _ = pool.QueryRow(ctx, + `SELECT status FROM nodes WHERE node_id = $1`, nodeID, + ).Scan(&status) + if status != "suspended" { + t.Errorf("node status after slash: got %q, want suspended", status) + } + + // Stake must be marked force_closed. + var stakeStatus string + _ = pool.QueryRow(ctx, + `SELECT stake_status FROM node_stakes WHERE node_id = $1`, nodeID, + ).Scan(&stakeStatus) + if stakeStatus != "force_closed" { + t.Errorf("stake_status: got %q, want force_closed", stakeStatus) + } } diff --git a/owm-coordinator/internal/testutil/db.go b/owm-coordinator/internal/testutil/db.go new file mode 100644 index 0000000..ac78a85 --- /dev/null +++ b/owm-coordinator/internal/testutil/db.go @@ -0,0 +1,111 @@ +// Package testutil provides helpers for integration tests that require a live +// PostgreSQL database. Set OWM_TEST_DSN to a Postgres DSN to run them; tests +// are skipped automatically when the variable is absent. +package testutil + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/golang-migrate/migrate/v4" + _ "github.com/golang-migrate/migrate/v4/database/postgres" + _ "github.com/golang-migrate/migrate/v4/source/file" + "github.com/jackc/pgx/v5/pgxpool" + "go.uber.org/zap" +) + +// MustDB returns a live pgxpool.Pool for integration tests. +// Skips the test if OWM_TEST_DSN is unset; runs all pending migrations. +// The pool is closed automatically at test cleanup. +func MustDB(t *testing.T) *pgxpool.Pool { + t.Helper() + dsn := os.Getenv("OWM_TEST_DSN") + if dsn == "" { + t.Skip("integration test: set OWM_TEST_DSN to run") + } + + // Navigate from this source file to the migrations directory. + _, thisFile, _, _ := runtime.Caller(0) + migrationsDir := filepath.Join(filepath.Dir(thisFile), "..", "..", "migrations") + migrationsDir, _ = filepath.Abs(migrationsDir) + + m, err := migrate.New("file://"+migrationsDir, dsn) + if err != nil { + t.Fatalf("testutil.MustDB: migrate.New: %v", err) + } + defer m.Close() + if err := m.Up(); err != nil && err != migrate.ErrNoChange { + t.Fatalf("testutil.MustDB: migrate.Up: %v", err) + } + + pool, err := pgxpool.New(context.Background(), dsn) + if err != nil { + t.Fatalf("testutil.MustDB: pgxpool.New: %v", err) + } + t.Cleanup(pool.Close) + return pool +} + +// TruncateAll removes all rows from every OWM application table and resets +// sequences. Call at the start of each integration test for isolation. +func TruncateAll(t *testing.T, pool *pgxpool.Pool) { + t.Helper() + _, err := pool.Exec(context.Background(), ` + TRUNCATE TABLE + misbehavior_signals, + slashing_events, + node_stakes, + fl_participants, + model_versions, + bounties, + tasks, + fl_rounds, + nodes + RESTART IDENTITY CASCADE`) + if err != nil { + t.Fatalf("TruncateAll: %v", err) + } +} + +// Logger returns a no-op zap logger suitable for unit and integration tests. +func Logger() *zap.Logger { + return zap.NewNop() +} + +// NodeFixture holds an Ed25519 key pair and a synthetic LN node URI for tests. +type NodeFixture struct { + PubKeyHex string + PubKey ed25519.PublicKey + PrivKey ed25519.PrivateKey + LNNodeURI string +} + +// NewNodeFixture generates a fresh Ed25519 key pair. +func NewNodeFixture(t *testing.T) *NodeFixture { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("NewNodeFixture: %v", err) + } + pubHex := hex.EncodeToString(pub) + return &NodeFixture{ + PubKeyHex: pubHex, + PubKey: pub, + PrivKey: priv, + LNNodeURI: pubHex[:20] + "@127.0.0.1:9735", + } +} + +// Sign returns an Ed25519 signature over the canonical OWM registration message +// used by registry.Register. +func (f *NodeFixture) Sign(lnURI, tier string, ts int64) []byte { + msg := []byte(fmt.Sprintf("owm-register|%s|%s|%s|%d", f.PubKeyHex, lnURI, tier, ts)) + return ed25519.Sign(f.PrivKey, msg) +} diff --git a/owm-coordinator/migrations/000009_fl_round_number_seq.down.sql b/owm-coordinator/migrations/000009_fl_round_number_seq.down.sql new file mode 100644 index 0000000..26a2982 --- /dev/null +++ b/owm-coordinator/migrations/000009_fl_round_number_seq.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE fl_rounds ALTER COLUMN round_number DROP DEFAULT; +DROP SEQUENCE IF EXISTS fl_round_number_seq; diff --git a/owm-coordinator/migrations/000009_fl_round_number_seq.up.sql b/owm-coordinator/migrations/000009_fl_round_number_seq.up.sql new file mode 100644 index 0000000..c2f2d30 --- /dev/null +++ b/owm-coordinator/migrations/000009_fl_round_number_seq.up.sql @@ -0,0 +1,6 @@ +-- fl_rounds.round_number had no DEFAULT, causing OpenRound INSERT to fail. +-- Create a dedicated sequence so callers can omit round_number on insert. +CREATE SEQUENCE IF NOT EXISTS fl_round_number_seq; +ALTER TABLE fl_rounds + ALTER COLUMN round_number SET DEFAULT nextval('fl_round_number_seq'); +ALTER SEQUENCE fl_round_number_seq OWNED BY fl_rounds.round_number; From 09a9167965fa1949ab275bf690caff7729e5a97a Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 20:49:52 -0400 Subject: [PATCH 2/8] coord-T10: task eligibility validation (Gap #4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Migration 000010: add supported_task_types TEXT[] DEFAULT '{}' to nodes - registry.Node: add SupportedTaskTypes []string field; persist it through Register upsert and return it from GetByPublicKey / ListActive - scheduler.filterEligible: enforce task type match in addition to tier — a node with an empty SupportedTaskTypes list accepts all task types (backward-compatible); non-empty list must contain the requested task type - New supportsTaskType() helper (pure, no imports) - eligibility_test.go (package scheduler): 6 unit tests covering tier-only, empty list, mismatch, combined tier+type, nil input, and supportsTaskType table - TestScheduleIntegration_TaskTypeMismatch: integration test confirms Schedule errors when no node supports the requested type and succeeds when one does Co-Authored-By: Claude Sonnet 4.6 --- owm-coordinator/internal/registry/registry.go | 88 +++++++++-------- .../internal/scheduler/eligibility_test.go | 97 +++++++++++++++++++ .../internal/scheduler/scheduler.go | 28 +++++- .../internal/scheduler/scheduler_test.go | 47 +++++++++ .../000010_add_supported_task_types.down.sql | 1 + .../000010_add_supported_task_types.up.sql | 4 + 6 files changed, 221 insertions(+), 44 deletions(-) create mode 100644 owm-coordinator/internal/scheduler/eligibility_test.go create mode 100644 owm-coordinator/migrations/000010_add_supported_task_types.down.sql create mode 100644 owm-coordinator/migrations/000010_add_supported_task_types.up.sql diff --git a/owm-coordinator/internal/registry/registry.go b/owm-coordinator/internal/registry/registry.go index f69a271..f407b54 100644 --- a/owm-coordinator/internal/registry/registry.go +++ b/owm-coordinator/internal/registry/registry.go @@ -35,20 +35,21 @@ const ( // Node represents a registered OWM network participant. type Node struct { - NodeID uuid.UUID - PublicKey string // Ed25519 hex - LNNodeURI string // pubkey@host:port (clearnet or .onion) - OnionAddress string // optional Tor v3 .onion hostname for control-plane access - Tier string - VRAMGB float64 - RAMGB float64 - BandwidthMbps float64 - Reliability float64 // 0.0–1.0 rolling 7-day - TotalTasks int64 - TotalSats int64 - Status string - RegisteredAt time.Time - LastHeartbeat *time.Time + NodeID uuid.UUID + PublicKey string // Ed25519 hex + LNNodeURI string // pubkey@host:port (clearnet or .onion) + OnionAddress string // optional Tor v3 .onion hostname for control-plane access + Tier string + VRAMGB float64 + RAMGB float64 + BandwidthMbps float64 + SupportedTaskTypes []string // task types this node accepts; empty = all types + Reliability float64 // 0.0–1.0 rolling 7-day + TotalTasks int64 + TotalSats int64 + Status string + RegisteredAt time.Time + LastHeartbeat *time.Time } // NodeCapabilities describes the hardware offered by a registering node. @@ -92,18 +93,23 @@ func (r *Registry) Register(ctx context.Context, pubKeyHex, lnURI, onionAddr str } // Upsert node — if pubkey already exists, update capabilities and reset to pending. + supportedTypes := caps.SupportedTaskTypes + if supportedTypes == nil { + supportedTypes = []string{} + } node := &Node{ - NodeID: uuid.New(), - PublicKey: pubKeyHex, - LNNodeURI: lnURI, - OnionAddress: onionAddr, - Tier: caps.Tier, - VRAMGB: caps.VRAMGB, - RAMGB: caps.RAMGB, - BandwidthMbps: caps.BandwidthMbps, - Reliability: 1.0, - Status: StatusPending, - RegisteredAt: time.Now().UTC(), + NodeID: uuid.New(), + PublicKey: pubKeyHex, + LNNodeURI: lnURI, + OnionAddress: onionAddr, + Tier: caps.Tier, + VRAMGB: caps.VRAMGB, + RAMGB: caps.RAMGB, + BandwidthMbps: caps.BandwidthMbps, + SupportedTaskTypes: supportedTypes, + Reliability: 1.0, + Status: StatusPending, + RegisteredAt: time.Now().UTC(), } // Capture existing tier/status before the upsert so re-registration can @@ -113,17 +119,19 @@ func (r *Registry) Register(ctx context.Context, pubKeyHex, lnURI, onionAddr str SELECT tier, status FROM nodes WHERE public_key = $2 ) INSERT INTO nodes (node_id, public_key, ln_node_uri, onion_address, tier, - vram_gb, ram_gb, bandwidth_mbps, reliability, status, registered_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + vram_gb, ram_gb, bandwidth_mbps, reliability, status, + registered_at, supported_task_types) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT (public_key) DO UPDATE - SET ln_node_uri = EXCLUDED.ln_node_uri, - onion_address = EXCLUDED.onion_address, - tier = EXCLUDED.tier, - vram_gb = EXCLUDED.vram_gb, - ram_gb = EXCLUDED.ram_gb, - bandwidth_mbps = EXCLUDED.bandwidth_mbps, - status = 'pending', - registered_at = EXCLUDED.registered_at + SET ln_node_uri = EXCLUDED.ln_node_uri, + onion_address = EXCLUDED.onion_address, + tier = EXCLUDED.tier, + vram_gb = EXCLUDED.vram_gb, + ram_gb = EXCLUDED.ram_gb, + bandwidth_mbps = EXCLUDED.bandwidth_mbps, + supported_task_types = EXCLUDED.supported_task_types, + status = 'pending', + registered_at = EXCLUDED.registered_at RETURNING node_id, status, (SELECT tier FROM prior) AS prior_tier, (SELECT status FROM prior) AS prior_status` @@ -131,7 +139,7 @@ func (r *Registry) Register(ctx context.Context, pubKeyHex, lnURI, onionAddr str row := r.db.QueryRow(ctx, q, node.NodeID, node.PublicKey, node.LNNodeURI, node.OnionAddress, node.Tier, node.VRAMGB, node.RAMGB, node.BandwidthMbps, node.Reliability, - node.Status, node.RegisteredAt, + node.Status, node.RegisteredAt, node.SupportedTaskTypes, ) var priorTier, priorStatus *string @@ -188,7 +196,7 @@ func (r *Registry) RecordHeartbeat(ctx context.Context, nodeID uuid.UUID) (pendi func (r *Registry) GetByPublicKey(ctx context.Context, pubKeyHex string) (*Node, error) { const q = ` SELECT node_id, public_key, ln_node_uri, onion_address, tier, vram_gb, ram_gb, - bandwidth_mbps, reliability, total_tasks, total_sats, + bandwidth_mbps, supported_task_types, reliability, total_tasks, total_sats, status, registered_at, last_heartbeat FROM nodes WHERE public_key = $1` @@ -196,7 +204,7 @@ func (r *Registry) GetByPublicKey(ctx context.Context, pubKeyHex string) (*Node, row := r.db.QueryRow(ctx, q, pubKeyHex) err := row.Scan( &n.NodeID, &n.PublicKey, &n.LNNodeURI, &n.OnionAddress, &n.Tier, - &n.VRAMGB, &n.RAMGB, &n.BandwidthMbps, &n.Reliability, + &n.VRAMGB, &n.RAMGB, &n.BandwidthMbps, &n.SupportedTaskTypes, &n.Reliability, &n.TotalTasks, &n.TotalSats, &n.Status, &n.RegisteredAt, &n.LastHeartbeat, ) if err != nil { @@ -209,7 +217,7 @@ func (r *Registry) GetByPublicKey(ctx context.Context, pubKeyHex string) (*Node, func (r *Registry) ListActive(ctx context.Context) ([]*Node, error) { const q = ` SELECT node_id, public_key, ln_node_uri, onion_address, tier, vram_gb, ram_gb, - bandwidth_mbps, reliability, total_tasks, total_sats, + bandwidth_mbps, supported_task_types, reliability, total_tasks, total_sats, status, registered_at, last_heartbeat FROM nodes WHERE status = 'active' ORDER BY reliability DESC` @@ -224,7 +232,7 @@ func (r *Registry) ListActive(ctx context.Context) ([]*Node, error) { var n Node if err := rows.Scan( &n.NodeID, &n.PublicKey, &n.LNNodeURI, &n.OnionAddress, &n.Tier, - &n.VRAMGB, &n.RAMGB, &n.BandwidthMbps, &n.Reliability, + &n.VRAMGB, &n.RAMGB, &n.BandwidthMbps, &n.SupportedTaskTypes, &n.Reliability, &n.TotalTasks, &n.TotalSats, &n.Status, &n.RegisteredAt, &n.LastHeartbeat, ); err != nil { return nil, err diff --git a/owm-coordinator/internal/scheduler/eligibility_test.go b/owm-coordinator/internal/scheduler/eligibility_test.go new file mode 100644 index 0000000..343564c --- /dev/null +++ b/owm-coordinator/internal/scheduler/eligibility_test.go @@ -0,0 +1,97 @@ +package scheduler + +import ( + "context" + "testing" + + "github.com/owmnetwork/owm-coordinator/internal/registry" +) + +// filterEligible is unexported; these tests live in package scheduler. + +var noopScheduler = &Scheduler{} + +func TestFilterEligible_TierOnly(t *testing.T) { + nodes := []*registry.Node{ + {Tier: "t1"}, + {Tier: "t2"}, + {Tier: "t3"}, + } + // TaskGradientAgg needs tier 2. + got := noopScheduler.filterEligible(context.Background(), nodes, TaskGradientAgg, 2) + if len(got) != 2 { + t.Errorf("tier filter: got %d, want 2 (t2+t3)", len(got)) + } +} + +func TestFilterEligible_EmptySupportedTypes_AcceptsAll(t *testing.T) { + nodes := []*registry.Node{ + {Tier: "t1", SupportedTaskTypes: nil}, + {Tier: "t1", SupportedTaskTypes: []string{}}, + } + // Empty/nil list means "accept all task types". + got := noopScheduler.filterEligible(context.Background(), nodes, TaskInference, 1) + if len(got) != 2 { + t.Errorf("nil/empty supported types: got %d, want 2", len(got)) + } +} + +func TestFilterEligible_TaskTypeMismatch(t *testing.T) { + nodes := []*registry.Node{ + {Tier: "t1", SupportedTaskTypes: []string{"fl_round", "gradient_agg"}}, + {Tier: "t1", SupportedTaskTypes: []string{"inference"}}, + } + got := noopScheduler.filterEligible(context.Background(), nodes, TaskInference, 1) + if len(got) != 1 { + t.Errorf("task type mismatch: got %d, want 1 (inference-only node)", len(got)) + } + if got[0].SupportedTaskTypes[0] != "inference" { + t.Errorf("wrong node selected: got %v", got[0].SupportedTaskTypes) + } +} + +func TestFilterEligible_TierAndTaskTypeCombined(t *testing.T) { + nodes := []*registry.Node{ + // Passes tier but not task type. + {Tier: "t2", SupportedTaskTypes: []string{"inference"}}, + // Passes both. + {Tier: "t2", SupportedTaskTypes: []string{"gradient_agg"}}, + // Fails tier. + {Tier: "t1", SupportedTaskTypes: []string{"gradient_agg"}}, + // Passes tier, empty list (all types). + {Tier: "t3", SupportedTaskTypes: nil}, + } + // TaskGradientAgg needs tier 2 — nodes[0] fails task type, nodes[2] fails tier. + got := noopScheduler.filterEligible(context.Background(), nodes, TaskGradientAgg, 2) + if len(got) != 2 { + t.Errorf("combined filter: got %d, want 2", len(got)) + } +} + +func TestFilterEligible_NoNodes(t *testing.T) { + got := noopScheduler.filterEligible(context.Background(), nil, TaskInference, 1) + if len(got) != 0 { + t.Errorf("nil input: got %d, want 0", len(got)) + } +} + +func TestSupportsTaskType_EmptyAcceptsAll(t *testing.T) { + cases := []struct { + supported []string + task TaskType + want bool + }{ + {nil, TaskInference, true}, + {[]string{}, TaskFLRound, true}, + {[]string{"inference"}, TaskInference, true}, + {[]string{"inference"}, TaskFLRound, false}, + {[]string{"fl_round", "inference"}, TaskInference, true}, + {[]string{"gradient_agg"}, TaskAuditRepo, false}, + } + for _, c := range cases { + got := supportsTaskType(c.supported, c.task) + if got != c.want { + t.Errorf("supportsTaskType(%v, %q) = %v, want %v", c.supported, c.task, got, c.want) + } + } +} diff --git a/owm-coordinator/internal/scheduler/scheduler.go b/owm-coordinator/internal/scheduler/scheduler.go index f916fd4..3264127 100644 --- a/owm-coordinator/internal/scheduler/scheduler.go +++ b/owm-coordinator/internal/scheduler/scheduler.go @@ -177,18 +177,38 @@ func (s *Scheduler) RequeueTimedOut(ctx context.Context) (int, error) { return n, nil } -// filterEligible returns nodes that meet the minimum tier for the task. -func (s *Scheduler) filterEligible(_ context.Context, nodes []*registry.Node, _ TaskType, minTierNum int) []*registry.Node { +// filterEligible returns nodes that meet both the minimum tier and task type +// support requirements. A node with an empty SupportedTaskTypes slice accepts +// any task type (backward-compatible default). +func (s *Scheduler) filterEligible(_ context.Context, nodes []*registry.Node, taskType TaskType, minTierNum int) []*registry.Node { tierNums := map[string]int{"t1": 1, "t2": 2, "t3": 3} var out []*registry.Node for _, n := range nodes { - if tierNums[n.Tier] >= minTierNum { - out = append(out, n) + if tierNums[n.Tier] < minTierNum { + continue } + if !supportsTaskType(n.SupportedTaskTypes, taskType) { + continue + } + out = append(out, n) } return out } +// supportsTaskType returns true when a node's declared list allows the given +// task type, or when the list is empty (meaning the node accepts all types). +func supportsTaskType(supported []string, taskType TaskType) bool { + if len(supported) == 0 { + return true + } + for _, tt := range supported { + if tt == string(taskType) { + return true + } + } + return false +} + // selectNode picks the best node using a weighted score: // score = reliability × tier_multiplier × stake_bonus // Adds jitter to avoid thundering-herd when multiple nodes score identically. diff --git a/owm-coordinator/internal/scheduler/scheduler_test.go b/owm-coordinator/internal/scheduler/scheduler_test.go index 1f68f7e..dbb6951 100644 --- a/owm-coordinator/internal/scheduler/scheduler_test.go +++ b/owm-coordinator/internal/scheduler/scheduler_test.go @@ -139,3 +139,50 @@ func TestRequeueTimedOutIntegration(t *testing.T) { t.Errorf("task status after requeue: got %q, want pending", status) } } + +// TestScheduleIntegration_TaskTypeMismatch verifies that a node declaring +// specific supported task types is NOT assigned tasks it doesn't support, +// and IS assigned tasks it does support. +func TestScheduleIntegration_TaskTypeMismatch(t *testing.T) { + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + reg := registry.New(pool, testutil.Logger()) + + // Register a node that only supports "fl_round". + nf := testutil.NewNodeFixture(t) + ts := time.Now().Unix() + caps := registry.NodeCapabilities{ + Tier: registry.TierT1, + VRAMGB: 8, + RAMGB: 16, + BandwidthMbps: 100, + SupportedTaskTypes: []string{string(scheduler.TaskFLRound)}, + } + sig := nf.Sign(nf.LNNodeURI, registry.TierT1, ts) + node, err := reg.Register(ctx, nf.PubKeyHex, nf.LNNodeURI, "", caps, sig, ts) + if err != nil { + t.Fatalf("Register: %v", err) + } + if err := reg.Activate(ctx, node.NodeID); err != nil { + t.Fatalf("Activate: %v", err) + } + + sched := scheduler.New(pool, reg, nil, testutil.Logger()) + + // Scheduling an unsupported task type must fail — no eligible nodes. + _, err = sched.Schedule(ctx, scheduler.TaskInference, "hash-inference", 60) + if err == nil { + t.Fatal("expected error: no node supports inference, but Schedule succeeded") + } + + // Scheduling the supported task type must succeed. + a, err := sched.Schedule(ctx, scheduler.TaskFLRound, "hash-fl", 120) + if err != nil { + t.Fatalf("Schedule fl_round: %v", err) + } + if a.NodeID != node.NodeID { + t.Errorf("assigned node: got %s, want %s", a.NodeID, node.NodeID) + } +} diff --git a/owm-coordinator/migrations/000010_add_supported_task_types.down.sql b/owm-coordinator/migrations/000010_add_supported_task_types.down.sql new file mode 100644 index 0000000..a3a88c3 --- /dev/null +++ b/owm-coordinator/migrations/000010_add_supported_task_types.down.sql @@ -0,0 +1 @@ +ALTER TABLE nodes DROP COLUMN IF EXISTS supported_task_types; diff --git a/owm-coordinator/migrations/000010_add_supported_task_types.up.sql b/owm-coordinator/migrations/000010_add_supported_task_types.up.sql new file mode 100644 index 0000000..1bc05c1 --- /dev/null +++ b/owm-coordinator/migrations/000010_add_supported_task_types.up.sql @@ -0,0 +1,4 @@ +-- Persist the task types each node supports so the scheduler can enforce +-- eligibility at assignment time (SRS-SCHED-01). +ALTER TABLE nodes + ADD COLUMN IF NOT EXISTS supported_task_types TEXT[] NOT NULL DEFAULT '{}'; From 9028b4da2b73e52df5120efd6cc33dd1031d0bf9 Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 21:25:34 -0400 Subject: [PATCH 3/8] coord-T11: config validation hardening (Gap #5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend validate() with four new rule groups: Tier minimums (BRS-POS-02): - Each of t1/t2/t3 must meet its BRS floor (100k/500k/2M sats) - Ordering t1 ≤ t2 ≤ t3 is enforced; equal boundaries are allowed - Nil map is caught cleanly (all values read as 0) Stake operational config: - verify_interval_hours, degraded_grace_period_hours, slash_cooldown_days ≥ 1 - t1_auto_slash_signals ≥ 3 and t2t3_maintainer_acks ≥ 2 (SRS-STAKE-04) FL config bounds: - min_participants ≥ 2, gradient_l2_clip_norm > 0, anomaly_std_dev_threshold > 0, round_interval_minutes ≥ 1, top_k_sparsification_pct ∈ [0, 1] S3 group check: - If any of endpoint/bucket/access_key/secret_key is non-empty, all four must be set (partial config is rejected) Add validate_test.go: 31 targeted tests (29 new), each mutating exactly one field of a passing base config to verify the specific error message substring. Co-Authored-By: Claude Sonnet 4.6 --- owm-coordinator/internal/config/config.go | 70 +++++ .../internal/config/validate_test.go | 241 ++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100644 owm-coordinator/internal/config/validate_test.go diff --git a/owm-coordinator/internal/config/config.go b/owm-coordinator/internal/config/config.go index 22d3af2..530c1d0 100644 --- a/owm-coordinator/internal/config/config.go +++ b/owm-coordinator/internal/config/config.go @@ -248,5 +248,75 @@ func (c *Config) validate() error { if c.Observer.Enabled && strings.TrimSpace(c.Observer.APIEndpoint) == "" { return fmt.Errorf("observer.api_endpoint is required when observer.enabled is true") } + + // ── Stake tier minimums (BRS-POS-02) ───────────────────────────────────── + // Each tier's configured minimum must be at or above the BRS floor. + tierFloors := []struct { + key string + floor int64 + }{ + {"t1", 100_000}, + {"t2", 500_000}, + {"t3", 2_000_000}, + } + for _, tf := range tierFloors { + val := c.Stake.TierMinimumSats[tf.key] + if val < tf.floor { + return fmt.Errorf("stake.tier_minimum_sats.%s must be ≥ %d sats (BRS-POS-02), got %d", + tf.key, tf.floor, val) + } + } + t1, t2, t3 := c.Stake.TierMinimumSats["t1"], c.Stake.TierMinimumSats["t2"], c.Stake.TierMinimumSats["t3"] + if t1 > t2 || t2 > t3 { + return fmt.Errorf("stake tier minimums must be ordered t1 ≤ t2 ≤ t3 (got t1=%d t2=%d t3=%d)", t1, t2, t3) + } + + // ── Stake operational config ────────────────────────────────────────────── + if c.Stake.VerifyIntervalHours < 1 { + return fmt.Errorf("stake.verify_interval_hours must be ≥ 1, got %d", c.Stake.VerifyIntervalHours) + } + if c.Stake.DegradedGracePeriodHours < 1 { + return fmt.Errorf("stake.degraded_grace_period_hours must be ≥ 1, got %d", c.Stake.DegradedGracePeriodHours) + } + if c.Stake.SlashCooldownDays < 1 { + return fmt.Errorf("stake.slash_cooldown_days must be ≥ 1, got %d", c.Stake.SlashCooldownDays) + } + // SRS-STAKE-04: T1 requires ≥ 3 signals; T2/T3 requires ≥ 2 maintainer acks. + if c.Stake.T1AutoSlashSignals < 3 { + return fmt.Errorf("stake.t1_auto_slash_signals must be ≥ 3 (SRS-STAKE-04), got %d", c.Stake.T1AutoSlashSignals) + } + if c.Stake.T2T3MaintainerAcks < 2 { + return fmt.Errorf("stake.t2t3_maintainer_acks must be ≥ 2 (SRS-STAKE-04), got %d", c.Stake.T2T3MaintainerAcks) + } + + // ── FL config bounds ────────────────────────────────────────────────────── + if c.FL.MinParticipants < 2 { + return fmt.Errorf("fl.min_participants must be ≥ 2, got %d", c.FL.MinParticipants) + } + if c.FL.GradientL2ClipNorm <= 0 { + return fmt.Errorf("fl.gradient_l2_clip_norm must be > 0, got %g", c.FL.GradientL2ClipNorm) + } + if c.FL.AnomalyStdDevThreshold <= 0 { + return fmt.Errorf("fl.anomaly_std_dev_threshold must be > 0, got %g", c.FL.AnomalyStdDevThreshold) + } + if c.FL.RoundIntervalMinutes < 1 { + return fmt.Errorf("fl.round_interval_minutes must be ≥ 1, got %d", c.FL.RoundIntervalMinutes) + } + if c.FL.TopKSparsificationPct < 0 || c.FL.TopKSparsificationPct > 1 { + return fmt.Errorf("fl.top_k_sparsification_pct must be in [0, 1], got %g", c.FL.TopKSparsificationPct) + } + + // ── S3 group check ──────────────────────────────────────────────────────── + // S3 credentials are optional, but if any field is provided all four must be. + s3Count := 0 + for _, v := range []string{c.S3.Endpoint, c.S3.Bucket, c.S3.AccessKey, c.S3.SecretKey} { + if v != "" { + s3Count++ + } + } + if s3Count > 0 && s3Count < 4 { + return fmt.Errorf("s3 is partially configured: endpoint, bucket, access_key, and secret_key must all be set together (or all left empty)") + } + return nil } diff --git a/owm-coordinator/internal/config/validate_test.go b/owm-coordinator/internal/config/validate_test.go new file mode 100644 index 0000000..4125eed --- /dev/null +++ b/owm-coordinator/internal/config/validate_test.go @@ -0,0 +1,241 @@ +package config + +import ( + "strings" + "testing" +) + +// validBase returns a minimal Config that passes every validate() check. +// Individual tests mutate one field at a time to trigger a specific error. +func validBase() *Config { + return &Config{ + DevMode: true, // skip Lightning credential checks + Database: DatabaseConfig{DSN: "postgres://localhost/test"}, + Lightning: LightningConfig{Backend: "lnd"}, + FL: FLConfig{ + MinParticipants: 2, + GradientL2ClipNorm: 1.0, + AnomalyStdDevThreshold: 3.0, + RoundIntervalMinutes: 1, + TopKSparsificationPct: 0.10, + }, + Stake: StakeConfig{ + TierMinimumSats: map[string]int64{ + "t1": 100_000, "t2": 500_000, "t3": 2_000_000, + }, + VerifyIntervalHours: 1, + DegradedGracePeriodHours: 1, + SlashCooldownDays: 1, + T1AutoSlashSignals: 3, + T2T3MaintainerAcks: 2, + }, + } +} + +func mustFail(t *testing.T, cfg *Config, wantSubstr string) { + t.Helper() + err := cfg.validate() + if err == nil { + t.Fatalf("expected validation error containing %q, got nil", wantSubstr) + } + if !strings.Contains(err.Error(), wantSubstr) { + t.Fatalf("error %q does not contain %q", err.Error(), wantSubstr) + } +} + +func mustPass(t *testing.T, cfg *Config) { + t.Helper() + if err := cfg.validate(); err != nil { + t.Fatalf("unexpected validation error: %v", err) + } +} + +// ── Tier minimums ───────────────────────────────────────────────────────────── + +func TestValidate_TierFloor_T1(t *testing.T) { + cfg := validBase() + cfg.Stake.TierMinimumSats["t1"] = 99_999 + mustFail(t, cfg, "tier_minimum_sats.t1") +} + +func TestValidate_TierFloor_T2(t *testing.T) { + cfg := validBase() + cfg.Stake.TierMinimumSats["t2"] = 499_999 + mustFail(t, cfg, "tier_minimum_sats.t2") +} + +func TestValidate_TierFloor_T3(t *testing.T) { + cfg := validBase() + cfg.Stake.TierMinimumSats["t3"] = 1_999_999 + mustFail(t, cfg, "tier_minimum_sats.t3") +} + +func TestValidate_TierFloor_NilMap(t *testing.T) { + cfg := validBase() + cfg.Stake.TierMinimumSats = nil // all tiers read as 0 → below every floor + mustFail(t, cfg, "tier_minimum_sats.t1") +} + +func TestValidate_TierOrdering_T1GtT2(t *testing.T) { + cfg := validBase() + cfg.Stake.TierMinimumSats["t1"] = 600_000 // above t2 + mustFail(t, cfg, "t1 ≤ t2 ≤ t3") +} + +func TestValidate_TierOrdering_T2GtT3(t *testing.T) { + cfg := validBase() + cfg.Stake.TierMinimumSats["t2"] = 3_000_000 // above t3 + mustFail(t, cfg, "t1 ≤ t2 ≤ t3") +} + +func TestValidate_TierOrdering_EqualBoundariesOK(t *testing.T) { + // t1 == t2 == t3 is allowed by the ordering check (≤ not <). + cfg := validBase() + cfg.Stake.TierMinimumSats["t1"] = 2_000_000 + cfg.Stake.TierMinimumSats["t2"] = 2_000_000 + cfg.Stake.TierMinimumSats["t3"] = 2_000_000 + mustPass(t, cfg) +} + +// ── Stake operational config ────────────────────────────────────────────────── + +func TestValidate_Stake_VerifyIntervalZero(t *testing.T) { + cfg := validBase() + cfg.Stake.VerifyIntervalHours = 0 + mustFail(t, cfg, "verify_interval_hours") +} + +func TestValidate_Stake_DegradedGracePeriodZero(t *testing.T) { + cfg := validBase() + cfg.Stake.DegradedGracePeriodHours = 0 + mustFail(t, cfg, "degraded_grace_period_hours") +} + +func TestValidate_Stake_SlashCooldownZero(t *testing.T) { + cfg := validBase() + cfg.Stake.SlashCooldownDays = 0 + mustFail(t, cfg, "slash_cooldown_days") +} + +func TestValidate_Stake_T1AutoSlashSignalsTooLow(t *testing.T) { + cfg := validBase() + cfg.Stake.T1AutoSlashSignals = 2 + mustFail(t, cfg, "t1_auto_slash_signals") +} + +func TestValidate_Stake_T2T3MaintainerAcksTooLow(t *testing.T) { + cfg := validBase() + cfg.Stake.T2T3MaintainerAcks = 1 + mustFail(t, cfg, "t2t3_maintainer_acks") +} + +func TestValidate_Stake_HigherThresholdsOK(t *testing.T) { + cfg := validBase() + cfg.Stake.T1AutoSlashSignals = 10 // more conservative than SRS minimum + cfg.Stake.T2T3MaintainerAcks = 5 + mustPass(t, cfg) +} + +// ── FL config bounds ────────────────────────────────────────────────────────── + +func TestValidate_FL_MinParticipantsTooLow(t *testing.T) { + cfg := validBase() + cfg.FL.MinParticipants = 1 + mustFail(t, cfg, "min_participants") +} + +func TestValidate_FL_MinParticipantsExactly2(t *testing.T) { + cfg := validBase() + cfg.FL.MinParticipants = 2 + mustPass(t, cfg) +} + +func TestValidate_FL_GradientClipNormZero(t *testing.T) { + cfg := validBase() + cfg.FL.GradientL2ClipNorm = 0 + mustFail(t, cfg, "gradient_l2_clip_norm") +} + +func TestValidate_FL_GradientClipNormNegative(t *testing.T) { + cfg := validBase() + cfg.FL.GradientL2ClipNorm = -1 + mustFail(t, cfg, "gradient_l2_clip_norm") +} + +func TestValidate_FL_AnomalyThresholdZero(t *testing.T) { + cfg := validBase() + cfg.FL.AnomalyStdDevThreshold = 0 + mustFail(t, cfg, "anomaly_std_dev_threshold") +} + +func TestValidate_FL_RoundIntervalZero(t *testing.T) { + cfg := validBase() + cfg.FL.RoundIntervalMinutes = 0 + mustFail(t, cfg, "round_interval_minutes") +} + +func TestValidate_FL_TopKPctNegative(t *testing.T) { + cfg := validBase() + cfg.FL.TopKSparsificationPct = -0.01 + mustFail(t, cfg, "top_k_sparsification_pct") +} + +func TestValidate_FL_TopKPctAbove1(t *testing.T) { + cfg := validBase() + cfg.FL.TopKSparsificationPct = 1.01 + mustFail(t, cfg, "top_k_sparsification_pct") +} + +func TestValidate_FL_TopKPctZeroOK(t *testing.T) { + cfg := validBase() + cfg.FL.TopKSparsificationPct = 0 // 0 = sparsification disabled + mustPass(t, cfg) +} + +func TestValidate_FL_TopKPctOneOK(t *testing.T) { + cfg := validBase() + cfg.FL.TopKSparsificationPct = 1.0 // keep all + mustPass(t, cfg) +} + +// ── S3 group check ──────────────────────────────────────────────────────────── + +func TestValidate_S3_AllEmptyOK(t *testing.T) { + cfg := validBase() // S3 fields are all zero-value → OK + mustPass(t, cfg) +} + +func TestValidate_S3_AllSetOK(t *testing.T) { + cfg := validBase() + cfg.S3 = S3Config{ + Endpoint: "https://s3.example.com", + Bucket: "owm-data", + AccessKey: "AKID", + SecretKey: "secret", + } + mustPass(t, cfg) +} + +func TestValidate_S3_PartialMissingBucket(t *testing.T) { + cfg := validBase() + cfg.S3 = S3Config{Endpoint: "https://s3.example.com", AccessKey: "AKID", SecretKey: "secret"} + mustFail(t, cfg, "s3 is partially configured") +} + +func TestValidate_S3_PartialOnlyEndpoint(t *testing.T) { + cfg := validBase() + cfg.S3 = S3Config{Endpoint: "https://s3.example.com"} + mustFail(t, cfg, "s3 is partially configured") +} + +func TestValidate_S3_PartialOnlyBucket(t *testing.T) { + cfg := validBase() + cfg.S3 = S3Config{Bucket: "owm-data"} + mustFail(t, cfg, "s3 is partially configured") +} + +// ── Full valid config ───────────────────────────────────────────────────────── + +func TestValidate_BaseConfigPasses(t *testing.T) { + mustPass(t, validBase()) +} From 16e6ebf42ed717a0f8bbb1c6117cfa39268e3910 Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 21:46:18 -0400 Subject: [PATCH 4/8] coord-T12: reliability scoring with uptime fraction (Gap #6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements SRS-SCHED-04: reliability = task_success_fraction × uptime_fraction over a rolling 7-day window (clamped to node registration time). Migration 000011: heartbeat_log table (node_id, recorded_at) with a DESC index on (node_id, recorded_at) for efficient 7-day window queries. Includes a pruning note for rows older than 8 days. RecordHeartbeat: now appends to heartbeat_log on every heartbeat, preserving the existing last_heartbeat update in the nodes table. UpdateReliability: replaced the placeholder with three focused scalar queries plus the Go-side computeReliability helper: - task_success_fraction = completed / max(total, 1) [1.0 if no tasks yet] - uptime_fraction = received_hb / expected_hb [1.0 if window < 60 s] - expected_hb = windowDuration / 60 s (SRS-NODE-04 rate) - reliability = clamp(task × uptime, 0, 1) computeReliability is a pure function — no DB, fully deterministic. reliability_test.go (package registry): 10 unit tests covering normal, edge (new node, no heartbeats, extra heartbeats, sub-interval window, zero tasks) and boundary cases; plus 2 integration tests for the full DB path. Co-Authored-By: Claude Sonnet 4.6 --- owm-coordinator/internal/registry/registry.go | 108 +++++++-- .../internal/registry/reliability_test.go | 224 ++++++++++++++++++ .../migrations/000011_heartbeat_log.down.sql | 2 + .../migrations/000011_heartbeat_log.up.sql | 13 + 4 files changed, 330 insertions(+), 17 deletions(-) create mode 100644 owm-coordinator/internal/registry/reliability_test.go create mode 100644 owm-coordinator/migrations/000011_heartbeat_log.down.sql create mode 100644 owm-coordinator/migrations/000011_heartbeat_log.up.sql diff --git a/owm-coordinator/internal/registry/registry.go b/owm-coordinator/internal/registry/registry.go index f407b54..46bc935 100644 --- a/owm-coordinator/internal/registry/registry.go +++ b/owm-coordinator/internal/registry/registry.go @@ -180,8 +180,8 @@ func (r *Registry) Activate(ctx context.Context, nodeID uuid.UUID) error { return nil } -// RecordHeartbeat updates last_heartbeat and node metrics, returning the -// number of pending tasks for that node. +// RecordHeartbeat updates last_heartbeat, appends to heartbeat_log for uptime +// tracking (SRS-SCHED-04), and returns the number of pending tasks for that node. func (r *Registry) RecordHeartbeat(ctx context.Context, nodeID uuid.UUID) (pendingTasks int, err error) { now := time.Now().UTC() err = r.db.QueryRow(ctx, @@ -189,6 +189,13 @@ func (r *Registry) RecordHeartbeat(ctx context.Context, nodeID uuid.UUID) (pendi RETURNING (SELECT count(*) FROM tasks WHERE assigned_node = $2 AND status = 'pending')`, now, nodeID, ).Scan(&pendingTasks) + if err != nil { + return 0, err + } + _, err = r.db.Exec(ctx, + `INSERT INTO heartbeat_log (node_id, recorded_at) VALUES ($1, $2)`, + nodeID, now, + ) return pendingTasks, err } @@ -259,29 +266,96 @@ func (r *Registry) UpdateStatus(ctx context.Context, nodeID uuid.UUID, status st return err } -// UpdateReliability recalculates and persists a node's reliability score. -// reliability = (successful_tasks / total_tasks) * uptime_fraction (rolling 7d). +// UpdateReliability recalculates and persists a node's reliability score using +// the full SRS-SCHED-04 formula over a rolling 7-day window: +// +// reliability = task_success_fraction × uptime_fraction +// +// uptime_fraction is derived from heartbeat_log (SRS-NODE-04: 60 s interval). +// success=true increments the total_tasks counter; false does not (the task +// completion is already recorded in the tasks table and counted there). func (r *Registry) UpdateReliability(ctx context.Context, nodeID uuid.UUID, success bool) error { - var col string + // Clamp the window start to the node's registration time so brand-new nodes + // are not penalised for the days before they existed. + var registeredAt time.Time + if err := r.db.QueryRow(ctx, + `SELECT registered_at FROM nodes WHERE node_id = $1`, nodeID, + ).Scan(®isteredAt); err != nil { + return fmt.Errorf("fetching registration time: %w", err) + } + windowStart := time.Now().UTC().Add(-7 * 24 * time.Hour) + if registeredAt.After(windowStart) { + windowStart = registeredAt + } + + // Count tasks in the window. + var totalTasks, completedTasks int64 + if err := r.db.QueryRow(ctx, + `SELECT COUNT(*), COUNT(*) FILTER (WHERE status = 'completed') + FROM tasks WHERE assigned_node = $1 AND submitted_at > $2`, + nodeID, windowStart, + ).Scan(&totalTasks, &completedTasks); err != nil { + return fmt.Errorf("counting tasks: %w", err) + } + + // Count heartbeats in the window. + var receivedHB int64 + if err := r.db.QueryRow(ctx, + `SELECT COUNT(*) FROM heartbeat_log WHERE node_id = $1 AND recorded_at > $2`, + nodeID, windowStart, + ).Scan(&receivedHB); err != nil { + return fmt.Errorf("counting heartbeats: %w", err) + } + + reliability := computeReliability(completedTasks, totalTasks, receivedHB, time.Since(windowStart)) + + successIncr := 0 if success { - col = "total_tasks = total_tasks + 1" - } else { - col = "reliability = reliability" + successIncr = 1 } - // Simplified reliability update; a full implementation uses a time-windowed query. _, err := r.db.Exec(ctx, - fmt.Sprintf(`UPDATE nodes SET %s, reliability = ( - SELECT COALESCE( - COUNT(*) FILTER (WHERE status = 'completed')::NUMERIC / - NULLIF(COUNT(*), 0), 1.0 - ) FROM tasks WHERE assigned_node = $1 - AND submitted_at > now() - INTERVAL '7 days' - ) WHERE node_id = $1`, col), - nodeID, + `UPDATE nodes SET reliability = $2, total_tasks = total_tasks + $3 WHERE node_id = $1`, + nodeID, reliability, successIncr, ) return err } +// computeReliability implements SRS-SCHED-04: +// +// reliability = task_success_fraction × uptime_fraction +// +// task_success_fraction = completed / max(total, 1) [1.0 when no tasks yet] +// uptime_fraction = receivedHB / expectedHB [1.0 when window < 1 interval] +// expectedHB = windowDuration / 60 s (SRS-NODE-04 heartbeat rate) +// +// Both fractions are clamped to [0, 1] and the result is clamped to [0, 1]. +func computeReliability(completed, total, receivedHB int64, windowDuration time.Duration) float64 { + const heartbeatIntervalSecs = 60.0 + + taskFraction := 1.0 + if total > 0 { + taskFraction = float64(completed) / float64(total) + } + + expectedHB := windowDuration.Seconds() / heartbeatIntervalSecs + uptimeFraction := 1.0 + if expectedHB >= 1.0 { + uptimeFraction = float64(receivedHB) / expectedHB + if uptimeFraction > 1 { + uptimeFraction = 1 + } + } + + r := taskFraction * uptimeFraction + if r < 0 { + return 0 + } + if r > 1 { + return 1 + } + return r +} + // IsSuspendedOrCoolingDown returns true if the node's public key is currently // suspended and the slashing cooldown has not expired. func (r *Registry) IsSuspendedOrCoolingDown(ctx context.Context, pubKeyHex string) (bool, error) { diff --git a/owm-coordinator/internal/registry/reliability_test.go b/owm-coordinator/internal/registry/reliability_test.go new file mode 100644 index 0000000..e65caca --- /dev/null +++ b/owm-coordinator/internal/registry/reliability_test.go @@ -0,0 +1,224 @@ +package registry + +// computeReliability is unexported; these tests live in package registry. + +import ( + "context" + "testing" + "time" + + "github.com/owmnetwork/owm-coordinator/internal/testutil" +) + +// ── computeReliability unit tests ───────────────────────────────────────────── + +func TestComputeReliability_NoTasksNoHeartbeats_NewNode(t *testing.T) { + // Brand-new node: window < 1 heartbeat interval → uptime defaults to 1.0. + // No tasks → task fraction defaults to 1.0. Result: 1.0. + got := computeReliability(0, 0, 0, 30*time.Second) + if got != 1.0 { + t.Errorf("got %f, want 1.0", got) + } +} + +func TestComputeReliability_PerfectRecord(t *testing.T) { + // 10/10 tasks completed, heartbeat every 60 s for 1 hour (60 received, 60 expected). + got := computeReliability(10, 10, 60, time.Hour) + if got != 1.0 { + t.Errorf("got %f, want 1.0", got) + } +} + +func TestComputeReliability_HalfUptimeFullSuccess(t *testing.T) { + // All tasks succeeded but node was online only 50 % of the time. + // 1 hour window → 60 expected heartbeats; only 30 received. + got := computeReliability(5, 5, 30, time.Hour) + const want = 0.5 + if got != want { + t.Errorf("got %f, want %f", got, want) + } +} + +func TestComputeReliability_HalfSuccessFullUptime(t *testing.T) { + // Node was always online but only completed half its tasks. + got := computeReliability(5, 10, 60, time.Hour) + const want = 0.5 + if got != want { + t.Errorf("got %f, want %f", got, want) + } +} + +func TestComputeReliability_HalfUptimeHalfSuccess(t *testing.T) { + // 0.5 * 0.5 = 0.25 + got := computeReliability(5, 10, 30, time.Hour) + const want = 0.25 + if got != want { + t.Errorf("got %f, want %f", got, want) + } +} + +func TestComputeReliability_ZeroTasksWithHeartbeats(t *testing.T) { + // No tasks yet → task fraction = 1.0; uptime is real. + got := computeReliability(0, 0, 30, time.Hour) // 50 % uptime + const want = 0.5 + if got != want { + t.Errorf("got %f, want %f", got, want) + } +} + +func TestComputeReliability_NoHeartbeatsLongWindow(t *testing.T) { + // Node has been around for a day but sent zero heartbeats → uptime = 0. + got := computeReliability(10, 10, 0, 24*time.Hour) + if got != 0.0 { + t.Errorf("got %f, want 0.0", got) + } +} + +func TestComputeReliability_MoreHeartbeatsThanExpected(t *testing.T) { + // Extra heartbeats (e.g. from retries) must not push reliability above 1.0. + got := computeReliability(10, 10, 200, time.Hour) // expected 60, received 200 + if got != 1.0 { + t.Errorf("got %f, want 1.0 (capped)", got) + } +} + +func TestComputeReliability_WindowExactlyOneInterval(t *testing.T) { + // Edge: exactly 60 s window → expectedHB = 1.0. + // 1 heartbeat received → uptime = 1.0; 1/1 tasks completed → 1.0. + got := computeReliability(1, 1, 1, 60*time.Second) + if got != 1.0 { + t.Errorf("got %f, want 1.0", got) + } +} + +func TestComputeReliability_WindowSlightlyBelowOneInterval(t *testing.T) { + // < 60 s window → expectedHB < 1 → uptime defaults to 1.0. + got := computeReliability(0, 0, 0, 59*time.Second) + if got != 1.0 { + t.Errorf("got %f, want 1.0 (window too small, defaults)", got) + } +} + +// ── Integration: UpdateReliability with real DB ─────────────────────────────── + +func TestUpdateReliabilityIntegration(t *testing.T) { + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + reg := New(pool, testutil.Logger()) + nf := testutil.NewNodeFixture(t) + + // Register and activate the node. + ts := time.Now().Unix() + caps := NodeCapabilities{Tier: TierT1, VRAMGB: 8, RAMGB: 16, BandwidthMbps: 100} + sig := nf.Sign(nf.LNNodeURI, TierT1, ts) + node, err := reg.Register(ctx, nf.PubKeyHex, nf.LNNodeURI, "", caps, sig, ts) + if err != nil { + t.Fatalf("Register: %v", err) + } + if err := reg.Activate(ctx, node.NodeID); err != nil { + t.Fatalf("Activate: %v", err) + } + + // Seed 60 heartbeats directly (representing ~1 hour of uptime). + base := time.Now().UTC().Add(-61 * time.Minute) + for i := 0; i < 60; i++ { + _, err := pool.Exec(ctx, + `INSERT INTO heartbeat_log (node_id, recorded_at) VALUES ($1, $2)`, + node.NodeID, base.Add(time.Duration(i)*time.Minute), + ) + if err != nil { + t.Fatalf("insert heartbeat %d: %v", i, err) + } + } + + // Seed 8 tasks: 6 completed, 2 failed — task_fraction = 0.75. + insertTask := func(status string) { + _, err := pool.Exec(ctx, + `INSERT INTO tasks + (task_id, task_type, assigned_node, node_ln_uri, status, + input_hash, reward_sats, submitted_at, started_at, timeout_seconds) + VALUES (gen_random_uuid(), 'inference', $1, $2, $3, + 'hash', 10, now() - interval '30 minutes', now() - interval '30 minutes', 300)`, + node.NodeID, node.LNNodeURI, status, + ) + if err != nil { + t.Fatalf("insert task (%s): %v", status, err) + } + } + for i := 0; i < 6; i++ { + insertTask("completed") + } + for i := 0; i < 2; i++ { + insertTask("failed") + } + + // UpdateReliability with success=true for the last task. + if err := reg.UpdateReliability(ctx, node.NodeID, true); err != nil { + t.Fatalf("UpdateReliability: %v", err) + } + + // Fetch the persisted reliability score. + var reliability float64 + if err := pool.QueryRow(ctx, + `SELECT reliability FROM nodes WHERE node_id = $1`, node.NodeID, + ).Scan(&reliability); err != nil { + t.Fatalf("fetching reliability: %v", err) + } + + // task_fraction ≈ 6/8 = 0.75; uptime_fraction ≈ 60/61 ≈ 0.98 → reliability ≈ 0.735 + // We use a tolerance of ±0.05 to account for timing jitter. + const wantApprox = 0.75 + const tolerance = 0.05 + if reliability < wantApprox-tolerance || reliability > wantApprox+tolerance { + t.Errorf("reliability = %f, want %f ± %f", reliability, wantApprox, tolerance) + } + if reliability <= 0 || reliability > 1 { + t.Errorf("reliability %f out of [0,1] range", reliability) + } +} + +func TestRecordHeartbeatIntegration(t *testing.T) { + pool := testutil.MustDB(t) + testutil.TruncateAll(t, pool) + + ctx := context.Background() + reg := New(pool, testutil.Logger()) + nf := testutil.NewNodeFixture(t) + + ts := time.Now().Unix() + caps := NodeCapabilities{Tier: TierT1, VRAMGB: 8, RAMGB: 16, BandwidthMbps: 100} + sig := nf.Sign(nf.LNNodeURI, TierT1, ts) + node, err := reg.Register(ctx, nf.PubKeyHex, nf.LNNodeURI, "", caps, sig, ts) + if err != nil { + t.Fatalf("Register: %v", err) + } + + _, err = reg.RecordHeartbeat(ctx, node.NodeID) + if err != nil { + t.Fatalf("RecordHeartbeat: %v", err) + } + + // Verify one row was inserted into heartbeat_log. + var count int + if err := pool.QueryRow(ctx, + `SELECT COUNT(*) FROM heartbeat_log WHERE node_id = $1`, node.NodeID, + ).Scan(&count); err != nil { + t.Fatalf("count heartbeat_log: %v", err) + } + if count != 1 { + t.Errorf("heartbeat_log rows: got %d, want 1", count) + } + + // Verify last_heartbeat was updated in the nodes table. + var lastHB *time.Time + if err := pool.QueryRow(ctx, + `SELECT last_heartbeat FROM nodes WHERE node_id = $1`, node.NodeID, + ).Scan(&lastHB); err != nil { + t.Fatalf("fetch last_heartbeat: %v", err) + } + if lastHB == nil { + t.Error("last_heartbeat was not updated in nodes table") + } +} diff --git a/owm-coordinator/migrations/000011_heartbeat_log.down.sql b/owm-coordinator/migrations/000011_heartbeat_log.down.sql new file mode 100644 index 0000000..79730a8 --- /dev/null +++ b/owm-coordinator/migrations/000011_heartbeat_log.down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_heartbeat_log_node_at; +DROP TABLE IF EXISTS heartbeat_log; diff --git a/owm-coordinator/migrations/000011_heartbeat_log.up.sql b/owm-coordinator/migrations/000011_heartbeat_log.up.sql new file mode 100644 index 0000000..931bc29 --- /dev/null +++ b/owm-coordinator/migrations/000011_heartbeat_log.up.sql @@ -0,0 +1,13 @@ +-- Records every heartbeat received from a node (SRS-NODE-04: 60-second interval). +-- Used by UpdateReliability to compute the rolling 7-day uptime fraction per +-- SRS-SCHED-04: reliability = task_success_fraction × uptime_fraction. +-- +-- Retention: rows older than 8 days are safe to prune with a scheduled job: +-- DELETE FROM heartbeat_log WHERE recorded_at < now() - INTERVAL '8 days'; +CREATE TABLE IF NOT EXISTS heartbeat_log ( + node_id UUID NOT NULL REFERENCES nodes(node_id) ON DELETE CASCADE, + recorded_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_heartbeat_log_node_at + ON heartbeat_log (node_id, recorded_at DESC); From 48d20648708a734d5ea95bccba13c898d957604f Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 22:28:28 -0400 Subject: [PATCH 5/8] ci: fix four CI failures (SSL, mTLS, observer key, gRPC deprecation) - testutil/db.go: append sslmode=disable to the migration DSN so golang-migrate's pq driver works against the CI Postgres service container (no TLS configured) - rpc/server_test.go: also accept "broken pipe" as a valid mTLS rejection indicator; Linux TCP stack can reset the connection before the TLS handshake message reaches the client - observer/client.go: check raw-bytes length before bytes.TrimSpace in decodeSigningKey; binary keys whose first/last byte is ASCII whitespace were silently truncated, causing intermittent NewClient failures - lightning/lnd_client.go: replace deprecated grpc.DialContext+WithBlock with grpc.NewClient (lazy connect, no startup timeout) Co-Authored-By: Claude Sonnet 4.6 --- owm-coordinator/internal/lightning/lnd_client.go | 4 +--- owm-coordinator/internal/observer/client.go | 9 +++++---- owm-coordinator/internal/rpc/server_test.go | 10 +++++++--- owm-coordinator/internal/testutil/db.go | 13 ++++++++++++- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/owm-coordinator/internal/lightning/lnd_client.go b/owm-coordinator/internal/lightning/lnd_client.go index 30663b7..4870b40 100644 --- a/owm-coordinator/internal/lightning/lnd_client.go +++ b/owm-coordinator/internal/lightning/lnd_client.go @@ -46,9 +46,7 @@ func NewLNDClient(host, tlsCertPath string, macaroonBytes []byte) (*LNDClient, e opts = append(opts, grpc.WithPerRPCCredentials(&macaroonCredential{macaroonHex: macHex})) } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, host, append(opts, grpc.WithBlock())...) + conn, err := grpc.NewClient(host, opts...) if err != nil { return nil, fmt.Errorf("dialing LND: %w", err) } diff --git a/owm-coordinator/internal/observer/client.go b/owm-coordinator/internal/observer/client.go index 98f4b72..4712b1d 100644 --- a/owm-coordinator/internal/observer/client.go +++ b/owm-coordinator/internal/observer/client.go @@ -130,6 +130,11 @@ func (c *Client) Submit(ctx context.Context, r Receipt) (receiptID string, err e // 32-byte seed material (64-char hex or 32 raw bytes) is expanded via Ed25519 // seed expansion so SignReceipt receives a full private key. func decodeSigningKey(data []byte) ([]byte, error) { + // Raw bytes: check BEFORE TrimSpace — binary keys can contain bytes that + // TrimSpace would strip (e.g. 0x0A, 0x20), corrupting the length. + if len(data) == 32 || len(data) == ed25519.PrivateKeySize { + return normalizeToFullPrivateKey(data) + } data = bytes.TrimSpace(data) // Hex: 64 chars = 32-byte seed, 128 chars = 64-byte full key. if len(data) >= 64 && len(data)%2 == 0 && hexEncoded(data) { @@ -143,10 +148,6 @@ func decodeSigningKey(data []byte) ([]byte, error) { if bytes.Contains(data, []byte("PRIVATE KEY")) { return decodePEMEd25519PrivateKey(data) } - // Raw bytes: 32 = seed, 64 = full key. - if len(data) == 32 || len(data) == ed25519.PrivateKeySize { - return normalizeToFullPrivateKey(data) - } return nil, fmt.Errorf("invalid key format: need 32- or 64-byte raw, 64- or 128-char hex, or PEM") } diff --git a/owm-coordinator/internal/rpc/server_test.go b/owm-coordinator/internal/rpc/server_test.go index 5ded7e7..e826ed9 100644 --- a/owm-coordinator/internal/rpc/server_test.go +++ b/owm-coordinator/internal/rpc/server_test.go @@ -155,10 +155,14 @@ func TestGRPCServer_RejectsClientWithoutCert(t *testing.T) { t.Fatalf("expected Unavailable (TLS handshake rejection), got %v: %v", st.Code(), rpcErr) } - // The error message should reference the TLS/handshake/certificate failure. + // The error message should reference TLS rejection. On some Linux kernels the + // server resets the TCP connection before the handshake completes, so the + // client sees "broken pipe" instead of a TLS-layer message — both indicate + // the server correctly refused the unauthenticated client. errMsg := strings.ToLower(rpcErr.Error()) - if !strings.Contains(errMsg, "handshake") && !strings.Contains(errMsg, "certificate") && !strings.Contains(errMsg, "tls") { - t.Fatalf("expected error to mention TLS/handshake/certificate, got: %v", rpcErr) + if !strings.Contains(errMsg, "handshake") && !strings.Contains(errMsg, "certificate") && + !strings.Contains(errMsg, "tls") && !strings.Contains(errMsg, "broken pipe") { + t.Fatalf("expected error to mention TLS/handshake/certificate/broken pipe, got: %v", rpcErr) } t.Logf("got expected error (mTLS rejection): %v", rpcErr) diff --git a/owm-coordinator/internal/testutil/db.go b/owm-coordinator/internal/testutil/db.go index ac78a85..96fa0d3 100644 --- a/owm-coordinator/internal/testutil/db.go +++ b/owm-coordinator/internal/testutil/db.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" "github.com/golang-migrate/migrate/v4" @@ -36,7 +37,17 @@ func MustDB(t *testing.T) *pgxpool.Pool { migrationsDir := filepath.Join(filepath.Dir(thisFile), "..", "..", "migrations") migrationsDir, _ = filepath.Abs(migrationsDir) - m, err := migrate.New("file://"+migrationsDir, dsn) + // golang-migrate uses the pq driver which requires sslmode=disable on + // servers without TLS (e.g. the CI Postgres service container). + migrDSN := dsn + if !strings.Contains(migrDSN, "sslmode=") { + if strings.Contains(migrDSN, "?") { + migrDSN += "&sslmode=disable" + } else { + migrDSN += "?sslmode=disable" + } + } + m, err := migrate.New("file://"+migrationsDir, migrDSN) if err != nil { t.Fatalf("testutil.MustDB: migrate.New: %v", err) } From b4d6bb8ea0d2130b12353be43d90d8d6b43702eb Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 22:53:17 -0400 Subject: [PATCH 6/8] ci: fix three remaining test failures - observer/client.go: gate the early raw-bytes check with !hexEncoded so 64-char hex strings (also 64 bytes) are correctly hex-decoded rather than treated as a raw key; fixes TestDecodeSigningKey_64CharSeedHex - registry/reliability_test.go: backdate registered_at by 2 hours after registration so the seeded historical data (heartbeats/tasks) falls inside the clamped window; fixes TestUpdateReliabilityIntegration - registry/registry.go: use COALESCE(onion_address,'') in ListActive and GetByPublicKey so a nullable TEXT column never causes "cannot scan NULL into *string"; fixes TestScheduleIntegration_TaskTypeMismatch Co-Authored-By: Claude Sonnet 4.6 --- owm-coordinator/internal/observer/client.go | 7 ++++--- owm-coordinator/internal/registry/registry.go | 4 ++-- owm-coordinator/internal/registry/reliability_test.go | 10 ++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/owm-coordinator/internal/observer/client.go b/owm-coordinator/internal/observer/client.go index 4712b1d..2a72f2a 100644 --- a/owm-coordinator/internal/observer/client.go +++ b/owm-coordinator/internal/observer/client.go @@ -130,9 +130,10 @@ func (c *Client) Submit(ctx context.Context, r Receipt) (receiptID string, err e // 32-byte seed material (64-char hex or 32 raw bytes) is expanded via Ed25519 // seed expansion so SignReceipt receives a full private key. func decodeSigningKey(data []byte) ([]byte, error) { - // Raw bytes: check BEFORE TrimSpace — binary keys can contain bytes that - // TrimSpace would strip (e.g. 0x0A, 0x20), corrupting the length. - if len(data) == 32 || len(data) == ed25519.PrivateKeySize { + // Raw binary: check BEFORE TrimSpace so whitespace bytes (0x0A, 0x20, …) + // at the key boundary are not stripped. Hex strings (all bytes in [0-9a-fA-F]) + // of the same length are intentionally excluded and handled below. + if (len(data) == 32 || len(data) == ed25519.PrivateKeySize) && !hexEncoded(data) { return normalizeToFullPrivateKey(data) } data = bytes.TrimSpace(data) diff --git a/owm-coordinator/internal/registry/registry.go b/owm-coordinator/internal/registry/registry.go index 46bc935..e9fd3fa 100644 --- a/owm-coordinator/internal/registry/registry.go +++ b/owm-coordinator/internal/registry/registry.go @@ -202,7 +202,7 @@ func (r *Registry) RecordHeartbeat(ctx context.Context, nodeID uuid.UUID) (pendi // GetByPublicKey retrieves a node by its Ed25519 public key. func (r *Registry) GetByPublicKey(ctx context.Context, pubKeyHex string) (*Node, error) { const q = ` - SELECT node_id, public_key, ln_node_uri, onion_address, tier, vram_gb, ram_gb, + SELECT node_id, public_key, ln_node_uri, COALESCE(onion_address, ''), tier, vram_gb, ram_gb, bandwidth_mbps, supported_task_types, reliability, total_tasks, total_sats, status, registered_at, last_heartbeat FROM nodes WHERE public_key = $1` @@ -223,7 +223,7 @@ func (r *Registry) GetByPublicKey(ctx context.Context, pubKeyHex string) (*Node, // ListActive returns all nodes currently in active status. func (r *Registry) ListActive(ctx context.Context) ([]*Node, error) { const q = ` - SELECT node_id, public_key, ln_node_uri, onion_address, tier, vram_gb, ram_gb, + SELECT node_id, public_key, ln_node_uri, COALESCE(onion_address, ''), tier, vram_gb, ram_gb, bandwidth_mbps, supported_task_types, reliability, total_tasks, total_sats, status, registered_at, last_heartbeat FROM nodes WHERE status = 'active' ORDER BY reliability DESC` diff --git a/owm-coordinator/internal/registry/reliability_test.go b/owm-coordinator/internal/registry/reliability_test.go index e65caca..38d51b9 100644 --- a/owm-coordinator/internal/registry/reliability_test.go +++ b/owm-coordinator/internal/registry/reliability_test.go @@ -121,6 +121,16 @@ func TestUpdateReliabilityIntegration(t *testing.T) { t.Fatalf("Activate: %v", err) } + // Backdate registered_at so the window start (max(now-7d, registered_at)) + // is 2 hours ago, allowing the historical heartbeat and task data below to + // fall inside the window. + if _, err := pool.Exec(ctx, + `UPDATE nodes SET registered_at = now() - interval '2 hours' WHERE node_id = $1`, + node.NodeID, + ); err != nil { + t.Fatalf("backdate registered_at: %v", err) + } + // Seed 60 heartbeats directly (representing ~1 hour of uptime). base := time.Now().UTC().Add(-61 * time.Minute) for i := 0; i < 60; i++ { From 278464f5b1e038756f21ebbed019391271d5179b Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 22:57:47 -0400 Subject: [PATCH 7/8] ci: fix TestUpdateReliabilityIntegration window math MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pin registered_at to base-1s (1 second before the earliest seeded heartbeat) so the rolling window is ~61 minutes. The previous fix used a 2-hour interval, giving expectedHB=120 against only 60 seeded heartbeats → uptime=0.5 → reliability=0.375, failing the ≈0.75 target. With a ~61-minute window: expectedHB≈61, uptime≈60/61≈0.98, reliability≈0.735, safely within ±0.05 of 0.75. Co-Authored-By: Claude Sonnet 4.6 --- .../internal/registry/reliability_test.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/owm-coordinator/internal/registry/reliability_test.go b/owm-coordinator/internal/registry/reliability_test.go index 38d51b9..7d1a5ca 100644 --- a/owm-coordinator/internal/registry/reliability_test.go +++ b/owm-coordinator/internal/registry/reliability_test.go @@ -121,18 +121,21 @@ func TestUpdateReliabilityIntegration(t *testing.T) { t.Fatalf("Activate: %v", err) } - // Backdate registered_at so the window start (max(now-7d, registered_at)) - // is 2 hours ago, allowing the historical heartbeat and task data below to - // fall inside the window. + // Define base before backdating so registered_at can be anchored to it. + // Heartbeats run from base to base+59m (i.e. now()-61m … now()-1m). + base := time.Now().UTC().Add(-61 * time.Minute) + + // Backdate registered_at to 1 second before the earliest heartbeat so the + // window is ~61 minutes and uptime_fraction ≈ 60/61 ≈ 0.98. Using a 2-hour + // interval would give a 120-minute window with only 60 heartbeats → 0.5. if _, err := pool.Exec(ctx, - `UPDATE nodes SET registered_at = now() - interval '2 hours' WHERE node_id = $1`, - node.NodeID, + `UPDATE nodes SET registered_at = $2 WHERE node_id = $1`, + node.NodeID, base.Add(-time.Second), ); err != nil { t.Fatalf("backdate registered_at: %v", err) } // Seed 60 heartbeats directly (representing ~1 hour of uptime). - base := time.Now().UTC().Add(-61 * time.Minute) for i := 0; i < 60; i++ { _, err := pool.Exec(ctx, `INSERT INTO heartbeat_log (node_id, recorded_at) VALUES ($1, $2)`, From 12fc2b370aee79c8d33d81fc0ef8b394c2ed6338 Mon Sep 17 00:00:00 2001 From: CodeByMAB Date: Sun, 17 May 2026 23:18:02 -0400 Subject: [PATCH 8/8] ci: align protobuf generated file headers with CI's protoc version CI installs protoc via apt-get (v3.21.12) and protoc-gen-go-grpc (v1.6.2), then runs git diff --exit-code to guard against stale stubs. The committed files were generated locally with v7.34.0/v1.6.1, causing a spurious diff on every run. Update the header comment-only version strings to match what CI regenerates so the diff check is clean. No functional code changes in the generated files. Co-Authored-By: Claude Sonnet 4.6 --- owm-coordinator/proto/coordinator/v1/coordinator.pb.go | 2 +- owm-coordinator/proto/coordinator/v1/coordinator_grpc.pb.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/owm-coordinator/proto/coordinator/v1/coordinator.pb.go b/owm-coordinator/proto/coordinator/v1/coordinator.pb.go index 2729277..47fff04 100644 --- a/owm-coordinator/proto/coordinator/v1/coordinator.pb.go +++ b/owm-coordinator/proto/coordinator/v1/coordinator.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v7.34.0 +// protoc v3.21.12 // source: proto/coordinator/v1/coordinator.proto package coordinatorv1 diff --git a/owm-coordinator/proto/coordinator/v1/coordinator_grpc.pb.go b/owm-coordinator/proto/coordinator/v1/coordinator_grpc.pb.go index ff92226..819263e 100644 --- a/owm-coordinator/proto/coordinator/v1/coordinator_grpc.pb.go +++ b/owm-coordinator/proto/coordinator/v1/coordinator_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 -// - protoc v7.34.0 +// - protoc-gen-go-grpc v1.6.2 +// - protoc v3.21.12 // source: proto/coordinator/v1/coordinator.proto package coordinatorv1