From 127c0480e0ec4470c6910c123ea12ec251285ee0 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 06:03:56 -0500 Subject: [PATCH 01/11] feat(remediators): multi-strategy dispatch with first-success-wins (Task #15214) Add RemediatorRegistry.RemediateWithStrategies(ctx, []string, problem): tries strategy types in order, stops on first success, returns wrapped last error if all fail; reuses Remediate so circuit-breaker/rate-limit/ cooldown/dry-run all apply. detector.evaluateRemediation builds the ordered list from remCfg.Strategies (each nested .Strategy) when present, else falls back to [remCfg.Strategy] (backward compatible). Tests cover first-fail-then-success ordering, short-circuit, all-fail, and fallback. Note: real config types are Strategy string + Strategies []MonitorRemediationConfig (task premise said []string); per-strategy params don't flow through Remediate today, so dispatch is by type string. --- pkg/detector/detector.go | 55 +++++++++-- pkg/detector/mocks_test.go | 33 +++++++ pkg/detector/remediation_test.go | 162 +++++++++++++++++++++++++++++++ pkg/remediators/registry.go | 45 +++++++++ pkg/remediators/registry_test.go | 150 ++++++++++++++++++++++++++++ 5 files changed, 439 insertions(+), 6 deletions(-) diff --git a/pkg/detector/detector.go b/pkg/detector/detector.go index c9fc6c7..9761c4e 100644 --- a/pkg/detector/detector.go +++ b/pkg/detector/detector.go @@ -133,6 +133,9 @@ type MonitorFactory interface { type RemediationExecutor interface { // Remediate executes the named remediator strategy for the given problem. Remediate(ctx context.Context, remediatorType string, problem types.Problem) error + // RemediateWithStrategies attempts an ordered list of remediator strategy + // types, stopping at the first one that succeeds (first-success-wins). + RemediateWithStrategies(ctx context.Context, strategyTypes []string, problem types.Problem) error // IsDryRun reports whether the executor is running in dry-run mode. IsDryRun() bool } @@ -513,13 +516,22 @@ func (pd *ProblemDetector) evaluateRemediation(status *types.Status) { remCfg := monitorCfg.Remediation + // Build the ordered list of remediation strategy types to attempt. + // When Strategies is non-empty, dispatch each nested strategy in order + // (first-success-wins). Otherwise fall back to the single Strategy so that + // existing single-strategy configs behave exactly as before. + strategyTypes := buildStrategyList(remCfg) + if len(strategyTypes) == 0 { + return + } + for _, cond := range status.Conditions { if cond.Status != types.ConditionFalse { continue } problem := types.Problem{ - Type: remCfg.Strategy, + Type: strategyTypes[0], Resource: cond.Type, Severity: types.ProblemWarning, Message: cond.Message, @@ -531,18 +543,49 @@ func (pd *ProblemDetector) evaluateRemediation(status *types.Status) { }, } - if err := registry.Remediate(pd.ctx, remCfg.Strategy, problem); err != nil { - log.Printf("[WARN] Remediation failed for %s/%s (strategy=%s): %v", - status.Source, cond.Type, remCfg.Strategy, err) + if err := registry.RemediateWithStrategies(pd.ctx, strategyTypes, problem); err != nil { + log.Printf("[WARN] Remediation failed for %s/%s (strategies=%v): %v", + status.Source, cond.Type, strategyTypes, err) pd.stats.IncrementRemediationsFailed() } else { - log.Printf("[INFO] Remediation triggered for %s/%s (strategy=%s, dry-run=%v)", - status.Source, cond.Type, remCfg.Strategy, registry.IsDryRun()) + log.Printf("[INFO] Remediation triggered for %s/%s (strategies=%v, dry-run=%v)", + status.Source, cond.Type, strategyTypes, registry.IsDryRun()) pd.stats.IncrementRemediationsTriggered() } } } +// buildStrategyList returns the ordered list of remediation strategy types to +// attempt for a monitor's remediation config. +// +// When remCfg.Strategies is non-empty, the nested strategies are dispatched in +// order (each strategy's Strategy field, skipping empties). Otherwise it falls +// back to the single remCfg.Strategy, preserving backward compatibility: +// a config with only Strategy set yields []string{Strategy}. +func buildStrategyList(remCfg *types.MonitorRemediationConfig) []string { + if remCfg == nil { + return nil + } + + if len(remCfg.Strategies) > 0 { + strategies := make([]string, 0, len(remCfg.Strategies)) + for _, s := range remCfg.Strategies { + if s.Strategy != "" { + strategies = append(strategies, s.Strategy) + } + } + if len(strategies) > 0 { + return strategies + } + } + + if remCfg.Strategy != "" { + return []string{remCfg.Strategy} + } + + return nil +} + // fanInFromMonitor reads statuses from a monitor and forwards them to the main status channel func (pd *ProblemDetector) fanInFromMonitor(ctx context.Context, statusCh <-chan *types.Status, monitorName string) { log.Printf("[DEBUG] Starting fan-in for monitor %s", monitorName) diff --git a/pkg/detector/mocks_test.go b/pkg/detector/mocks_test.go index a718576..3d5ac8e 100644 --- a/pkg/detector/mocks_test.go +++ b/pkg/detector/mocks_test.go @@ -2,6 +2,7 @@ package detector import ( "context" + "fmt" "sync" "time" @@ -313,6 +314,38 @@ func (m *MockRemediationExecutor) Remediate(_ context.Context, remediatorType st return m.returnErr } +// RemediateWithStrategies records each attempted strategy and returns the +// configured error. It mirrors the real registry's first-success-wins +// semantics: if returnErr is nil the first strategy "succeeds" and remaining +// strategies are not attempted; if returnErr is non-nil every strategy is +// attempted and the error is returned. +func (m *MockRemediationExecutor) RemediateWithStrategies(ctx context.Context, strategyTypes []string, problem types.Problem) error { + if len(strategyTypes) == 0 { + return fmt.Errorf("no remediation strategies provided") + } + + m.mu.Lock() + err := m.returnErr + m.mu.Unlock() + + var lastErr error + for _, strategyType := range strategyTypes { + attemptProblem := problem + attemptProblem.Type = strategyType + // Record the attempt via Remediate so existing call assertions hold. + callErr := m.Remediate(ctx, strategyType, attemptProblem) + if callErr == nil && err == nil { + return nil + } + if err != nil { + lastErr = err + } else { + lastErr = callErr + } + } + return lastErr +} + // IsDryRun implements RemediationExecutor. func (m *MockRemediationExecutor) IsDryRun() bool { m.mu.Lock() diff --git a/pkg/detector/remediation_test.go b/pkg/detector/remediation_test.go index ae60b88..f118f98 100644 --- a/pkg/detector/remediation_test.go +++ b/pkg/detector/remediation_test.go @@ -261,3 +261,165 @@ func TestEvaluateRemediation_NoMonitorRemediationConfig(t *testing.T) { t.Errorf("expected 0 calls when monitor has no remediation config, got %d", exec.CallCount()) } } + +// TestBuildStrategyList covers the ordered strategy-list building and the +// single-strategy fallback used by evaluateRemediation. +func TestBuildStrategyList(t *testing.T) { + tests := []struct { + name string + remCfg *types.MonitorRemediationConfig + want []string + }{ + { + name: "nil config yields nil", + remCfg: nil, + want: nil, + }, + { + name: "single strategy only - fallback", + remCfg: &types.MonitorRemediationConfig{Strategy: "systemd-restart"}, + want: []string{"systemd-restart"}, + }, + { + name: "strategies non-empty preserves order and ignores top-level Strategy", + remCfg: &types.MonitorRemediationConfig{ + Strategy: "node-reboot", // should be ignored when Strategies present + Strategies: []types.MonitorRemediationConfig{ + {Strategy: "systemd-restart"}, + {Strategy: "custom-script"}, + {Strategy: "node-reboot"}, + }, + }, + want: []string{"systemd-restart", "custom-script", "node-reboot"}, + }, + { + name: "empty strategy entries are skipped", + remCfg: &types.MonitorRemediationConfig{ + Strategy: "pod-delete", + Strategies: []types.MonitorRemediationConfig{ + {Strategy: ""}, + {Strategy: "systemd-restart"}, + {Strategy: ""}, + }, + }, + want: []string{"systemd-restart"}, + }, + { + name: "all empty strategy entries fall back to single Strategy", + remCfg: &types.MonitorRemediationConfig{ + Strategy: "pod-delete", + Strategies: []types.MonitorRemediationConfig{ + {Strategy: ""}, + }, + }, + want: []string{"pod-delete"}, + }, + { + name: "no strategy at all yields nil", + remCfg: &types.MonitorRemediationConfig{}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildStrategyList(tt.remCfg) + if len(got) != len(tt.want) { + t.Fatalf("buildStrategyList() = %v, want %v", got, tt.want) + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("buildStrategyList()[%d] = %q, want %q", i, got[i], tt.want[i]) + } + } + }) + } +} + +// TestEvaluateRemediation_MultiStrategyDispatch verifies that a monitor config +// with a Strategies list dispatches each strategy in order through the executor. +func TestEvaluateRemediation_MultiStrategyDispatch(t *testing.T) { + exec := NewMockRemediationExecutor() + // Force all attempts to "fail" so the executor walks every strategy in order. + exec.SetError(fmt.Errorf("simulated failure")) + + monCfg := types.MonitorConfig{ + Name: "multi-monitor", + Type: "test", + Enabled: true, + Interval: 30 * time.Second, + Timeout: 10 * time.Second, + Remediation: &types.MonitorRemediationConfig{ + Enabled: true, + Strategy: "node-reboot", // ignored in favor of Strategies + Strategies: []types.MonitorRemediationConfig{ + {Strategy: "node-reboot"}, + {Strategy: "pod-delete"}, + }, + }, + } + pd, mon := buildDetectorWithRemediation(t, monCfg, exec) + + mon.AddStatusUpdate(unhealthyStatus("multi-monitor", "ServiceHealthy")) + if !pollUntil(t, time.Second, func() bool { return exec.CallCount() == 2 }) { + t.Fatalf("expected 2 ordered executor calls within 1s, got %d", exec.CallCount()) + } + + calls := exec.Calls() + if calls[0].RemediatorType != "node-reboot" { + t.Errorf("first strategy = %q, want node-reboot", calls[0].RemediatorType) + } + if calls[1].RemediatorType != "pod-delete" { + t.Errorf("second strategy = %q, want pod-delete", calls[1].RemediatorType) + } + + // All strategies failed → counts as a failed remediation, not triggered. + snap := pd.GetStatistics() + if snap.GetRemediationsFailed() != 1 { + t.Errorf("expected remediationsFailed=1, got %d", snap.GetRemediationsFailed()) + } + if snap.GetRemediationsTriggered() != 0 { + t.Errorf("expected remediationsTriggered=0, got %d", snap.GetRemediationsTriggered()) + } +} + +// TestEvaluateRemediation_MultiStrategyFirstSuccessWins verifies that when the +// first strategy succeeds, subsequent strategies are not attempted. +func TestEvaluateRemediation_MultiStrategyFirstSuccessWins(t *testing.T) { + exec := NewMockRemediationExecutor() // no error => first attempt succeeds + + monCfg := types.MonitorConfig{ + Name: "multi-monitor", + Type: "test", + Enabled: true, + Interval: 30 * time.Second, + Timeout: 10 * time.Second, + Remediation: &types.MonitorRemediationConfig{ + Enabled: true, + Strategy: "node-reboot", // required by validation; primary strategy + Strategies: []types.MonitorRemediationConfig{ + {Strategy: "node-reboot"}, + {Strategy: "pod-delete"}, + }, + }, + } + pd, mon := buildDetectorWithRemediation(t, monCfg, exec) + + mon.AddStatusUpdate(unhealthyStatus("multi-monitor", "ServiceHealthy")) + if !pollUntil(t, time.Second, func() bool { return exec.CallCount() == 1 }) { + t.Fatalf("expected exactly 1 executor call (first-success-wins) within 1s, got %d", exec.CallCount()) + } + // Give a brief moment to ensure no second call sneaks in. + time.Sleep(50 * time.Millisecond) + if exec.CallCount() != 1 { + t.Errorf("expected 1 call after first success, got %d", exec.CallCount()) + } + if exec.Calls()[0].RemediatorType != "node-reboot" { + t.Errorf("first strategy = %q, want node-reboot", exec.Calls()[0].RemediatorType) + } + + snap := pd.GetStatistics() + if snap.GetRemediationsTriggered() != 1 { + t.Errorf("expected remediationsTriggered=1, got %d", snap.GetRemediationsTriggered()) + } +} diff --git a/pkg/remediators/registry.go b/pkg/remediators/registry.go index 073bc95..3f637f0 100644 --- a/pkg/remediators/registry.go +++ b/pkg/remediators/registry.go @@ -655,6 +655,51 @@ func (r *RemediatorRegistry) Remediate(ctx context.Context, remediatorType strin return nil } +// RemediateWithStrategies attempts an ordered list of remediation strategy +// types, stopping at the first one that succeeds (first-success-wins). +// +// Each strategy type is dispatched through Remediate, so every per-strategy +// safety check (circuit breaker, rate limit, cooldown, max attempts, dry-run) +// is applied identically to single-strategy remediation. The problem's Type +// field is set to the strategy being attempted before each call so that +// history, logging, and problem keys reflect the actual strategy used. +// +// Behavior: +// - Returns nil as soon as any strategy succeeds. +// - If a strategy fails, the next strategy in the list is attempted. +// - If all strategies fail, the last error is returned (wrapped with the +// count of attempted strategies for context). +// - An empty strategyTypes list returns an error (nothing to do). +// +// A list with a single strategy behaves exactly like calling Remediate once, +// preserving backward compatibility for existing single-strategy configs. +func (r *RemediatorRegistry) RemediateWithStrategies(ctx context.Context, strategyTypes []string, problem types.Problem) error { + if len(strategyTypes) == 0 { + return fmt.Errorf("no remediation strategies provided") + } + + var lastErr error + for i, strategyType := range strategyTypes { + // Set the problem type to the strategy being attempted so history, + // logging, and problem keys reflect the actual strategy used. + attemptProblem := problem + attemptProblem.Type = strategyType + + err := r.Remediate(ctx, strategyType, attemptProblem) + if err == nil { + // First success wins - stop here. + return nil + } + + lastErr = err + r.logInfof("Remediation strategy %d/%d (%s) failed, trying next: %v", + i+1, len(strategyTypes), strategyType, err) + } + + return fmt.Errorf("all %d remediation strategies failed, last error: %w", + len(strategyTypes), lastErr) +} + // checkCircuitBreaker checks if the circuit breaker allows remediation. // This must be called with the lock held. func (r *RemediatorRegistry) checkCircuitBreaker() error { diff --git a/pkg/remediators/registry_test.go b/pkg/remediators/registry_test.go index 35988d0..f954fc4 100644 --- a/pkg/remediators/registry_test.go +++ b/pkg/remediators/registry_test.go @@ -1607,3 +1607,153 @@ func TestSetCircuitStateObserver(t *testing.T) { registry.ResetCircuitBreaker() }) } + +// TestRemediateWithStrategies tests ordered multi-strategy dispatch with +// first-success-wins semantics. +func TestRemediateWithStrategies(t *testing.T) { + t.Run("first fails then second succeeds - tries in order, stops at first success", func(t *testing.T) { + registry := NewRegistry(100, 100) + + failMock := newMockRemediator("strategy-a", true) // always fails (first 10 attempts) + successMock := newMockRemediator("strategy-b", false) // always succeeds + neverMock := newMockRemediator("strategy-c", false) // should never be reached + + registry.Register(RemediatorInfo{ + Type: "strategy-a", + Factory: func() (types.Remediator, error) { return failMock, nil }, + }) + registry.Register(RemediatorInfo{ + Type: "strategy-b", + Factory: func() (types.Remediator, error) { return successMock, nil }, + }) + registry.Register(RemediatorInfo{ + Type: "strategy-c", + Factory: func() (types.Remediator, error) { return neverMock, nil }, + }) + + problem := createTestProblem("test-type", "test-resource") + + err := registry.RemediateWithStrategies(context.Background(), + []string{"strategy-a", "strategy-b", "strategy-c"}, problem) + if err != nil { + t.Fatalf("expected success (second strategy), got error: %v", err) + } + + // First strategy should have been attempted (and failed). + if failMock.getCallCount() != 1 { + t.Errorf("strategy-a call count = %d, want 1", failMock.getCallCount()) + } + // Second strategy should have been attempted (and succeeded). + if successMock.getCallCount() != 1 { + t.Errorf("strategy-b call count = %d, want 1", successMock.getCallCount()) + } + // Third strategy must never be reached after a success. + if neverMock.getCallCount() != 0 { + t.Errorf("strategy-c call count = %d, want 0 (should short-circuit)", neverMock.getCallCount()) + } + + // History should show exactly two attempts in order: a (fail), b (success). + history := registry.GetHistory(10) + if len(history) != 2 { + t.Fatalf("history length = %d, want 2", len(history)) + } + if history[0].RemediatorType != "strategy-a" || history[0].Success { + t.Errorf("history[0] = {%s, success=%v}, want {strategy-a, success=false}", + history[0].RemediatorType, history[0].Success) + } + if history[1].RemediatorType != "strategy-b" || !history[1].Success { + t.Errorf("history[1] = {%s, success=%v}, want {strategy-b, success=true}", + history[1].RemediatorType, history[1].Success) + } + }) + + t.Run("single success short-circuits", func(t *testing.T) { + registry := NewRegistry(100, 100) + + successMock := newMockRemediator("only", false) + secondMock := newMockRemediator("second", false) + + registry.Register(RemediatorInfo{ + Type: "only", + Factory: func() (types.Remediator, error) { return successMock, nil }, + }) + registry.Register(RemediatorInfo{ + Type: "second", + Factory: func() (types.Remediator, error) { return secondMock, nil }, + }) + + problem := createTestProblem("test-type", "test-resource") + + err := registry.RemediateWithStrategies(context.Background(), + []string{"only", "second"}, problem) + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if successMock.getCallCount() != 1 { + t.Errorf("first strategy call count = %d, want 1", successMock.getCallCount()) + } + if secondMock.getCallCount() != 0 { + t.Errorf("second strategy call count = %d, want 0 (short-circuit on first success)", + secondMock.getCallCount()) + } + }) + + t.Run("all strategies fail returns error", func(t *testing.T) { + registry := NewRegistry(100, 100) + + failA := newMockRemediator("fail-a", true) + failB := newMockRemediator("fail-b", true) + + registry.Register(RemediatorInfo{ + Type: "fail-a", + Factory: func() (types.Remediator, error) { return failA, nil }, + }) + registry.Register(RemediatorInfo{ + Type: "fail-b", + Factory: func() (types.Remediator, error) { return failB, nil }, + }) + + problem := createTestProblem("test-type", "test-resource") + + err := registry.RemediateWithStrategies(context.Background(), + []string{"fail-a", "fail-b"}, problem) + if err == nil { + t.Fatal("expected error when all strategies fail, got nil") + } + // Both strategies should have been attempted. + if failA.getCallCount() != 1 { + t.Errorf("fail-a call count = %d, want 1", failA.getCallCount()) + } + if failB.getCallCount() != 1 { + t.Errorf("fail-b call count = %d, want 1", failB.getCallCount()) + } + }) + + t.Run("empty strategy list returns error", func(t *testing.T) { + registry := NewRegistry(100, 100) + problem := createTestProblem("test-type", "test-resource") + + err := registry.RemediateWithStrategies(context.Background(), nil, problem) + if err == nil { + t.Fatal("expected error for empty strategy list, got nil") + } + }) + + t.Run("single strategy behaves like Remediate (backward compat)", func(t *testing.T) { + registry := NewRegistry(100, 100) + mock := newMockRemediator("single", false) + registry.Register(RemediatorInfo{ + Type: "single", + Factory: func() (types.Remediator, error) { return mock, nil }, + }) + + problem := createTestProblem("test-type", "test-resource") + if err := registry.RemediateWithStrategies(context.Background(), + []string{"single"}, problem); err != nil { + t.Fatalf("single-strategy dispatch failed: %v", err) + } + if mock.getCallCount() != 1 { + t.Errorf("call count = %d, want 1", mock.getCallCount()) + } + }) +} From 6e376e9dc141a27125e51071e144770e20988234 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 06:09:51 -0500 Subject: [PATCH 02/11] feat(controller): implement correlator pattern injection + tests (Task #15215) Replace the InjectProblemPattern no-op stub: store injected patterns (mutex-guarded, name-keyed/idempotent) and feed them into detectCommonCauseCorrelation alongside built-ins (deduped by name) via the same findNodesWithAllProblems/MinNodesForCorrelation/confidence logic. Signature now returns error (no existing callers): empty name/ problems rejected. Add correlator_test.go (first tests for the package): detection thresholds, confidence, GetStats, injection validation, and injected-pattern participation. Race-clean. Note: task referenced GetCorrelationSummary/AddPattern which don't exist; real equivalents are GetStats/GetActiveCorrelations. --- pkg/controller/correlator.go | 106 +++++++-- pkg/controller/correlator_test.go | 369 ++++++++++++++++++++++++++++++ 2 files changed, 455 insertions(+), 20 deletions(-) create mode 100644 pkg/controller/correlator_test.go diff --git a/pkg/controller/correlator.go b/pkg/controller/correlator.go index 3a2a500..1d98f00 100644 --- a/pkg/controller/correlator.go +++ b/pkg/controller/correlator.go @@ -35,6 +35,11 @@ type Correlator struct { mu sync.RWMutex active map[string]*Correlation // id -> correlation + // injectedPatterns holds custom common-cause patterns registered at runtime + // via InjectProblemPattern. They participate in common-cause detection + // alongside the built-in patterns. Guarded by mu. + injectedPatterns []injectedPattern + // Tracking for detection nodeReports map[string]*NodeReport // nodeName -> latest report totalNodes int @@ -46,6 +51,16 @@ type Correlator struct { wg sync.WaitGroup } +// injectedPattern is a custom common-cause pattern registered at runtime via +// InjectProblemPattern. It mirrors the shape of the built-in common-cause +// patterns: a set of problem types that, when all present on a node, indicate +// a shared root cause. +type injectedPattern struct { + problems []string + name string + description string +} + // NewCorrelator creates a new Correlator instance. func NewCorrelator(config *CorrelationConfig, storage Storage, metrics *ControllerMetrics, events *EventRecorder) *Correlator { if config == nil { @@ -58,13 +73,14 @@ func NewCorrelator(config *CorrelationConfig, storage Storage, metrics *Controll } return &Correlator{ - config: config, - storage: storage, - metrics: metrics, - events: events, - active: make(map[string]*Correlation), - nodeReports: make(map[string]*NodeReport), - stopCh: make(chan struct{}), + config: config, + storage: storage, + metrics: metrics, + events: events, + active: make(map[string]*Correlation), + nodeReports: make(map[string]*NodeReport), + injectedPatterns: make([]injectedPattern, 0), + stopCh: make(chan struct{}), } } @@ -384,11 +400,7 @@ func (c *Correlator) detectInfrastructureCorrelation(reports []*NodeReport, tota // Example: Memory pressure + Disk pressure on same nodes = resource exhaustion. func (c *Correlator) detectCommonCauseCorrelation(reports []*NodeReport) []*Correlation { // Known common-cause patterns - commonCausePatterns := []struct { - problems []string - name string - description string - }{ + commonCausePatterns := []injectedPattern{ { problems: []string{"MemoryPressure", "DiskPressure"}, name: "resource-exhaustion", @@ -406,6 +418,21 @@ func (c *Correlator) detectCommonCauseCorrelation(reports []*NodeReport) []*Corr }, } + // Append runtime-injected patterns, deduping by name so an injected pattern + // cannot shadow or double-detect a built-in with the same name. + c.mu.RLock() + builtinNames := make(map[string]bool, len(commonCausePatterns)) + for _, p := range commonCausePatterns { + builtinNames[p.name] = true + } + for _, p := range c.injectedPatterns { + if builtinNames[p.name] { + continue + } + commonCausePatterns = append(commonCausePatterns, p) + } + c.mu.RUnlock() + var correlations []*Correlation for _, pattern := range commonCausePatterns { @@ -662,11 +689,53 @@ func (c *Correlator) EvaluateNow(ctx context.Context) { c.evaluate(ctx) } -// InjectProblemPattern allows adding custom problem patterns for detection. -// This is useful for extending the correlator with domain-specific patterns. -func (c *Correlator) InjectProblemPattern(patternType string, problems []string, name, description string) { - // This is a hook for future extensibility - log.Printf("[DEBUG] Pattern injection not yet implemented: %s", name) +// InjectProblemPattern registers a custom common-cause problem pattern for +// detection. Injected patterns participate in common-cause correlation +// detection alongside the built-in patterns: a node that has all of the +// pattern's problem types active is considered a match, and when at least +// MinNodesForCorrelation nodes match, a correlation is produced. +// +// The patternType argument is accepted for forward-compatibility and +// callers should pass CorrelationTypeCommonCause; only common-cause patterns +// are currently supported. A non-empty name and at least one problem type are +// required. Re-injecting a pattern with an existing name replaces the previous +// definition. An error is returned (and nothing is stored) when validation +// fails so callers can surface the problem. +func (c *Correlator) InjectProblemPattern(patternType string, problems []string, name, description string) error { + if strings.TrimSpace(name) == "" { + log.Printf("[WARN] InjectProblemPattern: ignoring pattern with empty name") + return fmt.Errorf("pattern name must not be empty") + } + + if len(problems) == 0 { + log.Printf("[WARN] InjectProblemPattern: ignoring pattern %q with no problem types", name) + return fmt.Errorf("pattern %q must define at least one problem type", name) + } + + // Copy the problem slice so the caller cannot mutate it after injection. + problemsCopy := append([]string{}, problems...) + + c.mu.Lock() + defer c.mu.Unlock() + + // Replace an existing injected pattern with the same name (idempotent + // re-injection) rather than appending a duplicate. + for i := range c.injectedPatterns { + if c.injectedPatterns[i].name == name { + c.injectedPatterns[i] = injectedPattern{problems: problemsCopy, name: name, description: description} + log.Printf("[INFO] InjectProblemPattern: updated pattern %q (problems=%v)", name, problemsCopy) + return nil + } + } + + c.injectedPatterns = append(c.injectedPatterns, injectedPattern{ + problems: problemsCopy, + name: name, + description: description, + }) + + log.Printf("[INFO] InjectProblemPattern: registered pattern %q (type=%s, problems=%v)", name, patternType, problemsCopy) + return nil } // ForceResolve forces a correlation to be resolved (for manual intervention). @@ -698,9 +767,6 @@ func (c *Correlator) ForceResolve(ctx context.Context, correlationID string) err return nil } -// Ensure unused import doesn't cause issues -var _ = strings.TrimSpace - // SetStorage wires a storage backend into the correlator. It must be called // before Start() so that the initial load of active correlations succeeds. // Calling SetStorage after Start() is safe but the load-on-start path will diff --git a/pkg/controller/correlator_test.go b/pkg/controller/correlator_test.go new file mode 100644 index 0000000..a837578 --- /dev/null +++ b/pkg/controller/correlator_test.go @@ -0,0 +1,369 @@ +package controller + +import ( + "context" + "testing" + "time" +) + +// newTestCorrelator returns a Correlator with no storage/metrics/events wired, +// which is sufficient for exercising the in-memory detection and injection +// paths. minNodes controls MinNodesForCorrelation. +func newTestCorrelator(minNodes int) *Correlator { + cfg := &CorrelationConfig{ + Enabled: true, + ClusterWideThreshold: 0.3, + EvaluationInterval: 30 * time.Second, + MinNodesForCorrelation: minNodes, + } + return NewCorrelator(cfg, nil, nil, nil) +} + +// reportWithProblems builds a NodeReport for nodeName whose ActiveProblems are +// the supplied problem types. +func reportWithProblems(nodeName string, problemTypes ...string) *NodeReport { + problems := make([]ProblemSummary, 0, len(problemTypes)) + for _, pt := range problemTypes { + problems = append(problems, ProblemSummary{ + Type: pt, + Severity: "warning", + Message: pt + " active", + Source: "test", + DetectedAt: time.Now(), + LastSeenAt: time.Now(), + }) + } + return &NodeReport{ + NodeName: nodeName, + Timestamp: time.Now(), + OverallHealth: HealthStatusDegraded, + ActiveProblems: problems, + } +} + +// findCorrelationByPattern returns the active correlation whose metadata +// "pattern" key matches name, or nil. +func findCorrelationByPattern(corrs []*Correlation, name string) *Correlation { + for _, c := range corrs { + if p, ok := c.Metadata["pattern"]; ok && p == name { + return c + } + } + return nil +} + +// TestDetectCommonCauseCorrelation exercises the common-cause detection entry +// point for the built-in "resource-exhaustion" pattern (MemoryPressure + +// DiskPressure), both above and below the node threshold. +func TestDetectCommonCauseCorrelation(t *testing.T) { + tests := []struct { + name string + minNodes int + reports []*NodeReport + wantDetect bool + wantNodes []string + wantProblem []string + }{ + { + name: "at threshold detects", + minNodes: 2, + reports: []*NodeReport{ + reportWithProblems("node-a", "MemoryPressure", "DiskPressure"), + reportWithProblems("node-b", "MemoryPressure", "DiskPressure"), + }, + wantDetect: true, + wantNodes: []string{"node-a", "node-b"}, + wantProblem: []string{"MemoryPressure", "DiskPressure"}, + }, + { + name: "below threshold does not detect", + minNodes: 2, + reports: []*NodeReport{ + reportWithProblems("node-a", "MemoryPressure", "DiskPressure"), + reportWithProblems("node-b", "MemoryPressure"), // missing DiskPressure + }, + wantDetect: false, + }, + { + name: "partial problem set on a node is not a match", + minNodes: 1, + reports: []*NodeReport{ + reportWithProblems("node-a", "MemoryPressure"), // only one of two + }, + wantDetect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := newTestCorrelator(tt.minNodes) + for _, r := range tt.reports { + c.UpdateNodeReport(r) + } + + c.EvaluateNow(context.Background()) + + active := c.GetActiveCorrelations() + corr := findCorrelationByPattern(active, "resource-exhaustion") + + if tt.wantDetect { + if corr == nil { + t.Fatalf("expected resource-exhaustion correlation, got none (active=%d)", len(active)) + } + if corr.Type != CorrelationTypeCommonCause { + t.Errorf("type = %q, want %q", corr.Type, CorrelationTypeCommonCause) + } + if len(corr.AffectedNodes) != len(tt.wantNodes) { + t.Errorf("affected nodes = %v, want %v", corr.AffectedNodes, tt.wantNodes) + } + if len(corr.ProblemTypes) != len(tt.wantProblem) { + t.Errorf("problem types = %v, want %v", corr.ProblemTypes, tt.wantProblem) + } + if corr.Status != CorrelationStatusActive { + t.Errorf("status = %q, want %q", corr.Status, CorrelationStatusActive) + } + } else if corr != nil { + t.Fatalf("expected no resource-exhaustion correlation, got %+v", corr) + } + }) + } +} + +// TestDetectInfrastructureConfidence verifies confidence and affected-node +// calculation for the infrastructure detection path. Two of two nodes share a +// problem type => ratio 1.0 => critical severity, confidence 1.0. +func TestDetectInfrastructureConfidence(t *testing.T) { + c := newTestCorrelator(2) + c.UpdateNodeReport(reportWithProblems("node-a", "DNSFailure")) + c.UpdateNodeReport(reportWithProblems("node-b", "DNSFailure")) + + c.EvaluateNow(context.Background()) + + active := c.GetActiveCorrelations() + var infra *Correlation + for _, corr := range active { + if corr.Type == CorrelationTypeInfrastructure { + infra = corr + break + } + } + if infra == nil { + t.Fatalf("expected an infrastructure correlation, got none (active=%d)", len(active)) + } + + if infra.Confidence != 1.0 { + t.Errorf("confidence = %v, want 1.0", infra.Confidence) + } + if infra.Severity != "critical" { + t.Errorf("severity = %q, want critical (ratio >= 0.5)", infra.Severity) + } + if len(infra.AffectedNodes) != 2 { + t.Errorf("affected nodes = %v, want 2", infra.AffectedNodes) + } + if r, ok := infra.Metadata["ratio"].(float64); !ok || r != 1.0 { + t.Errorf("metadata ratio = %v, want 1.0", infra.Metadata["ratio"]) + } +} + +// TestGetStatsAfterDetection verifies the summary-shaped accessor reflects the +// post-detection state. +func TestGetStatsAfterDetection(t *testing.T) { + c := newTestCorrelator(2) + c.UpdateNodeReport(reportWithProblems("node-a", "MemoryPressure", "DiskPressure")) + c.UpdateNodeReport(reportWithProblems("node-b", "MemoryPressure", "DiskPressure")) + + // Before evaluation: no active correlations, 2 tracked nodes. + pre := c.GetStats() + if pre.TrackedNodes != 2 { + t.Errorf("pre TrackedNodes = %d, want 2", pre.TrackedNodes) + } + if pre.ActiveCorrelations != 0 { + t.Errorf("pre ActiveCorrelations = %d, want 0", pre.ActiveCorrelations) + } + + c.EvaluateNow(context.Background()) + + post := c.GetStats() + if post.ActiveCorrelations < 1 { + t.Errorf("post ActiveCorrelations = %d, want >= 1", post.ActiveCorrelations) + } + if post.TrackedNodes != 2 { + t.Errorf("post TrackedNodes = %d, want 2", post.TrackedNodes) + } + if post.LastEvalTime.IsZero() { + t.Errorf("post LastEvalTime should be set after evaluation") + } +} + +// TestInjectProblemPatternValidation covers the validation behavior of the +// injection API. +func TestInjectProblemPatternValidation(t *testing.T) { + tests := []struct { + name string + patternName string + problems []string + wantErr bool + }{ + { + name: "empty name is rejected", + patternName: "", + problems: []string{"A", "B"}, + wantErr: true, + }, + { + name: "whitespace name is rejected", + patternName: " ", + problems: []string{"A", "B"}, + wantErr: true, + }, + { + name: "empty problems is rejected", + patternName: "custom", + problems: nil, + wantErr: true, + }, + { + name: "valid pattern is accepted", + patternName: "custom", + problems: []string{"A", "B"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := newTestCorrelator(2) + err := c.InjectProblemPattern(CorrelationTypeCommonCause, tt.problems, tt.patternName, "desc") + if tt.wantErr && err == nil { + t.Fatalf("expected error, got nil") + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + + c.mu.RLock() + stored := len(c.injectedPatterns) + c.mu.RUnlock() + if tt.wantErr && stored != 0 { + t.Errorf("rejected pattern should not be stored, found %d", stored) + } + if !tt.wantErr && stored != 1 { + t.Errorf("valid pattern should be stored once, found %d", stored) + } + }) + } +} + +// TestInjectProblemPatternIdempotentReinjection verifies re-injecting under the +// same name replaces rather than duplicates. +func TestInjectProblemPatternIdempotentReinjection(t *testing.T) { + c := newTestCorrelator(2) + + if err := c.InjectProblemPattern(CorrelationTypeCommonCause, []string{"A", "B"}, "custom", "v1"); err != nil { + t.Fatalf("first inject: %v", err) + } + if err := c.InjectProblemPattern(CorrelationTypeCommonCause, []string{"A", "C"}, "custom", "v2"); err != nil { + t.Fatalf("second inject: %v", err) + } + + c.mu.RLock() + defer c.mu.RUnlock() + if len(c.injectedPatterns) != 1 { + t.Fatalf("expected 1 stored pattern after re-injection, got %d", len(c.injectedPatterns)) + } + got := c.injectedPatterns[0] + if got.description != "v2" || len(got.problems) != 2 || got.problems[1] != "C" { + t.Errorf("re-injection did not replace pattern: %+v", got) + } +} + +// TestInjectedPatternParticipatesInDetection is the core end-to-end check: +// an injected custom pattern must be detected when enough nodes match. +func TestInjectedPatternParticipatesInDetection(t *testing.T) { + c := newTestCorrelator(2) + + customProblems := []string{"GPUOverheat", "ThermalThrottle"} + if err := c.InjectProblemPattern(CorrelationTypeCommonCause, customProblems, "gpu-thermal", "GPU thermal event"); err != nil { + t.Fatalf("inject: %v", err) + } + + // Two nodes share both injected problem types => should match (>= minNodes). + c.UpdateNodeReport(reportWithProblems("node-a", "GPUOverheat", "ThermalThrottle")) + c.UpdateNodeReport(reportWithProblems("node-b", "GPUOverheat", "ThermalThrottle")) + // A third node has only one of the two => must not be counted as a match. + c.UpdateNodeReport(reportWithProblems("node-c", "GPUOverheat")) + + c.EvaluateNow(context.Background()) + + active := c.GetActiveCorrelations() + corr := findCorrelationByPattern(active, "gpu-thermal") + if corr == nil { + t.Fatalf("expected injected gpu-thermal correlation, got none (active=%d)", len(active)) + } + if corr.Type != CorrelationTypeCommonCause { + t.Errorf("type = %q, want %q", corr.Type, CorrelationTypeCommonCause) + } + if len(corr.AffectedNodes) != 2 { + t.Errorf("affected nodes = %v, want exactly [node-a node-b]", corr.AffectedNodes) + } + for _, n := range corr.AffectedNodes { + if n == "node-c" { + t.Errorf("node-c (partial match) should not be in affected nodes: %v", corr.AffectedNodes) + } + } + if len(corr.ProblemTypes) != 2 { + t.Errorf("problem types = %v, want 2", corr.ProblemTypes) + } + // Confidence = matching nodes / total reports = 2/3. + if corr.Confidence <= 0 || corr.Confidence > 1.0 { + t.Errorf("confidence = %v, want in (0,1]", corr.Confidence) + } +} + +// TestInjectedPatternBelowThresholdNotDetected verifies an injected pattern +// matched by too few nodes does not produce a correlation. +func TestInjectedPatternBelowThresholdNotDetected(t *testing.T) { + c := newTestCorrelator(3) // require 3 nodes + + if err := c.InjectProblemPattern(CorrelationTypeCommonCause, []string{"X", "Y"}, "xy-pattern", "desc"); err != nil { + t.Fatalf("inject: %v", err) + } + + c.UpdateNodeReport(reportWithProblems("node-a", "X", "Y")) + c.UpdateNodeReport(reportWithProblems("node-b", "X", "Y")) + // Only 2 match, threshold is 3. + + c.EvaluateNow(context.Background()) + + if corr := findCorrelationByPattern(c.GetActiveCorrelations(), "xy-pattern"); corr != nil { + t.Fatalf("expected no xy-pattern correlation below threshold, got %+v", corr) + } +} + +// TestInjectedPatternDoesNotShadowBuiltin ensures an injected pattern reusing a +// built-in name is deduped (the built-in remains, no duplicate detection). +func TestInjectedPatternDoesNotShadowBuiltin(t *testing.T) { + c := newTestCorrelator(2) + + // Inject using the built-in name "resource-exhaustion" but with different + // problems. The dedupe-by-name logic should drop the injected copy. + if err := c.InjectProblemPattern(CorrelationTypeCommonCause, []string{"Bogus1", "Bogus2"}, "resource-exhaustion", "shadow"); err != nil { + t.Fatalf("inject: %v", err) + } + + // Feed the built-in problem set. + c.UpdateNodeReport(reportWithProblems("node-a", "MemoryPressure", "DiskPressure")) + c.UpdateNodeReport(reportWithProblems("node-b", "MemoryPressure", "DiskPressure")) + + c.EvaluateNow(context.Background()) + + count := 0 + for _, corr := range c.GetActiveCorrelations() { + if p, ok := corr.Metadata["pattern"]; ok && p == "resource-exhaustion" { + count++ + } + } + if count != 1 { + t.Fatalf("expected exactly 1 resource-exhaustion correlation (no shadowing), got %d", count) + } +} From 4881e204067e922608de9f951c8f07642b17ab29 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 06:20:55 -0500 Subject: [PATCH 03/11] feat(remediators): enforce MaxRemediationsPerMinute token bucket (Task #15216) MaxRemediationsPerMinute was parsed/defaulted(2)/validated but never enforced. Add a perMinuteBucket *rate.Limiter (golang.org/x/time/rate) on RemediatorRegistry via SetMaxRemediationsPerMinute (matching the Set* idiom; 0 = unlimited), wired in main from config. Enforced in Remediate AFTER the non-consuming per-hour window check so a per-hour rejection never burns a per-minute token and a rejected call isn't counted as executed. Updated the stale config comment. Tests: N succeed / N+1 rejected within a minute, and unlimited at 0. go mod tidy promoted x/time (+modernc/sqlite) to direct, dropped unused gogo/protobuf. --- cmd/node-doctor/main.go | 7 ++- go.mod | 5 +- go.sum | 71 ++++++++++-------------- pkg/remediators/registry.go | 68 ++++++++++++++++++++++- pkg/remediators/registry_test.go | 92 ++++++++++++++++++++++++++++++++ pkg/types/config.go | 8 +-- 6 files changed, 198 insertions(+), 53 deletions(-) diff --git a/cmd/node-doctor/main.go b/cmd/node-doctor/main.go index 3c9d2d7..5da9efa 100644 --- a/cmd/node-doctor/main.go +++ b/cmd/node-doctor/main.go @@ -202,8 +202,11 @@ func main() { } remediatorRegistry = remediators.NewRegistry(maxPerHour, historySize) remediatorRegistry.SetDryRun(config.Remediation.DryRun || config.Settings.DryRunMode) - log.Printf("[INFO] Remediator registry initialized (dry-run=%v, maxPerHour=%d)", - remediatorRegistry.IsDryRun(), maxPerHour) + // Wire the per-minute token-bucket burst limit. A value of 0 leaves the + // per-minute check disabled (only the per-hour window applies). + remediatorRegistry.SetMaxRemediationsPerMinute(config.Remediation.MaxRemediationsPerMinute) + log.Printf("[INFO] Remediator registry initialized (dry-run=%v, maxPerHour=%d, maxPerMinute=%d)", + remediatorRegistry.IsDryRun(), maxPerHour, config.Remediation.MaxRemediationsPerMinute) // Wire the controller lease client when coordination is opted in. // The registry's Remediate path checks for a non-nil lease client and diff --git a/go.mod b/go.mod index f6a7423..96b9ef1 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,13 @@ require ( github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_model v0.6.2 golang.org/x/net v0.47.0 + golang.org/x/time v0.14.0 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.35.0 k8s.io/apimachinery v0.35.0 k8s.io/client-go v0.35.0 + modernc.org/sqlite v1.42.1 ) require ( @@ -36,7 +38,6 @@ require ( github.com/go-openapi/swag/stringutils v0.25.3 // indirect github.com/go-openapi/swag/typeutils v0.25.3 // indirect github.com/go-openapi/swag/yamlutils v0.25.3 // indirect - github.com/gogo/protobuf v1.3.2 // indirect github.com/google/gnostic-models v0.7.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -57,7 +58,6 @@ require ( golang.org/x/sys v0.38.0 // indirect golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect - golang.org/x/time v0.14.0 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect @@ -67,7 +67,6 @@ require ( modernc.org/libc v1.66.10 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect - modernc.org/sqlite v1.42.1 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.1 // indirect diff --git a/go.sum b/go.sum index 52ab942..e7099e0 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= +github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -51,23 +53,17 @@ github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6 github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= -github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= -github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/google/gnostic-models v0.7.1 h1:SisTfuFKJSKM5CPZkffwi6coztzzeYUhc3v4yxLWH8c= github.com/google/gnostic-models v0.7.1/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= -github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= -github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= +github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= -github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -88,12 +84,10 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/onsi/ginkgo/v2 v2.21.0 h1:7rg/4f3rB88pb5obDgNZrNHrQ4e6WpjonchcpuBRnZM= -github.com/onsi/ginkgo/v2 v2.21.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= github.com/onsi/ginkgo/v2 v2.27.2 h1:LzwLj0b89qtIy6SSASkzlNvX6WktqurSHwkk2ipF/Ns= -github.com/onsi/gomega v1.35.1 h1:Cwbd75ZBPxFSuZ6T+rN/WCb/gOc6YgFBXLlZLhC7Ds4= -github.com/onsi/gomega v1.35.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog= +github.com/onsi/ginkgo/v2 v2.27.2/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo= github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A= +github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= @@ -118,56 +112,33 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= +golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -181,16 +152,10 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/api v0.34.2 h1:fsSUNZhV+bnL6Aqrp6O7lMTy6o5x2C4XLjnh//8SLYY= -k8s.io/api v0.34.2/go.mod h1:MMBPaWlED2a8w4RSeanD76f7opUoypY8TFYkSM+3XHw= k8s.io/api v0.35.0 h1:iBAU5LTyBI9vw3L5glmat1njFK34srdLmktWwLTprlY= k8s.io/api v0.35.0/go.mod h1:AQ0SNTzm4ZAczM03QH42c7l3bih1TbAXYo0DkF8ktnA= -k8s.io/apimachinery v0.34.2 h1:zQ12Uk3eMHPxrsbUJgNF8bTauTVR2WgqJsTmwTE/NW4= -k8s.io/apimachinery v0.34.2/go.mod h1:/GwIlEcWuTX9zKIg2mbw0LRFIsXwrfoVxn+ef0X13lw= k8s.io/apimachinery v0.35.0 h1:Z2L3IHvPVv/MJ7xRxHEtk6GoJElaAqDCCU0S6ncYok8= k8s.io/apimachinery v0.35.0/go.mod h1:jQCgFZFR1F4Ik7hvr2g84RTJSZegBc8yHgFWKn//hns= -k8s.io/client-go v0.34.2 h1:Co6XiknN+uUZqiddlfAjT68184/37PS4QAzYvQvDR8M= -k8s.io/client-go v0.34.2/go.mod h1:2VYDl1XXJsdcAxw7BenFslRQX28Dxz91U9MWKjX97fE= k8s.io/client-go v0.35.0 h1:IAW0ifFbfQQwQmga0UdoH0yvdqrbwMdq9vIFEhRpxBE= k8s.io/client-go v0.35.0/go.mod h1:q2E5AAyqcbeLGPdoRB+Nxe3KYTfPce1Dnu1myQdqz9o= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= @@ -199,14 +164,32 @@ k8s.io/kube-openapi v0.0.0-20251121143641-b6aabc6c6745 h1:c3rI/4s8ibM4vV5UOIlbgk k8s.io/kube-openapi v0.0.0-20251121143641-b6aabc6c6745/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ= k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck= k8s.io/utils v0.0.0-20251002143259-bc988d571ff4/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4= +modernc.org/cc/v4 v4.26.5/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.28.1 h1:wPKYn5EC/mYTqBO373jKjvX2n+3+aK7+sICCv4Fjy1A= +modernc.org/ccgo/v4 v4.28.1/go.mod h1:uD+4RnfrVgE6ec9NGguUNdhqzNIeeomeXf6CL0GTE5Q= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A= modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sqlite v1.42.1 h1:Uq9MgEygn10NFglbbQUhp7yVyRvvoB2tCdK4hxhVfrI= modernc.org/sqlite v1.42.1/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU= diff --git a/pkg/remediators/registry.go b/pkg/remediators/registry.go index 3f637f0..d096411 100644 --- a/pkg/remediators/registry.go +++ b/pkg/remediators/registry.go @@ -33,6 +33,7 @@ import ( "sync" "time" + "golang.org/x/time/rate" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -212,6 +213,14 @@ type RemediatorRegistry struct { maxPerHour int // max remediations per hour rateLimitWindow time.Duration + // Per-minute token bucket (optional). When maxPerMinute > 0, perMinuteBucket + // is a token-bucket limiter (burst == maxPerMinute, refill == maxPerMinute per + // minute) that caps the burst rate of remediations within any one minute. When + // maxPerMinute <= 0 the bucket is nil and the per-minute check is skipped + // (unlimited), mirroring how maxPerHour == 0 disables the per-hour window. + perMinuteBucket *rate.Limiter + maxPerMinute int + // History tracking history []RemediationRecord maxHistory int @@ -335,6 +344,30 @@ func (r *RemediatorRegistry) SetLeaseClient(leaseClient *LeaseClient) { r.logInfof("Lease client configured for controller coordination") } +// SetMaxRemediationsPerMinute configures the per-minute token-bucket rate limit. +// +// When n > 0, a rate.Limiter is built with burst n and a refill of n tokens per +// minute (rate.Every(time.Minute / n)). The limiter starts full, so up to n +// remediations may proceed immediately within a minute before further attempts +// are rejected until tokens refill. When n <= 0 the bucket is cleared and the +// per-minute check is skipped entirely (unlimited), matching how maxPerHour == 0 +// disables the per-hour window. +// +// This is the wiring counterpart to config.RemediationConfig.MaxRemediationsPerMinute. +func (r *RemediatorRegistry) SetMaxRemediationsPerMinute(n int) { + r.mu.Lock() + defer r.mu.Unlock() + + if n > 0 { + r.maxPerMinute = n + r.perMinuteBucket = rate.NewLimiter(rate.Every(time.Minute/time.Duration(n)), n) + r.logInfof("Per-minute remediation rate limit configured (max: %d/min)", n) + } else { + r.maxPerMinute = 0 + r.perMinuteBucket = nil + } +} + // SetCircuitStateObserver registers an observer that is notified of circuit // breaker state changes. The observer is called once immediately with the // current state (so a backing metric is correct from the start) and then on @@ -559,13 +592,26 @@ func (r *RemediatorRegistry) Remediate(ctx context.Context, remediatorType strin } r.mu.Unlock() - // Phase 2: Rate limit check + // Phase 2: Rate limit check. + // + // The per-hour sliding window is checked first because it is non-consuming + // (read-only). The per-minute token bucket is checked last and consumes a + // token only when the remediation is otherwise cleared to proceed, so a + // per-hour rejection never burns a per-minute token. A rejection here returns + // before any execution and before recordRateLimitEntry, so a rejected + // remediation is never counted against the per-hour window nor reported as + // executed (history records it as a failed attempt, success=false). r.mu.Lock() if err := r.checkRateLimit(); err != nil { r.mu.Unlock() remediationErr = err return err } + if err := r.checkPerMinuteRate(); err != nil { + r.mu.Unlock() + remediationErr = err + return err + } r.mu.Unlock() // Phase 2.5: Controller lease check (if coordination enabled) @@ -763,6 +809,26 @@ func (r *RemediatorRegistry) checkRateLimit() error { return nil } +// checkPerMinuteRate checks if the per-minute token bucket allows remediation. +// This must be called with the lock held. +// +// It consumes one token from the bucket on success, so it must only be called +// once the remediation is otherwise cleared to proceed (i.e. after the +// non-consuming per-hour window check passes). When no bucket is configured +// (maxPerMinute <= 0) the check is a no-op (unlimited). +func (r *RemediatorRegistry) checkPerMinuteRate() error { + if r.perMinuteBucket == nil { + return nil // Per-minute rate limiting disabled + } + + if !r.perMinuteBucket.Allow() { + return fmt.Errorf("rate limit exceeded: more than %d remediations within a minute (max: %d/min)", + r.maxPerMinute, r.maxPerMinute) + } + + return nil +} + // recordRateLimitEntry records a successful remediation for rate limiting. func (r *RemediatorRegistry) recordRateLimitEntry() { if r.maxPerHour == 0 { diff --git a/pkg/remediators/registry_test.go b/pkg/remediators/registry_test.go index f954fc4..216a080 100644 --- a/pkg/remediators/registry_test.go +++ b/pkg/remediators/registry_test.go @@ -566,6 +566,98 @@ func TestRateLimit(t *testing.T) { }) } +// TestPerMinuteRateLimit tests the per-minute token-bucket rate limit wired via +// SetMaxRemediationsPerMinute (config.RemediationConfig.MaxRemediationsPerMinute). +func TestPerMinuteRateLimit(t *testing.T) { + t.Run("per-minute limit enforced: N succeed, N+1 rejected", func(t *testing.T) { + const n = 2 + // maxPerHour high enough that it does not interfere; per-minute is the gate. + registry := NewRegistry(1000, 100) + registry.SetMaxRemediationsPerMinute(n) + + mock := newMockRemediator("test", false) + registry.Register(RemediatorInfo{ + Type: "test", + Factory: func() (types.Remediator, error) { return mock, nil }, + }) + + problem := createTestProblem("test-type", "test-resource") + + // The limiter starts full (burst == n), so the first n immediate + // remediations pass. + for i := 0; i < n; i++ { + mock.ClearCooldown(problem) + mock.ResetAttempts(problem) + if err := registry.Remediate(context.Background(), "test", problem); err != nil { + t.Fatalf("Remediation %d failed: %v", i+1, err) + } + } + + // The (N+1)th within the same minute must be rejected with a rate-limit error. + mock.ClearCooldown(problem) + mock.ResetAttempts(problem) + err := registry.Remediate(context.Background(), "test", problem) + if err == nil { + t.Fatal("Expected per-minute rate limit error on (N+1)th remediation, got nil") + } + + // The rejected remediation must NOT be executed: the mock's remediate func + // is never invoked for it, so callCount stays at n. + if got := mock.getCallCount(); got != n { + t.Errorf("mock call count = %d, want %d (rejected remediation must not execute)", got, n) + } + + // And it must not be counted against the per-hour window (only n recorded). + stats := registry.GetStats() + if stats.RecentRemediations != n { + t.Errorf("RecentRemediations = %d, want %d (rejected remediation must not consume per-hour window)", + stats.RecentRemediations, n) + } + + // History records the rejected attempt as a failed (not executed) record. + // There should be n+1 records: n successes + 1 failed attempt. + history := registry.GetHistory(0) + if len(history) != n+1 { + t.Fatalf("history length = %d, want %d", len(history), n+1) + } + last := history[len(history)-1] + if last.Success { + t.Error("last history record should be a failed (rejected) attempt, got Success=true") + } + if last.Error == "" { + t.Error("last history record should carry the rate-limit error") + } + }) + + t.Run("per-minute limit 0 is unlimited", func(t *testing.T) { + registry := NewRegistry(1000, 100) + registry.SetMaxRemediationsPerMinute(0) // unlimited + + mock := newMockRemediator("test", false) + registry.Register(RemediatorInfo{ + Type: "test", + Factory: func() (types.Remediator, error) { return mock, nil }, + }) + + problem := createTestProblem("test-type", "test-resource") + + // Execute well more than any default per-minute allowance; none should be + // rejected by the per-minute check (per-hour is high too). + const attempts = 20 + for i := 0; i < attempts; i++ { + mock.ClearCooldown(problem) + mock.ResetAttempts(problem) + if err := registry.Remediate(context.Background(), "test", problem); err != nil { + t.Fatalf("Remediation %d unexpectedly failed with per-minute disabled: %v", i+1, err) + } + } + + if got := mock.getCallCount(); got != attempts { + t.Errorf("mock call count = %d, want %d", got, attempts) + } + }) +} + // TestHistory tests remediation history tracking. func TestHistory(t *testing.T) { t.Run("history records success and failure", func(t *testing.T) { diff --git a/pkg/types/config.go b/pkg/types/config.go index e31394a..471359b 100644 --- a/pkg/types/config.go +++ b/pkg/types/config.go @@ -405,9 +405,11 @@ type RemediationConfig struct { // Safety limits. MaxRemediationsPerHour int `json:"maxRemediationsPerHour,omitempty" yaml:"maxRemediationsPerHour,omitempty"` - // MaxRemediationsPerMinute is parsed but not enforced at runtime; only the - // per-hour bucket (maxPerHour) is wired to RemediatorRegistry. Reserved for - // a future fine-grained burst-limiting feature. + // MaxRemediationsPerMinute is enforced at runtime via the RemediatorRegistry's + // per-minute token bucket (wired through SetMaxRemediationsPerMinute). It caps + // the burst rate of remediations within any one minute, complementing the + // per-hour sliding window (maxPerHour). A value of 0 means unlimited (the + // per-minute check is skipped). MaxRemediationsPerMinute int `json:"maxRemediationsPerMinute,omitempty" yaml:"maxRemediationsPerMinute,omitempty"` // Cooldown configuration From 7e439cec674c44117d69a39464d0203d2a45ccf9 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 06:35:34 -0500 Subject: [PATCH 04/11] feat(logger): wire structured slog logging from config (Task #15217) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New pkg/logger.Init(cfg) builds a slog JSON/text handler at the configured level writing to stdout/stderr/file (consuming LogOutput/LogFile, which were validated but unused), sets slog default, and bridges the stdlib log package via an io.Writer that strips [LEVEL] prefixes -> slog levels — so all 374 existing log.Printf sites honor the configured format/destination with zero churn. Wired in main after config validation (warn-and-continue on error). Converted 7 detector hot paths to native structured slog with attributes. Updated stale config comment. Tests: JSON parseable, text, file output, prefix-bridge level mapping, level filtering. Also poll-for-stat fix to 3 pre-existing flaky remediation tests exposed by the slog timing change (production logic unchanged). --- cmd/node-doctor/main.go | 22 ++- pkg/detector/detector.go | 31 ++-- pkg/detector/remediation_test.go | 37 +++-- pkg/logger/logger.go | 190 ++++++++++++++++++++++++ pkg/logger/logger_test.go | 247 +++++++++++++++++++++++++++++++ pkg/types/config.go | 6 +- 6 files changed, 503 insertions(+), 30 deletions(-) create mode 100644 pkg/logger/logger.go create mode 100644 pkg/logger/logger_test.go diff --git a/cmd/node-doctor/main.go b/cmd/node-doctor/main.go index 5da9efa..079906b 100644 --- a/cmd/node-doctor/main.go +++ b/cmd/node-doctor/main.go @@ -18,6 +18,7 @@ import ( kubernetesexporter "github.com/supporttools/node-doctor/pkg/exporters/kubernetes" prometheusexporter "github.com/supporttools/node-doctor/pkg/exporters/prometheus" "github.com/supporttools/node-doctor/pkg/health" + "github.com/supporttools/node-doctor/pkg/logger" "github.com/supporttools/node-doctor/pkg/monitors" "github.com/supporttools/node-doctor/pkg/remediators" "github.com/supporttools/node-doctor/pkg/types" @@ -163,6 +164,15 @@ func main() { log.Fatalf("Configuration validation failed: %v", err) } + // Wire structured logging (slog) from the validated config. This installs the + // configured format/level/destination and bridges the standard log package so + // all subsequent log.Printf("[LEVEL] ...") calls below honor it. Errors here + // (e.g. an unopenable log file) are non-fatal: warn and continue on defaults + // rather than crash the daemon over logging setup. + if err := logger.Init(config); err != nil { + log.Printf("[WARN] Structured logging setup failed, continuing with default logging: %v", err) + } + if *validateConfig { log.Printf("[INFO] Configuration validation passed") return @@ -173,11 +183,15 @@ func main() { return } - // Setup basic logging (detailed logging setup would need more implementation) - log.Printf("[INFO] Node Doctor starting on node: %s", config.Settings.NodeName) - log.Printf("[INFO] Log level: %s, format: %s", config.Settings.LogLevel, config.Settings.LogFormat) + // Structured startup banner via slog (logging is wired above). + logger.L().Info("node doctor starting", + "node", config.Settings.NodeName, + "logLevel", config.Settings.LogLevel, + "logFormat", config.Settings.LogFormat, + "logOutput", config.Settings.LogOutput, + ) if config.Settings.DryRunMode { - log.Printf("[WARN] Running in DRY-RUN mode - no actual remediation will be performed") + logger.L().Warn("running in dry-run mode; no actual remediation will be performed") } // Create context for graceful shutdown diff --git a/pkg/detector/detector.go b/pkg/detector/detector.go index 9761c4e..16f926d 100644 --- a/pkg/detector/detector.go +++ b/pkg/detector/detector.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "log/slog" "reflect" "strings" "sync" @@ -450,7 +451,7 @@ func (pd *ProblemDetector) processStatuses() { // processStatus processes a single status update func (pd *ProblemDetector) processStatus(status *types.Status) { - log.Printf("[DEBUG] Processing status from %s", status.Source) + slog.Debug("processing status", "monitor", status.Source) // Update statistics pd.stats.IncrementStatusesReceived() @@ -459,7 +460,7 @@ func (pd *ProblemDetector) processStatus(status *types.Status) { // with a synthetic "blocked" status. This prevents exporters from seeing // misleading results from a monitor whose prerequisites are not satisfied. if blocked, blockedBy := pd.isMonitorBlocked(status.Source); blocked { - log.Printf("[INFO] Monitor %s is blocked: dependency %q is unhealthy; emitting blocked status", status.Source, blockedBy) + slog.Info("monitor blocked by unhealthy dependency", "monitor", status.Source, "blockedBy", blockedBy) status = synthBlockedStatus(status.Source, blockedBy) } @@ -476,7 +477,7 @@ func (pd *ProblemDetector) processStatus(status *types.Status) { // causing duplicate Kubernetes resources. See GitHub issue #7. for _, exporter := range pd.exporters { if err := exporter.ExportStatus(pd.ctx, status); err != nil { - log.Printf("[WARN] Failed to export status to exporter: %v", err) + slog.Warn("failed to export status", "monitor", status.Source, "error", err) pd.stats.IncrementExportsFailed() } else { pd.stats.IncrementExportsSucceeded() @@ -544,12 +545,14 @@ func (pd *ProblemDetector) evaluateRemediation(status *types.Status) { } if err := registry.RemediateWithStrategies(pd.ctx, strategyTypes, problem); err != nil { - log.Printf("[WARN] Remediation failed for %s/%s (strategies=%v): %v", - status.Source, cond.Type, strategyTypes, err) + slog.Warn("remediation failed", + "monitor", status.Source, "condition", cond.Type, + "strategies", strategyTypes, "error", err) pd.stats.IncrementRemediationsFailed() } else { - log.Printf("[INFO] Remediation triggered for %s/%s (strategies=%v, dry-run=%v)", - status.Source, cond.Type, strategyTypes, registry.IsDryRun()) + slog.Info("remediation triggered", + "monitor", status.Source, "condition", cond.Type, + "strategies", strategyTypes, "dryRun", registry.IsDryRun()) pd.stats.IncrementRemediationsTriggered() } } @@ -653,26 +656,28 @@ func (pd *ProblemDetector) handleConfigReload(ctx context.Context, newConfig *ty return nil } - log.Printf("[INFO] Applying configuration reload") + slog.Info("applying configuration reload") // Log summary of changes if !diff.HasChanges() { - log.Printf("[INFO] No configuration changes detected") + slog.Info("no configuration changes detected") pd.emitReloadEvent(types.EventInfo, "NoChanges", "Configuration reload completed with no changes") return nil } - log.Printf("[INFO] Config changes detected: %d monitors added, %d modified, %d removed", - len(diff.MonitorsAdded), len(diff.MonitorsModified), len(diff.MonitorsRemoved)) + slog.Info("config changes detected", + "added", len(diff.MonitorsAdded), + "modified", len(diff.MonitorsModified), + "removed", len(diff.MonitorsRemoved)) // Apply the reload if err := pd.applyConfigReload(ctx, newConfig, diff); err != nil { - log.Printf("[ERROR] Configuration reload failed: %v", err) + slog.Error("configuration reload failed", "error", err) pd.emitReloadEvent(types.EventError, "ReloadFailed", fmt.Sprintf("Configuration reload failed: %v", err)) return fmt.Errorf("configuration reload failed: %w", err) } - log.Printf("[INFO] Configuration reload completed successfully") + slog.Info("configuration reload completed successfully") pd.emitReloadEvent(types.EventInfo, "ReloadSuccess", "Configuration reload completed successfully") return nil diff --git a/pkg/detector/remediation_test.go b/pkg/detector/remediation_test.go index f118f98..3f609fb 100644 --- a/pkg/detector/remediation_test.go +++ b/pkg/detector/remediation_test.go @@ -188,9 +188,14 @@ func TestEvaluateRemediation_UnhealthyConditionTriggersRemediation(t *testing.T) if call.Problem.Resource != "KubeletHealthy" { t.Errorf("expected problem resource KubeletHealthy, got %q", call.Problem.Resource) } - snap := pd.GetStatistics() - if snap.GetRemediationsTriggered() != 1 { - t.Errorf("expected remediationsTriggered=1, got %d", snap.GetRemediationsTriggered()) + // The triggered-stat increment lands after the executor call returns, so poll + // for it rather than reading immediately after the CallCount poll. + if !pollUntil(t, time.Second, func() bool { + s := pd.GetStatistics() + return s.GetRemediationsTriggered() == 1 + }) { + s := pd.GetStatistics() + t.Errorf("expected remediationsTriggered=1, got %d", s.GetRemediationsTriggered()) } } @@ -234,9 +239,14 @@ func TestEvaluateRemediation_DryRunExecutorCalled(t *testing.T) { if !exec.IsDryRun() { t.Error("expected executor.IsDryRun() == true") } - snap := pd.GetStatistics() - if snap.GetRemediationsTriggered() != 1 { - t.Errorf("expected remediationsTriggered=1 for dry-run, got %d", snap.GetRemediationsTriggered()) + // The triggered-stat increment lands after the executor call returns, so poll + // for it rather than reading immediately after the CallCount poll. + if !pollUntil(t, time.Second, func() bool { + s := pd.GetStatistics() + return s.GetRemediationsTriggered() == 1 + }) { + s := pd.GetStatistics() + t.Errorf("expected remediationsTriggered=1 for dry-run, got %d", s.GetRemediationsTriggered()) } } @@ -374,12 +384,19 @@ func TestEvaluateRemediation_MultiStrategyDispatch(t *testing.T) { } // All strategies failed → counts as a failed remediation, not triggered. - snap := pd.GetStatistics() - if snap.GetRemediationsFailed() != 1 { + // The failed-stat increment happens in evaluateRemediation AFTER the executor + // calls complete, so poll for it rather than reading immediately after the + // CallCount poll (which would race the increment). + if !pollUntil(t, time.Second, func() bool { + snap := pd.GetStatistics() + return snap.GetRemediationsFailed() == 1 + }) { + snap := pd.GetStatistics() t.Errorf("expected remediationsFailed=1, got %d", snap.GetRemediationsFailed()) } - if snap.GetRemediationsTriggered() != 0 { - t.Errorf("expected remediationsTriggered=0, got %d", snap.GetRemediationsTriggered()) + snap := pd.GetStatistics() + if got := snap.GetRemediationsTriggered(); got != 0 { + t.Errorf("expected remediationsTriggered=0, got %d", got) } } diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go new file mode 100644 index 0000000..50485d0 --- /dev/null +++ b/pkg/logger/logger.go @@ -0,0 +1,190 @@ +// Package logger wires structured logging (stdlib log/slog) from the validated +// Node Doctor configuration at startup. +// +// It serves two roles: +// +// 1. It configures slog's default logger from cfg.Settings (LogLevel, +// LogFormat, LogOutput, LogFile) so code that wants real structured logging +// can call logger.L() (or slog.Default()) and get key/value attributes. +// +// 2. It bridges the standard library `log` package onto slog so the large body +// of existing `log.Printf("[LEVEL] ...")` call sites flow through the same +// handler — honoring the configured format and destination — without having +// to be rewritten. The bridge strips a leading "[INFO]"/"[WARN]"/"[ERROR]"/ +// "[DEBUG]" token (case-insensitive) from each line and maps it to the +// corresponding slog level; lines with no recognized prefix log at info. +package logger + +import ( + "context" + "fmt" + "io" + "log" + "log/slog" + "os" + "strings" + + "github.com/supporttools/node-doctor/pkg/types" +) + +// LevelFatal is a synthetic level above Error used to map the "fatal" log level +// (a valid value in config) onto slog, which has no native fatal level. Records +// at this level are always emitted (it sits above Error). +const LevelFatal = slog.Level(12) + +// L returns the active structured logger (slog's default). Code that wants to +// emit structured key/value attributes should call logger.L().Info(...) etc. +func L() *slog.Logger { + return slog.Default() +} + +// ParseLevel maps a config log level string (debug/info/warn/error/fatal) to a +// slog.Level. The match is case-insensitive and tolerant of surrounding +// whitespace. Unknown values fall back to slog.LevelInfo. +func ParseLevel(level string) slog.Level { + switch strings.ToLower(strings.TrimSpace(level)) { + case "debug": + return slog.LevelDebug + case "info": + return slog.LevelInfo + case "warn", "warning": + return slog.LevelWarn + case "error": + return slog.LevelError + case "fatal": + return LevelFatal + default: + return slog.LevelInfo + } +} + +// resolveWriter resolves the output writer for the configured LogOutput. +// Valid LogOutput values (see pkg/types/config.go) are "stdout", "stderr", and +// "file". For "file" it opens (creating if needed) cfg LogFile in append mode +// with 0644 permissions, returning a clear error if LogFile is empty or the +// file cannot be opened. +func resolveWriter(output, file string) (io.Writer, error) { + switch strings.ToLower(strings.TrimSpace(output)) { + case "", "stdout": + return os.Stdout, nil + case "stderr": + return os.Stderr, nil + case "file": + if strings.TrimSpace(file) == "" { + return nil, fmt.Errorf("logOutput is %q but logFile is empty", output) + } + f, err := os.OpenFile(file, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) //nolint:gosec // log path comes from validated operator config + if err != nil { + return nil, fmt.Errorf("open log file %q: %w", file, err) + } + return f, nil + default: + return nil, fmt.Errorf("unsupported logOutput %q (want stdout, stderr or file)", output) + } +} + +// NewHandler builds a slog.Handler for the given format ("json" or "text"), +// writing to w at the given minimum level. Any unrecognized format defaults to +// JSON. It is exported so the construction can be unit-tested without mutating +// global slog/log state. +func NewHandler(w io.Writer, format string, level slog.Level) slog.Handler { + opts := &slog.HandlerOptions{Level: level} + switch strings.ToLower(strings.TrimSpace(format)) { + case "text": + return slog.NewTextHandler(w, opts) + default: // "json" and anything unexpected + return slog.NewJSONHandler(w, opts) + } +} + +// Init configures structured logging from the validated configuration. It: +// +// - resolves the output writer from cfg.Settings.LogOutput / LogFile, +// - builds a JSON or text slog.Handler at the level mapped from LogLevel, +// - installs it as slog's default logger, and +// - redirects the standard `log` package through the prefix bridge so existing +// log.Printf("[LEVEL] ...") calls honor the same format/destination. +// +// On error (e.g. an unopenable log file) Init returns the error WITHOUT mutating +// any global state, so the caller can warn-and-continue on the prior defaults. +func Init(cfg *types.NodeDoctorConfig) error { + if cfg == nil { + return fmt.Errorf("nil config") + } + w, err := resolveWriter(cfg.Settings.LogOutput, cfg.Settings.LogFile) + if err != nil { + return err + } + + level := ParseLevel(cfg.Settings.LogLevel) + handler := NewHandler(w, cfg.Settings.LogFormat, level) + l := slog.New(handler) + + slog.SetDefault(l) + + // Bridge the standard log package onto slog. Clearing the flags prevents the + // stdlib from prepending its own timestamp (slog adds its own time field). + log.SetFlags(0) + log.SetOutput(&bridgeWriter{logger: l}) + + return nil +} + +// bridgeWriter is an io.Writer adapter that routes every standard-library log +// write through slog. Per write it strips a recognized level prefix, trims the +// trailing newline, and emits the remainder at the mapped level. +type bridgeWriter struct { + logger *slog.Logger +} + +// Write implements io.Writer. The stdlib log package issues exactly one Write +// per log call (with a trailing newline), so we treat each write as one record. +func (b *bridgeWriter) Write(p []byte) (int, error) { + n := len(p) + msg := strings.TrimRight(string(p), "\n") + level, msg := splitPrefix(msg) + l := b.logger + if l == nil { + l = slog.Default() + } + l.Log(context.Background(), level, msg) + return n, nil +} + +// splitPrefix strips a leading, optionally bracketed level token from a log +// line and returns the mapped slog level plus the remaining message. The token +// is matched case-insensitively. Recognized tokens: DEBUG, INFO, WARN/WARNING, +// ERROR, FATAL. A line with no recognized prefix maps to info with the message +// returned unchanged. +func splitPrefix(line string) (slog.Level, string) { + trimmed := strings.TrimLeft(line, " \t") + if !strings.HasPrefix(trimmed, "[") { + return slog.LevelInfo, line + } + end := strings.IndexByte(trimmed, ']') + if end < 0 { + return slog.LevelInfo, line + } + token := strings.ToUpper(strings.TrimSpace(trimmed[1:end])) + var level slog.Level + switch token { + case "DEBUG": + level = slog.LevelDebug + case "INFO": + level = slog.LevelInfo + case "WARN", "WARNING": + level = slog.LevelWarn + case "ERROR": + level = slog.LevelError + case "FATAL": + level = LevelFatal + default: + // Not a recognized level token; leave the line intact at info. + return slog.LevelInfo, line + } + rest := strings.TrimLeft(trimmed[end+1:], " \t") + return level, rest +} + +// Ensure bridgeWriter satisfies io.Writer. +var _ io.Writer = (*bridgeWriter)(nil) diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go new file mode 100644 index 0000000..a9ec822 --- /dev/null +++ b/pkg/logger/logger_test.go @@ -0,0 +1,247 @@ +package logger + +import ( + "bytes" + "context" + "encoding/json" + "log" + "log/slog" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/supporttools/node-doctor/pkg/types" +) + +// jsonHandlerLogger builds an slog.Logger writing JSON to buf at the given level, +// without mutating global slog state. +func jsonLogger(buf *bytes.Buffer, level slog.Level) *slog.Logger { + return slog.New(NewHandler(buf, "json", level)) +} + +func TestParseLevel(t *testing.T) { + cases := map[string]slog.Level{ + "debug": slog.LevelDebug, + "INFO": slog.LevelInfo, + " warn ": slog.LevelWarn, + "warning": slog.LevelWarn, + "error": slog.LevelError, + "fatal": LevelFatal, + "": slog.LevelInfo, + "bogus": slog.LevelInfo, + } + for in, want := range cases { + if got := ParseLevel(in); got != want { + t.Errorf("ParseLevel(%q) = %v, want %v", in, got, want) + } + } +} + +func TestNewHandlerJSON(t *testing.T) { + var buf bytes.Buffer + l := jsonLogger(&buf, slog.LevelInfo) + l.Info("hello world", "k", "v") + + var rec map[string]any + if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &rec); err != nil { + t.Fatalf("output is not valid JSON: %v (%q)", err, buf.String()) + } + if rec["msg"] != "hello world" { + t.Errorf("msg = %v, want %q", rec["msg"], "hello world") + } + if rec["level"] != "INFO" { + t.Errorf("level = %v, want INFO", rec["level"]) + } + if rec["k"] != "v" { + t.Errorf("attr k = %v, want v", rec["k"]) + } +} + +func TestNewHandlerText(t *testing.T) { + var buf bytes.Buffer + l := slog.New(NewHandler(&buf, "text", slog.LevelInfo)) + l.Info("human readable") + + out := buf.String() + // Text output is not JSON. + var rec map[string]any + if json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &rec) == nil { + t.Errorf("text handler produced JSON: %q", out) + } + if !strings.Contains(out, "human readable") { + t.Errorf("text output missing message: %q", out) + } + if !strings.Contains(out, "level=INFO") { + t.Errorf("text output missing level: %q", out) + } +} + +func TestLevelFiltering(t *testing.T) { + var buf bytes.Buffer + l := jsonLogger(&buf, slog.LevelInfo) + l.Debug("should be suppressed") + if buf.Len() != 0 { + t.Errorf("debug record emitted at info level: %q", buf.String()) + } + l.Info("should appear") + if !strings.Contains(buf.String(), "should appear") { + t.Errorf("info record suppressed: %q", buf.String()) + } +} + +func TestSplitPrefix(t *testing.T) { + cases := []struct { + in string + level slog.Level + msg string + }{ + {"[ERROR] boom", slog.LevelError, "boom"}, + {"[error] lower", slog.LevelError, "lower"}, + {"[WARN] careful", slog.LevelWarn, "careful"}, + {"[WARNING] careful", slog.LevelWarn, "careful"}, + {"[INFO] note", slog.LevelInfo, "note"}, + {"[DEBUG] trace", slog.LevelDebug, "trace"}, + {"[FATAL] dying", LevelFatal, "dying"}, + {"no prefix here", slog.LevelInfo, "no prefix here"}, + {"[NOTALEVEL] keep", slog.LevelInfo, "[NOTALEVEL] keep"}, + {" [INFO] spaced ", slog.LevelInfo, "spaced"}, + } + for _, c := range cases { + gotLevel, gotMsg := splitPrefix(c.in) + if gotLevel != c.level || strings.TrimRight(gotMsg, " ") != c.msg { + t.Errorf("splitPrefix(%q) = (%v, %q), want (%v, %q)", c.in, gotLevel, gotMsg, c.level, c.msg) + } + } +} + +func TestBridgeWriter(t *testing.T) { + var buf bytes.Buffer + l := jsonLogger(&buf, slog.LevelDebug) + bw := &bridgeWriter{logger: l} + + if _, err := bw.Write([]byte("[ERROR] boom\n")); err != nil { + t.Fatalf("write: %v", err) + } + var rec map[string]any + if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &rec); err != nil { + t.Fatalf("bridged output not JSON: %v (%q)", err, buf.String()) + } + if rec["level"] != "ERROR" { + t.Errorf("bridged level = %v, want ERROR", rec["level"]) + } + if rec["msg"] != "boom" { + t.Errorf("bridged msg = %v, want boom", rec["msg"]) + } + + // No-prefix line maps to info with message intact. + buf.Reset() + if _, err := bw.Write([]byte("plain line\n")); err != nil { + t.Fatalf("write: %v", err) + } + rec = nil + if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &rec); err != nil { + t.Fatalf("bridged output not JSON: %v (%q)", err, buf.String()) + } + if rec["level"] != "INFO" || rec["msg"] != "plain line" { + t.Errorf("no-prefix bridge = (%v, %v), want (INFO, plain line)", rec["level"], rec["msg"]) + } +} + +// withSavedGlobals saves and restores the global slog default and standard log +// package output/flags so Init tests don't leak state into other tests. +func withSavedGlobals(t *testing.T) { + t.Helper() + prevDefault := slog.Default() + prevFlags := log.Flags() + prevPrefix := log.Prefix() + prevOut := log.Writer() + t.Cleanup(func() { + slog.SetDefault(prevDefault) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + log.SetOutput(prevOut) + }) +} + +func TestInitFileOutput(t *testing.T) { + withSavedGlobals(t) + + dir := t.TempDir() + logPath := filepath.Join(dir, "nd.log") + cfg := &types.NodeDoctorConfig{} + cfg.Settings.LogLevel = "info" + cfg.Settings.LogFormat = "json" + cfg.Settings.LogOutput = "file" + cfg.Settings.LogFile = logPath + + if err := Init(cfg); err != nil { + t.Fatalf("Init: %v", err) + } + + // Native structured call and bridged stdlib call both go to the file. + L().Info("structured record", "kind", "native") + log.Printf("[WARN] bridged record") + + data, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("read log file: %v", err) + } + if _, err := os.Stat(logPath); err != nil { + t.Fatalf("log file missing: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + if len(lines) < 2 { + t.Fatalf("expected >=2 log lines, got %d: %q", len(lines), string(data)) + } + + var sawNative, sawBridged bool + for _, line := range lines { + var rec map[string]any + if err := json.Unmarshal([]byte(line), &rec); err != nil { + t.Fatalf("log line not JSON: %v (%q)", err, line) + } + if rec["msg"] == "structured record" && rec["level"] == "INFO" { + sawNative = true + } + if rec["msg"] == "bridged record" && rec["level"] == "WARN" { + sawBridged = true + } + } + if !sawNative { + t.Errorf("native structured record not found in %q", string(data)) + } + if !sawBridged { + t.Errorf("bridged WARN record not found in %q", string(data)) + } +} + +func TestInitFileOutputEmptyFileErrors(t *testing.T) { + withSavedGlobals(t) + + cfg := &types.NodeDoctorConfig{} + cfg.Settings.LogOutput = "file" + cfg.Settings.LogFile = "" + + if err := Init(cfg); err == nil { + t.Fatal("expected error for file output with empty logFile, got nil") + } +} + +func TestInitNilConfig(t *testing.T) { + withSavedGlobals(t) + if err := Init(nil); err == nil { + t.Fatal("expected error for nil config") + } +} + +func TestBridgeWriterContextNotNil(t *testing.T) { + // Guard against a nil context regression in Write. + var buf bytes.Buffer + bw := &bridgeWriter{logger: jsonLogger(&buf, slog.LevelInfo)} + bw.logger.Log(context.Background(), slog.LevelInfo, "ctx ok") + if !strings.Contains(buf.String(), "ctx ok") { + t.Errorf("context log failed: %q", buf.String()) + } +} diff --git a/pkg/types/config.go b/pkg/types/config.go index 471359b..bd74f8a 100644 --- a/pkg/types/config.go +++ b/pkg/types/config.go @@ -137,9 +137,9 @@ type GlobalSettings struct { // Logging configuration LogLevel string `json:"logLevel,omitempty" yaml:"logLevel,omitempty"` LogFormat string `json:"logFormat,omitempty" yaml:"logFormat,omitempty"` - // LogOutput and LogFile are parsed and validated but not yet consumed by the - // agent bootstrap; logging routing is currently configured via environment - // variables. Reserved for a future structured-logging refactor. + // LogOutput and LogFile route structured logging output. They are consumed by + // pkg/logger.Init at startup: LogOutput selects stdout/stderr/file and, when + // "file", LogFile names the append-mode log destination. LogOutput string `json:"logOutput,omitempty" yaml:"logOutput,omitempty"` LogFile string `json:"logFile,omitempty" yaml:"logFile,omitempty"` From 0adcda5cbeb710d5b5e13ea3cfd98de4009b5f39 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 06:39:35 -0500 Subject: [PATCH 05/11] test(remediators): FallbackOnUnreachable lease tests + storm-risk warnings (Task #15218) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add lease_client_test.go covering RequestLease: reachable+grants, reachable+denies, unreachable+fallback=true (FAKE approval — test documents the cluster-wide remediation storm this enables), and unreachable+fallback=false (error). Add prominent storm-risk WARNING above the FallbackOnUnreachable config field and in the helm values comment (value stays false). Committed only my hunks (pre-existing image.tag bump excluded). --- helm/node-doctor/values.yaml | 10 ++ pkg/remediators/lease_client_test.go | 157 +++++++++++++++++++++++++++ pkg/types/config.go | 14 ++- 3 files changed, 179 insertions(+), 2 deletions(-) create mode 100644 pkg/remediators/lease_client_test.go diff --git a/helm/node-doctor/values.yaml b/helm/node-doctor/values.yaml index 492376f..abe66d1 100644 --- a/helm/node-doctor/values.yaml +++ b/helm/node-doctor/values.yaml @@ -543,6 +543,16 @@ controller: # if the controller is unreachable (e.g., during upgrades). # false = safer (blocks remediations if controller is down). # true = higher availability (agents act independently when unreachable). + # + # WARNING - cluster-wide remediation storm risk: + # Setting this to true defeats lease coordination during a controller + # outage. When the controller is down, EVERY node self-approves its own + # lease at the same time, so ALL nodes may remediate (reboot/drain/etc.) + # simultaneously - a thundering-herd, cluster-wide remediation storm. + # Lease coordination exists to serialize remediations (single node at a + # time) and prevent exactly this. Keep this false unless single-node-at-a- + # time coordination is not required and you explicitly want higher + # availability over storm protection. fallbackOnUnreachable: false # Node affinity - prefer control-plane nodes diff --git a/pkg/remediators/lease_client_test.go b/pkg/remediators/lease_client_test.go new file mode 100644 index 0000000..ebca77e --- /dev/null +++ b/pkg/remediators/lease_client_test.go @@ -0,0 +1,157 @@ +package remediators + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/supporttools/node-doctor/pkg/controller" + "github.com/supporttools/node-doctor/pkg/types" +) + +// newTestCoordinationConfig builds a RemediationCoordinationConfig pointed at the +// given controller URL with short timeouts/retries so the unreachable cases fail +// fast rather than blocking the test suite. +func newTestCoordinationConfig(controllerURL string, fallback bool) *types.RemediationCoordinationConfig { + return &types.RemediationCoordinationConfig{ + Enabled: true, + ControllerURL: controllerURL, + LeaseTimeout: 5 * time.Minute, + RequestTimeout: 500 * time.Millisecond, + FallbackOnUnreachable: fallback, + MaxRetries: 1, + RetryInterval: 10 * time.Millisecond, + } +} + +// TestRequestLease_ReachableGrants verifies that when the controller is reachable +// and approves the lease, RequestLease returns Approved=true with no error. +func TestRequestLease_ReachableGrants(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Respond with an approved lease, mirroring the controller's + // LeaseResponse JSON shape (see pkg/controller/types.go). + resp := controller.LeaseResponse{ + LeaseID: "lease-123", + Approved: true, + ExpiresAt: time.Now().Add(5 * time.Minute), + Message: "granted", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + lc, err := NewLeaseClient(newTestCoordinationConfig(server.URL, false), "test-node") + if err != nil { + t.Fatalf("NewLeaseClient returned error: %v", err) + } + + resp, err := lc.RequestLease(context.Background(), "reboot", "test reason") + if err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if resp == nil { + t.Fatal("expected a lease response, got nil") + } + if !resp.Approved { + t.Errorf("expected Approved=true, got Approved=false (message: %q)", resp.Message) + } +} + +// TestRequestLease_ReachableDenies verifies that when the controller is reachable +// but denies the lease (concurrency limit, HTTP 429), RequestLease returns the +// denial with Approved=false and no transport error. +func TestRequestLease_ReachableDenies(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := controller.LeaseResponse{ + Approved: false, + Message: "denied: too many concurrent remediations", + } + w.Header().Set("Content-Type", "application/json") + // 429 is mapped by sendLeaseRequest to a denied (non-error) response. + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + lc, err := NewLeaseClient(newTestCoordinationConfig(server.URL, false), "test-node") + if err != nil { + t.Fatalf("NewLeaseClient returned error: %v", err) + } + + resp, err := lc.RequestLease(context.Background(), "reboot", "test reason") + if err != nil { + t.Fatalf("expected no transport error on denial, got: %v", err) + } + if resp == nil { + t.Fatal("expected a lease response, got nil") + } + if resp.Approved { + t.Error("expected Approved=false (denied), got Approved=true") + } +} + +// TestRequestLease_UnreachableFallbackTrue verifies the fallback path. +// +// STORM RISK WARNING: With FallbackOnUnreachable=true, an unreachable controller +// causes RequestLease to return a FAKE approval (Approved=true) WITHOUT any real +// lease coordination. This path is what ALLOWS A CLUSTER-WIDE REMEDIATION STORM: +// during a controller outage, EVERY DaemonSet node self-approves its own lease +// simultaneously, so ALL nodes can remediate (e.g. reboot/drain) at the same time. +// The lease coordination exists precisely to serialize remediations and prevent +// this thundering-herd behavior; enabling fallback trades that safety for +// availability. This test documents and pins that intentional (dangerous) behavior. +func TestRequestLease_UnreachableFallbackTrue(t *testing.T) { + // Create a server then immediately close it so the URL is valid but + // every request fails to connect (controller is unreachable). + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + unreachableURL := server.URL + server.Close() + + lc, err := NewLeaseClient(newTestCoordinationConfig(unreachableURL, true), "test-node") + if err != nil { + t.Fatalf("NewLeaseClient returned error: %v", err) + } + + resp, err := lc.RequestLease(context.Background(), "reboot", "test reason") + if err != nil { + t.Fatalf("expected no error with fallback enabled, got: %v", err) + } + if resp == nil { + t.Fatal("expected a fallback lease response, got nil") + } + // FAKE approval: no controller ever granted this lease. + if !resp.Approved { + t.Error("expected fallback Approved=true, got Approved=false") + } + if resp.Message == "" { + t.Error("expected a fallback message describing the unreachable controller") + } +} + +// TestRequestLease_UnreachableFallbackFalse verifies the safe (default) path: +// with FallbackOnUnreachable=false, an unreachable controller causes RequestLease +// to return an ERROR and NO approval, blocking remediation rather than risking a +// cluster-wide storm. +func TestRequestLease_UnreachableFallbackFalse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + unreachableURL := server.URL + server.Close() + + lc, err := NewLeaseClient(newTestCoordinationConfig(unreachableURL, false), "test-node") + if err != nil { + t.Fatalf("NewLeaseClient returned error: %v", err) + } + + resp, err := lc.RequestLease(context.Background(), "reboot", "test reason") + if err == nil { + t.Fatal("expected an error when controller is unreachable and fallback is disabled, got nil") + } + if resp != nil { + t.Errorf("expected no lease response on error, got: %+v", resp) + } +} diff --git a/pkg/types/config.go b/pkg/types/config.go index bd74f8a..bc88bc6 100644 --- a/pkg/types/config.go +++ b/pkg/types/config.go @@ -448,8 +448,18 @@ type RemediationCoordinationConfig struct { RequestTimeoutString string `json:"requestTimeout,omitempty" yaml:"requestTimeout,omitempty"` RequestTimeout time.Duration `json:"-" yaml:"-"` - // FallbackOnUnreachable determines behavior when controller is unreachable - // If true, proceed with remediation; if false, block and wait for controller + // FallbackOnUnreachable determines behavior when the controller is unreachable. + // If true, proceed with remediation; if false, block and wait for the controller. + // + // WARNING (cluster-wide remediation storm risk): Setting this to true defeats + // the lease coordination during a controller outage. When the controller is + // down, EVERY DaemonSet node's lease request "succeeds" with a fake self- + // approval simultaneously, so ALL nodes may remediate (reboot/drain/etc.) at + // the same time — a thundering-herd, cluster-wide remediation storm. The lease + // coordination exists precisely to serialize remediations (single node at a + // time) and prevent that. Defaults to false (safe: block when the controller is + // unreachable). Only enable when single-node-at-a-time coordination is not + // required and higher availability is preferred over storm protection. FallbackOnUnreachable bool `json:"fallbackOnUnreachable,omitempty" yaml:"fallbackOnUnreachable,omitempty"` // MaxRetries is the maximum number of lease request retries From c9a480a85c7ad23e53c1324d1085d6bd4d4056d7 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 06:41:41 -0500 Subject: [PATCH 06/11] build: add test-ci Makefile gate running unit + integration (Task #15219) Add a test-ci target that runs the fast unit suite (-short) AND the integration suite (go test ./test/integration/... no -short, 10m timeout) so coverage gaps like TestLeaseCoordinationFlow and TestCorrelationDetectionFlow (test/integration/controller/) are exercised in one gate. Build-tagged integration tests (kind dual-stack) still need -tags=integration and are not run here. Added to .PHONY. --- Makefile | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index cf78a8f..e0490b9 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ qa-check devils-advocate workflow-status \ gh-status gh-watch gh-logs gh-builds \ check-prerequisites check-docker check-kubectl \ - build test test-integration test-e2e test-all \ + build test test-integration test-e2e test-all test-ci \ test-net-icmp-integration \ lint fmt clean install-deps \ docker-build docker-push \ @@ -223,6 +223,21 @@ test-integration: fi @$(call print_success,"Integration tests completed") +# CI gate: fast unit tests (-short) PLUS integration tests (no -short), so +# coverage gaps like TestLeaseCoordinationFlow / TestCorrelationDetectionFlow +# (in test/integration/controller/) are exercised. Build-tagged integration +# tests (e.g. the kind dual-stack test) still require -tags=integration and are +# not run here. Use this target in CI in place of `test` alone. +test-ci: + @$(call print_status,"Running CI test gate (unit + integration)...") + @go test ./pkg/... ./cmd/... -v -cover -short + @if [ -d "test/integration" ]; then \ + go test ./test/integration/... -v -cover -timeout 10m; \ + else \ + $(call print_warning,"Integration tests not yet implemented (test/integration/ does not exist)"); \ + fi + @$(call print_success,"CI test gate passed") + # End-to-end tests test-e2e: @$(call print_status,"Running E2E tests...") From ce2e6ff2fd53b1d8b0160af091018b95eaa1fd59 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 06:43:59 -0500 Subject: [PATCH 07/11] feat(controller): enable SQLite WAL + synchronous=NORMAL (Task #15220) Apply PRAGMA journal_mode=WAL and synchronous=NORMAL in SQLiteStorage. Initialize (after ping) so concurrent RequestLease reads don't block the single writer and writes are faster while staying crash-safe in WAL mode. Connection pool was already pinned (SetMaxOpenConns/IdleConns(1), ConnMaxLifetime(0)). Add TestSQLiteStorage_WALAndSynchronous (file-backed temp DB) asserting journal_mode=wal, synchronous=1 (NORMAL), and MaxOpenConnections=1. --- pkg/controller/storage.go | 15 ++++++++++++ pkg/controller/storage_test.go | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/pkg/controller/storage.go b/pkg/controller/storage.go index 3ba6a33..e4e6563 100644 --- a/pkg/controller/storage.go +++ b/pkg/controller/storage.go @@ -98,6 +98,21 @@ func (s *SQLiteStorage) Initialize(ctx context.Context) error { return fmt.Errorf("failed to ping database: %w", err) } + // Enable WAL journaling + NORMAL synchronous for better read/write + // concurrency and throughput under many concurrent RequestLease calls. + // WAL lets readers proceed without blocking the single writer; NORMAL is + // durable enough in WAL mode (a crash can only lose the last transaction, + // not corrupt the DB). For in-memory DBs WAL is a no-op (stays "memory"). + for _, pragma := range []string{ + "PRAGMA journal_mode=WAL;", + "PRAGMA synchronous=NORMAL;", + } { + if _, err := db.ExecContext(ctx, pragma); err != nil { + _ = db.Close() + return fmt.Errorf("failed to apply %q: %w", pragma, err) + } + } + s.db = db // Run migrations diff --git a/pkg/controller/storage_test.go b/pkg/controller/storage_test.go index 6a48e3f..adde0ed 100644 --- a/pkg/controller/storage_test.go +++ b/pkg/controller/storage_test.go @@ -2,6 +2,8 @@ package controller import ( "context" + "path/filepath" + "strings" "testing" "time" ) @@ -33,6 +35,49 @@ func newTestStorage(t *testing.T) *SQLiteStorage { return storage } +// TestSQLiteStorage_WALAndSynchronous verifies that Initialize applies +// journal_mode=WAL and synchronous=NORMAL (Task #15220). WAL only takes effect +// for a file-backed database (it is a no-op for :memory:), so this uses a temp +// file. SetMaxOpenConns(1)/SetMaxIdleConns(1)/SetConnMaxLifetime(0) are also +// asserted via DB().Stats() bounds. +func TestSQLiteStorage_WALAndSynchronous(t *testing.T) { + config := &StorageConfig{ + Path: filepath.Join(t.TempDir(), "node-doctor.db"), + Retention: 24 * time.Hour, + } + storage, err := NewSQLiteStorage(config) + if err != nil { + t.Fatalf("NewSQLiteStorage() error = %v", err) + } + ctx := context.Background() + if err := storage.Initialize(ctx); err != nil { + t.Fatalf("Initialize() error = %v", err) + } + t.Cleanup(func() { storage.Close() }) + + var journalMode string + if err := storage.db.QueryRowContext(ctx, "PRAGMA journal_mode;").Scan(&journalMode); err != nil { + t.Fatalf("query journal_mode: %v", err) + } + if !strings.EqualFold(journalMode, "wal") { + t.Errorf("journal_mode = %q, want wal", journalMode) + } + + // PRAGMA synchronous returns the integer mode: 0=OFF, 1=NORMAL, 2=FULL. + var synchronous int + if err := storage.db.QueryRowContext(ctx, "PRAGMA synchronous;").Scan(&synchronous); err != nil { + t.Fatalf("query synchronous: %v", err) + } + if synchronous != 1 { + t.Errorf("synchronous = %d, want 1 (NORMAL)", synchronous) + } + + // Connection pool is pinned to a single connection for SQLite. + if got := storage.db.Stats().MaxOpenConnections; got != 1 { + t.Errorf("MaxOpenConnections = %d, want 1", got) + } +} + func TestNewSQLiteStorage(t *testing.T) { t.Run("creates storage with valid config", func(t *testing.T) { config := &StorageConfig{ From 4eded3f2517d8e631949be6328b75218d2dab5a8 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 08:07:39 -0500 Subject: [PATCH 08/11] feat(remediators): register safe built-in remediators + thread per-strategy config (Task #19263 phase 1) Built-in remediators were never registered, so registry.Remediate failed for every strategy. Add RegisterBuiltinRemediators (wired in main) for the two SAFE strategies: systemd-restart->SystemdRemediator, custom-script-> CustomRemediator. Thread per-strategy params via Problem.Metadata (service/scriptPath/args): buildStrategyList now returns the ordered []MonitorRemediationConfig and the detector loops them (first-success-wins preserved), encoding each strategy's params into the dispatched Problem. Systemd/Custom remediators read params from metadata with config fallback; CustomRemediator keeps abs-path/no-.. validation on the metadata path. node-reboot/pod-delete intentionally NOT registered (destructive; phase 2). Tests: registration, per-strategy metadata threading, dry-run dispatch, scriptPath validation. detector race-clean. --- cmd/node-doctor/main.go | 11 ++ pkg/detector/detector.go | 119 ++++++++++++---- pkg/detector/remediation_test.go | 104 +++++++++++++- pkg/remediators/builtin.go | 98 ++++++++++++++ pkg/remediators/builtin_test.go | 226 +++++++++++++++++++++++++++++++ pkg/remediators/custom.go | 141 ++++++++++++++++--- pkg/remediators/custom_test.go | 21 ++- pkg/remediators/registry.go | 20 +++ pkg/remediators/systemd.go | 105 +++++++++----- pkg/remediators/systemd_test.go | 7 +- 10 files changed, 762 insertions(+), 90 deletions(-) create mode 100644 pkg/remediators/builtin.go create mode 100644 pkg/remediators/builtin_test.go diff --git a/cmd/node-doctor/main.go b/cmd/node-doctor/main.go index 079906b..e89611b 100644 --- a/cmd/node-doctor/main.go +++ b/cmd/node-doctor/main.go @@ -222,6 +222,17 @@ func main() { log.Printf("[INFO] Remediator registry initialized (dry-run=%v, maxPerHour=%d, maxPerMinute=%d)", remediatorRegistry.IsDryRun(), maxPerHour, config.Remediation.MaxRemediationsPerMinute) + // Register the built-in remediator strategies so the detector's dispatch + // (which addresses a remediator by its strategy type) can find one. + // TaskForge #19263 Phase 1 registers ONLY the two SAFE strategies + // (systemd-restart, custom-script); the destructive node-reboot/pod-delete + // strategies are deferred to Phase 2 and remain unregistered (a config + // naming them fails dispatch, which is the desired fail-safe). SetDryRun + // is applied above so the closures pick up the correct dry-run state. + remediators.RegisterBuiltinRemediators(remediatorRegistry, config) + log.Printf("[INFO] Registered built-in remediators: %v (node-reboot/pod-delete deferred to Phase 2)", + remediatorRegistry.GetRegisteredTypes()) + // Wire the controller lease client when coordination is opted in. // The registry's Remediate path checks for a non-nil lease client and // performs RequestLease/ReleaseLease internally; nothing else to wire. diff --git a/pkg/detector/detector.go b/pkg/detector/detector.go index 16f926d..e1d20a3 100644 --- a/pkg/detector/detector.go +++ b/pkg/detector/detector.go @@ -2,6 +2,7 @@ package detector import ( "context" + "encoding/json" "fmt" "log" "log/slog" @@ -517,37 +518,61 @@ func (pd *ProblemDetector) evaluateRemediation(status *types.Status) { remCfg := monitorCfg.Remediation - // Build the ordered list of remediation strategy types to attempt. + // Build the ordered list of remediation strategies to attempt. Each entry + // carries its OWN per-strategy parameters (Service for systemd-restart, + // ScriptPath/Args for custom-script) so that a multi-strategy list with + // differing params threads each strategy's params independently. + // // When Strategies is non-empty, dispatch each nested strategy in order - // (first-success-wins). Otherwise fall back to the single Strategy so that + // (first-success-wins). Otherwise fall back to the single top-level config so // existing single-strategy configs behave exactly as before. - strategyTypes := buildStrategyList(remCfg) - if len(strategyTypes) == 0 { + strategies := buildStrategyList(remCfg) + if len(strategies) == 0 { return } + // strategyTypes is kept purely for logging context (the ordered type names). + strategyTypes := make([]string, len(strategies)) + for i, s := range strategies { + strategyTypes[i] = s.Strategy + } + for _, cond := range status.Conditions { if cond.Status != types.ConditionFalse { continue } - problem := types.Problem{ - Type: strategyTypes[0], - Resource: cond.Type, - Severity: types.ProblemWarning, - Message: cond.Message, - DetectedAt: cond.Transition, - Metadata: map[string]string{ - "source": status.Source, - "condition": cond.Type, - "reason": cond.Reason, - }, + // Dispatch the ordered strategies, threading each strategy's own params + // into the Problem.Metadata before attempting it. First success wins; the + // single-strategy case is just a one-element loop, preserving prior + // single-Strategy behavior. + var lastErr error + succeeded := false + for i, strat := range strategies { + problem := types.Problem{ + Type: strat.Strategy, + Resource: cond.Type, + Severity: types.ProblemWarning, + Message: cond.Message, + DetectedAt: cond.Transition, + Metadata: buildProblemMetadata(status.Source, cond, strat), + } + + err := registry.Remediate(pd.ctx, strat.Strategy, problem) + if err == nil { + succeeded = true + break + } + lastErr = err + slog.Debug("remediation strategy failed, trying next", + "monitor", status.Source, "condition", cond.Type, + "strategy", strat.Strategy, "index", i+1, "of", len(strategies), "error", err) } - if err := registry.RemediateWithStrategies(pd.ctx, strategyTypes, problem); err != nil { + if !succeeded { slog.Warn("remediation failed", "monitor", status.Source, "condition", cond.Type, - "strategies", strategyTypes, "error", err) + "strategies", strategyTypes, "error", lastErr) pd.stats.IncrementRemediationsFailed() } else { slog.Info("remediation triggered", @@ -558,23 +583,57 @@ func (pd *ProblemDetector) evaluateRemediation(status *types.Status) { } } -// buildStrategyList returns the ordered list of remediation strategy types to -// attempt for a monitor's remediation config. +// buildProblemMetadata builds the Problem.Metadata carrier for a single +// remediation strategy attempt. It always includes the source/condition/reason +// context the detector has always set, and additionally threads the strategy's +// OWN per-strategy parameters so the registered remediator can resolve them at +// Remediate time: +// - "service" : strat.Service (systemd-restart target service) +// - "scriptPath" : strat.ScriptPath (custom-script script path) +// - "args" : JSON-encoded strat.Args (custom-script arguments) +// +// Only non-empty params are added so a strategy that does not use a given param +// does not pollute the metadata. +func buildProblemMetadata(source string, cond types.Condition, strat types.MonitorRemediationConfig) map[string]string { + meta := map[string]string{ + "source": source, + "condition": cond.Type, + "reason": cond.Reason, + } + if strat.Service != "" { + meta["service"] = strat.Service + } + if strat.ScriptPath != "" { + meta["scriptPath"] = strat.ScriptPath + } + if len(strat.Args) > 0 { + // JSON-encode args so values containing commas/spaces survive intact. + if encoded, err := json.Marshal(strat.Args); err == nil { + meta["args"] = string(encoded) + } + } + return meta +} + +// buildStrategyList returns the ordered list of remediation strategies to +// attempt for a monitor's remediation config, each carrying its own per-strategy +// parameters (Strategy, Service, ScriptPath, Args). // -// When remCfg.Strategies is non-empty, the nested strategies are dispatched in -// order (each strategy's Strategy field, skipping empties). Otherwise it falls -// back to the single remCfg.Strategy, preserving backward compatibility: -// a config with only Strategy set yields []string{Strategy}. -func buildStrategyList(remCfg *types.MonitorRemediationConfig) []string { +// When remCfg.Strategies is non-empty, the nested strategies are returned in +// order (skipping entries with an empty Strategy). Otherwise it falls back to a +// single entry built from the top-level remCfg fields, preserving backward +// compatibility: a config with only Strategy/Service/ScriptPath set yields a +// one-element list carrying those params. +func buildStrategyList(remCfg *types.MonitorRemediationConfig) []types.MonitorRemediationConfig { if remCfg == nil { return nil } if len(remCfg.Strategies) > 0 { - strategies := make([]string, 0, len(remCfg.Strategies)) + strategies := make([]types.MonitorRemediationConfig, 0, len(remCfg.Strategies)) for _, s := range remCfg.Strategies { if s.Strategy != "" { - strategies = append(strategies, s.Strategy) + strategies = append(strategies, s) } } if len(strategies) > 0 { @@ -583,7 +642,13 @@ func buildStrategyList(remCfg *types.MonitorRemediationConfig) []string { } if remCfg.Strategy != "" { - return []string{remCfg.Strategy} + // Carry the top-level per-strategy params for the single-strategy fallback. + return []types.MonitorRemediationConfig{{ + Strategy: remCfg.Strategy, + Service: remCfg.Service, + ScriptPath: remCfg.ScriptPath, + Args: remCfg.Args, + }} } return nil diff --git a/pkg/detector/remediation_test.go b/pkg/detector/remediation_test.go index 3f609fb..82281a8 100644 --- a/pkg/detector/remediation_test.go +++ b/pkg/detector/remediation_test.go @@ -335,11 +335,11 @@ func TestBuildStrategyList(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got := buildStrategyList(tt.remCfg) if len(got) != len(tt.want) { - t.Fatalf("buildStrategyList() = %v, want %v", got, tt.want) + t.Fatalf("buildStrategyList() len = %d (%v), want %d (%v)", len(got), got, len(tt.want), tt.want) } for i := range got { - if got[i] != tt.want[i] { - t.Errorf("buildStrategyList()[%d] = %q, want %q", i, got[i], tt.want[i]) + if got[i].Strategy != tt.want[i] { + t.Errorf("buildStrategyList()[%d].Strategy = %q, want %q", i, got[i].Strategy, tt.want[i]) } } }) @@ -400,6 +400,104 @@ func TestEvaluateRemediation_MultiStrategyDispatch(t *testing.T) { } } +// TestEvaluateRemediation_ThreadsServiceMetadata verifies that a single-strategy +// systemd-restart config with a Service threads that service into the dispatched +// Problem.Metadata["service"] (TaskForge #19263 Phase 1). +func TestEvaluateRemediation_ThreadsServiceMetadata(t *testing.T) { + exec := NewMockRemediationExecutor() + + monCfg := types.MonitorConfig{ + Name: "svc-monitor", + Type: "test", + Enabled: true, + Interval: 30 * time.Second, + Timeout: 10 * time.Second, + Remediation: &types.MonitorRemediationConfig{ + Enabled: true, + Strategy: "systemd-restart", + Service: "kubelet", + }, + } + pd, mon := buildDetectorWithRemediation(t, monCfg, exec) + + mon.AddStatusUpdate(unhealthyStatus("svc-monitor", "KubeletHealthy")) + if !pollUntil(t, time.Second, func() bool { return exec.CallCount() == 1 }) { + t.Fatalf("expected 1 executor call within 1s, got %d", exec.CallCount()) + } + _ = pd + + call := exec.Calls()[0] + if call.RemediatorType != "systemd-restart" { + t.Errorf("strategy = %q, want systemd-restart", call.RemediatorType) + } + if got := call.Problem.Metadata["service"]; got != "kubelet" { + t.Errorf("Problem.Metadata[service] = %q, want kubelet", got) + } + // Source/condition/reason context preserved. + if got := call.Problem.Metadata["condition"]; got != "KubeletHealthy" { + t.Errorf("Problem.Metadata[condition] = %q, want KubeletHealthy", got) + } +} + +// TestEvaluateRemediation_MultiStrategyThreadsPerStrategyParams verifies that a +// multi-strategy list with DIFFERING params threads each strategy's own params +// into its dispatched Problem.Metadata (service for systemd-restart, scriptPath +// + JSON-encoded args for custom-script). +func TestEvaluateRemediation_MultiStrategyThreadsPerStrategyParams(t *testing.T) { + exec := NewMockRemediationExecutor() + // Force every attempt to fail so the detector walks all strategies in order. + exec.SetError(fmt.Errorf("simulated failure")) + + monCfg := types.MonitorConfig{ + Name: "multi-param-monitor", + Type: "test", + Enabled: true, + Interval: 30 * time.Second, + Timeout: 10 * time.Second, + Remediation: &types.MonitorRemediationConfig{ + Enabled: true, + Strategy: "node-reboot", // top-level (param-free) required by validation; ignored in favor of Strategies + Strategies: []types.MonitorRemediationConfig{ + {Strategy: "systemd-restart", Service: "containerd"}, + {Strategy: "custom-script", ScriptPath: "/opt/fix.sh", Args: []string{"--force", "a b"}}, + }, + }, + } + pd, mon := buildDetectorWithRemediation(t, monCfg, exec) + _ = pd + + mon.AddStatusUpdate(unhealthyStatus("multi-param-monitor", "ServiceHealthy")) + if !pollUntil(t, time.Second, func() bool { return exec.CallCount() == 2 }) { + t.Fatalf("expected 2 ordered executor calls within 1s, got %d", exec.CallCount()) + } + + calls := exec.Calls() + // First strategy: systemd-restart with service=containerd, no scriptPath/args. + if calls[0].RemediatorType != "systemd-restart" { + t.Errorf("first strategy = %q, want systemd-restart", calls[0].RemediatorType) + } + if got := calls[0].Problem.Metadata["service"]; got != "containerd" { + t.Errorf("first Problem.Metadata[service] = %q, want containerd", got) + } + if _, ok := calls[0].Problem.Metadata["scriptPath"]; ok { + t.Errorf("first Problem.Metadata should not carry scriptPath, got %v", calls[0].Problem.Metadata) + } + + // Second strategy: custom-script with scriptPath + JSON-encoded args, no service. + if calls[1].RemediatorType != "custom-script" { + t.Errorf("second strategy = %q, want custom-script", calls[1].RemediatorType) + } + if got := calls[1].Problem.Metadata["scriptPath"]; got != "/opt/fix.sh" { + t.Errorf("second Problem.Metadata[scriptPath] = %q, want /opt/fix.sh", got) + } + if got := calls[1].Problem.Metadata["args"]; got != `["--force","a b"]` { + t.Errorf("second Problem.Metadata[args] = %q, want JSON array", got) + } + if _, ok := calls[1].Problem.Metadata["service"]; ok { + t.Errorf("second Problem.Metadata should not carry service, got %v", calls[1].Problem.Metadata) + } +} + // TestEvaluateRemediation_MultiStrategyFirstSuccessWins verifies that when the // first strategy succeeds, subsequent strategies are not attempted. func TestEvaluateRemediation_MultiStrategyFirstSuccessWins(t *testing.T) { diff --git a/pkg/remediators/builtin.go b/pkg/remediators/builtin.go new file mode 100644 index 0000000..e10d033 --- /dev/null +++ b/pkg/remediators/builtin.go @@ -0,0 +1,98 @@ +package remediators + +import ( + "time" + + "github.com/supporttools/node-doctor/pkg/types" +) + +// Built-in remediator strategy type names. These match the strategy enum in +// pkg/types/config.go (config.validRemediationStrategies) so a strategy named in +// a MonitorRemediationConfig dispatches to the matching registered remediator. +const ( + // StrategySystemdRestart restarts a systemd service. The target service is + // supplied per-call via Problem.Metadata["service"]. + StrategySystemdRestart = "systemd-restart" + + // StrategyCustomScript runs a user-provided script. The script path/args are + // supplied per-call via Problem.Metadata["scriptPath"]/["args"]. + StrategyCustomScript = "custom-script" + + // StrategyNodeReboot and StrategyPodDelete are DESTRUCTIVE strategies that are + // intentionally NOT registered by RegisterBuiltinRemediators (Phase 1). They + // are deferred to Phase 2 (TaskForge #19263 Phase 2). Until then, a config + // that names them will fail dispatch with "unknown remediator type", which is + // the desired fail-safe behavior for un-implemented destructive actions. + StrategyNodeReboot = "node-reboot" + StrategyPodDelete = "pod-delete" +) + +// RegisterBuiltinRemediators registers the SAFE built-in remediator strategies +// with the registry so the detector's remediation dispatch (which addresses a +// remediator by its strategy type) can actually find a remediator. Before this +// wiring (TaskForge #19263) nothing registered any remediator type, so every +// dispatch failed with "unknown remediator type". +// +// Phase 1 registers ONLY the two non-destructive strategies: +// +// - "systemd-restart" -> a SystemdRemediator configured for the restart +// operation. It is a metadata-driven singleton: the per-call target service +// comes from Problem.Metadata["service"] (threaded by the detector from each +// strategy's MonitorRemediationConfig.Service), falling back to any global +// default service if one is ever configured. This means a single registered +// remediator handles kubelet, containerd, docker, etc. +// +// - "custom-script" -> a CustomRemediator with no fixed script path. It is a +// metadata-driven singleton: the per-call script path/args come from +// Problem.Metadata["scriptPath"]/["args"]. The metadata path is held to the +// same absolute-path / no-".." safety validation as a configured path, so the +// deferred construction does NOT weaken any safety check. +// +// The DESTRUCTIVE strategies "node-reboot" and "pod-delete" are intentionally +// NOT registered here. They are Phase 2 work; leaving them unregistered means a +// config naming them fails dispatch (fail-safe) until they are implemented. +// +// dryRun is taken from the registry's own dry-run state so the built-in +// remediators honour the global dry-run/dry-run-mode flag even on the path +// (config.Remediation.DryRun) that the registry already applies, and so a script +// that exists is not actually executed during a dry run. +// +// Register panics on duplicate or empty types; RegisterBuiltinRemediators must +// therefore be called exactly once per registry (main wiring does this). +func RegisterBuiltinRemediators(registry *RemediatorRegistry, cfg *types.NodeDoctorConfig) { + if registry == nil { + return + } + + dryRun := registry.IsDryRun() + + // systemd-restart: restart operation, service resolved per-call from metadata. + registry.Register(RemediatorInfo{ + Type: StrategySystemdRestart, + Factory: func() (types.Remediator, error) { + return NewSystemdRemediator(SystemdConfig{ + Operation: SystemdRestart, + // ServiceName intentionally empty: resolved per-call from + // Problem.Metadata["service"]. Verification is left off by default; + // per-call params do not (yet) carry a verify flag. + DryRun: dryRun, + }) + }, + Description: "Restarts the systemd service named in the triggering monitor's remediation config (Problem.Metadata[\"service\"]).", + }) + + // custom-script: script path/args resolved per-call from metadata. + registry.Register(RemediatorInfo{ + Type: StrategyCustomScript, + Factory: func() (types.Remediator, error) { + return NewCustomRemediator(CustomConfig{ + // ScriptPath intentionally empty: resolved per-call from + // Problem.Metadata["scriptPath"] (validated absolute / no ".."). + Timeout: 5 * time.Minute, + CaptureOutput: true, + DryRun: dryRun, + }) + }, + Description: "Runs the remediation script named in the triggering monitor's remediation config (Problem.Metadata[\"scriptPath\"]/[\"args\"]).", + }) +} diff --git a/pkg/remediators/builtin_test.go b/pkg/remediators/builtin_test.go new file mode 100644 index 0000000..200e07b --- /dev/null +++ b/pkg/remediators/builtin_test.go @@ -0,0 +1,226 @@ +package remediators + +import ( + "context" + "testing" + "time" + + "github.com/supporttools/node-doctor/pkg/types" +) + +// TestRegisterBuiltinRemediators_RegistersSafeStrategiesOnly verifies that +// Phase 1 registers ONLY systemd-restart and custom-script, and explicitly does +// NOT register the destructive node-reboot / pod-delete strategies (Phase 2). +func TestRegisterBuiltinRemediators_RegistersSafeStrategiesOnly(t *testing.T) { + registry := NewRegistry(10, 100) + RegisterBuiltinRemediators(registry, &types.NodeDoctorConfig{}) + + if !registry.IsRegistered(StrategySystemdRestart) { + t.Errorf("expected %q to be registered", StrategySystemdRestart) + } + if !registry.IsRegistered(StrategyCustomScript) { + t.Errorf("expected %q to be registered", StrategyCustomScript) + } + + // Destructive strategies must NOT be registered in Phase 1. + if registry.IsRegistered(StrategyNodeReboot) { + t.Errorf("%q must NOT be registered in Phase 1 (destructive, deferred to Phase 2)", StrategyNodeReboot) + } + if registry.IsRegistered(StrategyPodDelete) { + t.Errorf("%q must NOT be registered in Phase 1 (destructive, deferred to Phase 2)", StrategyPodDelete) + } + + got := registry.GetRegisteredTypes() + if len(got) != 2 { + t.Errorf("expected exactly 2 registered types, got %d: %v", len(got), got) + } +} + +// TestRegisterBuiltinRemediators_NilRegistryNoPanic verifies the nil guard. +func TestRegisterBuiltinRemediators_NilRegistryNoPanic(t *testing.T) { + RegisterBuiltinRemediators(nil, &types.NodeDoctorConfig{}) +} + +// TestBuiltinSystemdRestart_DryRunReachesRemediatorWithMetadataService verifies +// that a Problem carrying metadata["service"] dispatched via +// Remediate("systemd-restart", ...) reaches the SystemdRemediator and succeeds +// in dry-run (no real systemctl runs). +func TestBuiltinSystemdRestart_DryRunReachesRemediatorWithMetadataService(t *testing.T) { + registry := NewRegistry(10, 100) + registry.SetDryRun(true) + RegisterBuiltinRemediators(registry, &types.NodeDoctorConfig{}) + + problem := types.Problem{ + Type: StrategySystemdRestart, + Resource: "KubeletHealthy", + Severity: types.ProblemWarning, + Metadata: map[string]string{metadataKeyService: "kubelet"}, + } + + if err := registry.Remediate(context.Background(), StrategySystemdRestart, problem); err != nil { + t.Fatalf("Remediate(systemd-restart) dry-run failed: %v", err) + } +} + +// TestBuiltinSystemd_MetadataServiceUsedOverConfig verifies that the systemd +// remediator resolves the service from Problem.Metadata at Remediate time and +// actually issues the systemctl command against it (injected executor, no real +// systemctl). +func TestBuiltinSystemd_MetadataServiceUsedOverConfig(t *testing.T) { + // Build the same remediator the builtin factory builds: empty ServiceName, + // restart operation — a metadata-driven singleton. + r, err := NewSystemdRemediator(SystemdConfig{Operation: SystemdRestart}) + if err != nil { + t.Fatalf("NewSystemdRemediator: %v", err) + } + mock := &mockSystemdExecutor{serviceActive: true} + r.SetSystemdExecutor(mock) + + problem := types.Problem{ + Type: StrategySystemdRestart, + Metadata: map[string]string{metadataKeyService: "containerd"}, + } + if err := r.Remediate(context.Background(), problem); err != nil { + t.Fatalf("Remediate: %v", err) + } + + cmds := mock.getExecutedCommands() + foundRestart := false + for _, c := range cmds { + if c == "systemctl [restart containerd]" { + foundRestart = true + } + } + if !foundRestart { + t.Errorf("expected systemctl restart containerd, executed: %v", cmds) + } +} + +// TestBuiltinSystemd_NoServiceFails verifies that a metadata-driven singleton +// with neither config nor metadata service fails at Remediate time with a clear +// error (rather than panicking or restarting an empty service). +func TestBuiltinSystemd_NoServiceFails(t *testing.T) { + r, err := NewSystemdRemediator(SystemdConfig{Operation: SystemdRestart}) + if err != nil { + t.Fatalf("NewSystemdRemediator: %v", err) + } + r.SetSystemdExecutor(&mockSystemdExecutor{}) + + problem := types.Problem{Type: StrategySystemdRestart} // no metadata service + if err := r.Remediate(context.Background(), problem); err == nil { + t.Fatal("expected error when no service is specified, got nil") + } +} + +// TestBuiltinCustomScript_MetadataScriptPathReachesRemediator verifies that a +// Problem carrying metadata["scriptPath"]/["args"] reaches the CustomRemediator, +// which executes the metadata path with the metadata args (injected executor). +func TestBuiltinCustomScript_MetadataScriptPathReachesRemediator(t *testing.T) { + // Build the same remediator the builtin factory builds: empty ScriptPath. + r, err := NewCustomRemediator(CustomConfig{Timeout: time.Minute}) + if err != nil { + t.Fatalf("NewCustomRemediator: %v", err) + } + mock := &mockScriptExecutor{} + r.SetScriptExecutor(mock) + + problem := types.Problem{ + Type: StrategyCustomScript, + Metadata: map[string]string{ + metadataKeyScriptPath: "/opt/remediate.sh", + metadataKeyArgs: `["--force","a b"]`, + }, + } + if err := r.Remediate(context.Background(), problem); err != nil { + t.Fatalf("Remediate: %v", err) + } + + scripts := mock.getExecutedScripts() + if len(scripts) != 1 || scripts[0] != "/opt/remediate.sh" { + t.Fatalf("expected /opt/remediate.sh executed, got %v", scripts) + } + if len(mock.executedArgs) != 1 || len(mock.executedArgs[0]) != 2 || + mock.executedArgs[0][0] != "--force" || mock.executedArgs[0][1] != "a b" { + t.Errorf("expected args [--force, 'a b'], got %v", mock.executedArgs) + } +} + +// TestBuiltinCustomScript_DryRunDispatch verifies the registry dispatch path for +// custom-script in dry-run: the metadata scriptPath reaches the remediator and +// the dry run succeeds WITHOUT executing the script. +func TestBuiltinCustomScript_DryRunDispatch(t *testing.T) { + registry := NewRegistry(10, 100) + registry.SetDryRun(true) + RegisterBuiltinRemediators(registry, &types.NodeDoctorConfig{}) + + problem := types.Problem{ + Type: StrategyCustomScript, + Metadata: map[string]string{metadataKeyScriptPath: "/opt/fix.sh"}, + } + if err := registry.Remediate(context.Background(), StrategyCustomScript, problem); err != nil { + t.Fatalf("Remediate(custom-script) dry-run failed: %v", err) + } +} + +// TestBuiltinCustomScript_InvalidMetadataPathRejected verifies that the metadata +// supplied script path is held to the SAME safety validation (absolute, no "..") +// even on the metadata-driven path: relative and traversal paths are rejected +// and the script is NOT executed. +func TestBuiltinCustomScript_InvalidMetadataPathRejected(t *testing.T) { + cases := []struct { + name string + scriptPath string + }{ + {"relative", "relative/fix.sh"}, + {"traversal", "/opt/../../etc/fix.sh"}, + {"empty (no fallback)", ""}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, err := NewCustomRemediator(CustomConfig{Timeout: time.Minute}) + if err != nil { + t.Fatalf("NewCustomRemediator: %v", err) + } + mock := &mockScriptExecutor{} + r.SetScriptExecutor(mock) + + problem := types.Problem{Type: StrategyCustomScript} + if tc.scriptPath != "" { + problem.Metadata = map[string]string{metadataKeyScriptPath: tc.scriptPath} + } + + if err := r.Remediate(context.Background(), problem); err == nil { + t.Fatalf("expected error for invalid script path %q, got nil", tc.scriptPath) + } + if got := mock.getExecutedScripts(); len(got) != 0 { + t.Errorf("script must NOT be executed for invalid path, executed: %v", got) + } + }) + } +} + +// TestBuiltinCustomScript_InvalidArgsJSONRejected verifies that an unparseable +// args metadata value is rejected rather than silently ignored. +func TestBuiltinCustomScript_InvalidArgsJSONRejected(t *testing.T) { + r, err := NewCustomRemediator(CustomConfig{Timeout: time.Minute}) + if err != nil { + t.Fatalf("NewCustomRemediator: %v", err) + } + mock := &mockScriptExecutor{} + r.SetScriptExecutor(mock) + + problem := types.Problem{ + Type: StrategyCustomScript, + Metadata: map[string]string{ + metadataKeyScriptPath: "/opt/fix.sh", + metadataKeyArgs: "not-json", + }, + } + if err := r.Remediate(context.Background(), problem); err == nil { + t.Fatal("expected error for unparseable args JSON, got nil") + } + if got := mock.getExecutedScripts(); len(got) != 0 { + t.Errorf("script must NOT be executed when args are invalid, executed: %v", got) + } +} diff --git a/pkg/remediators/custom.go b/pkg/remediators/custom.go index e7d6305..37576fb 100644 --- a/pkg/remediators/custom.go +++ b/pkg/remediators/custom.go @@ -2,6 +2,7 @@ package remediators import ( "context" + "encoding/json" "fmt" "os" "os/exec" @@ -131,14 +132,25 @@ func (e *defaultScriptExecutor) CheckScriptSafety(scriptPath string) error { } // NewCustomRemediator creates a new custom script remediator with the given configuration. +// +// ScriptPath may be empty: such a remediator is a metadata-driven singleton +// whose script is supplied per-call via Problem.Metadata["scriptPath"] (see +// RegisterBuiltinRemediators). The per-call path is still subject to the same +// absolute-path / no-".." safety validation at Remediate time. When ScriptPath +// IS set at construction time it is validated immediately (absolute, no ".."). func NewCustomRemediator(config CustomConfig) (*CustomRemediator, error) { // Validate configuration if err := validateCustomConfig(&config); err != nil { return nil, fmt.Errorf("invalid custom config: %w", err) } - // Create base remediator with medium cooldown (5 minutes for custom scripts) - scriptName := filepath.Base(config.ScriptPath) + // Create base remediator with medium cooldown (5 minutes for custom scripts). + // When ScriptPath is empty the remediator is metadata-driven; use a stable + // label so the base remediator name is still unique. + scriptName := "dynamic" + if config.ScriptPath != "" { + scriptName = filepath.Base(config.ScriptPath) + } base, err := NewBaseRemediator( fmt.Sprintf("custom-%s", scriptName), CooldownMedium, @@ -161,20 +173,42 @@ func NewCustomRemediator(config CustomConfig) (*CustomRemediator, error) { return remediator, nil } -// validateCustomConfig validates the custom remediator configuration. -func validateCustomConfig(config *CustomConfig) error { - // Validate script path - if config.ScriptPath == "" { +// validateScriptPath enforces the custom-script safety policy on a script path: +// the path must be absolute and must not contain a ".." traversal component. +// It is applied both at construction time (when ScriptPath is configured) and at +// Remediate time on the metadata-supplied path, so a metadata-driven singleton +// can never bypass the safety check. +func validateScriptPath(scriptPath string) error { + if scriptPath == "" { return fmt.Errorf("script path is required") } + if !filepath.IsAbs(scriptPath) { + return fmt.Errorf("script path must be absolute: %q", scriptPath) + } + // Reject any ".." path-traversal component on the RAW path (not the cleaned + // path): an absolute path like "/opt/../../etc/x" cleans to "/etc/x" with no + // ".." remaining, so checking the cleaned path would let it through. Checking + // the raw components rejects traversal unconditionally, matching the config- + // layer policy (types.MonitorRemediationConfig.Validate). + for _, part := range strings.Split(scriptPath, string(filepath.Separator)) { + if part == ".." { + return fmt.Errorf("script path must not contain '..': %q", scriptPath) + } + } + return nil +} - // Convert to absolute path if relative - if !filepath.IsAbs(config.ScriptPath) { - absPath, err := filepath.Abs(config.ScriptPath) - if err != nil { - return fmt.Errorf("failed to resolve absolute path: %w", err) +// validateCustomConfig validates the custom remediator configuration. +// +// ScriptPath is optional: an empty ScriptPath produces a metadata-driven +// singleton whose script is supplied per-call via Problem.Metadata. When +// ScriptPath IS provided it must pass validateScriptPath (absolute, no ".."). +func validateCustomConfig(config *CustomConfig) error { + // Validate script path when configured (empty => metadata-driven singleton). + if config.ScriptPath != "" { + if err := validateScriptPath(config.ScriptPath); err != nil { + return err } - config.ScriptPath = absPath } // Set defaults @@ -191,8 +225,10 @@ func validateCustomConfig(config *CustomConfig) error { return fmt.Errorf("timeout must not exceed 30 minutes") } - // Set working directory to script's directory if not specified - if config.WorkingDir == "" { + // Set working directory to script's directory if not specified and a script + // path is configured. For metadata-driven singletons the working directory is + // resolved per-call from the metadata script path. + if config.WorkingDir == "" && config.ScriptPath != "" { config.WorkingDir = filepath.Dir(config.ScriptPath) } @@ -205,17 +241,30 @@ func validateCustomConfig(config *CustomConfig) error { } // remediate performs the actual custom script execution. +// +// The script path and args are resolved per-call: when the Problem carries a +// "scriptPath" metadata key (threaded by the detector from the strategy's +// MonitorRemediationConfig.ScriptPath), that path is used — after passing the +// same absolute-path / no-".." safety validation as a configured path — and the +// optional "args" metadata key (JSON-encoded) overrides the configured args. +// When the metadata path is absent the remediator falls back to its +// construction-time config.ScriptPath/ScriptArgs. func (r *CustomRemediator) remediate(ctx context.Context, problem types.Problem) error { - // Safety checks - if err := r.scriptExecutor.CheckScriptSafety(r.config.ScriptPath); err != nil { + scriptPath, scriptArgs, workingDir, err := r.resolveScript(problem) + if err != nil { + return err + } + + // Safety checks (existence/regular-file/executable) on the resolved path. + if err := r.scriptExecutor.CheckScriptSafety(scriptPath); err != nil { return fmt.Errorf("script safety check failed: %w", err) } - r.logInfof("Executing custom remediation script: %s", filepath.Base(r.config.ScriptPath)) + r.logInfof("Executing custom remediation script: %s", filepath.Base(scriptPath)) // Dry-run mode if r.config.DryRun { - r.logInfof("DRY-RUN: Would execute script: %s with args: %v", r.config.ScriptPath, r.config.ScriptArgs) + r.logInfof("DRY-RUN: Would execute script: %s with args: %v", scriptPath, scriptArgs) return nil } @@ -234,10 +283,10 @@ func (r *CustomRemediator) remediate(ctx context.Context, problem types.Problem) // Execute the script stdout, stderr, exitCode, err := r.scriptExecutor.ExecuteScript( execCtx, - r.config.ScriptPath, - r.config.ScriptArgs, + scriptPath, + scriptArgs, env, - r.config.WorkingDir, + workingDir, ) // Log output if configured @@ -268,6 +317,56 @@ func (r *CustomRemediator) remediate(ctx context.Context, problem types.Problem) return nil } +// resolveScript determines the script path, args, and working directory to use +// for this remediation. It prefers the per-call Problem.Metadata["scriptPath"] +// (threaded from the strategy's MonitorRemediationConfig.ScriptPath) and falls +// back to the construction-time config when absent. +// +// The metadata-supplied path is held to the SAME safety policy as a configured +// path (absolute, no ".."), so a metadata-driven singleton can never be tricked +// into running a relative or traversal path. When a metadata path is used and +// no working directory is configured, the working directory defaults to that +// script's directory. +// +// Args precedence: Problem.Metadata["args"] (JSON-encoded array) overrides the +// configured ScriptArgs when present and parseable; an unparseable args value is +// rejected rather than silently ignored. +func (r *CustomRemediator) resolveScript(problem types.Problem) (scriptPath string, args []string, workingDir string, err error) { + scriptPath = r.config.ScriptPath + args = r.config.ScriptArgs + workingDir = r.config.WorkingDir + + metaPath := "" + if problem.Metadata != nil { + metaPath = problem.Metadata[metadataKeyScriptPath] + } + + if metaPath != "" { + // Apply the same absolute-path / no-".." safety policy to the per-call path. + if verr := validateScriptPath(metaPath); verr != nil { + return "", nil, "", fmt.Errorf("invalid script path from problem metadata: %w", verr) + } + scriptPath = filepath.Clean(metaPath) + if workingDir == "" { + workingDir = filepath.Dir(scriptPath) + } + + if raw, ok := problem.Metadata[metadataKeyArgs]; ok && raw != "" { + var parsed []string + if jerr := json.Unmarshal([]byte(raw), &parsed); jerr != nil { + return "", nil, "", fmt.Errorf("invalid args from problem metadata (expected JSON array): %w", jerr) + } + args = parsed + } + } + + if scriptPath == "" { + return "", nil, "", fmt.Errorf("no script path specified (neither problem metadata %q nor config.ScriptPath set)", metadataKeyScriptPath) + } + + return scriptPath, args, workingDir, nil +} + // prepareProblemEnvironment converts problem metadata to environment variables. // This allows scripts to access problem details for context-aware remediation. func (r *CustomRemediator) prepareProblemEnvironment(problem types.Problem) map[string]string { diff --git a/pkg/remediators/custom_test.go b/pkg/remediators/custom_test.go index 4180b72..88010ff 100644 --- a/pkg/remediators/custom_test.go +++ b/pkg/remediators/custom_test.go @@ -140,12 +140,29 @@ func TestNewCustomRemediator(t *testing.T) { wantError: false, }, { - name: "empty script path", + // Empty ScriptPath is now valid: it produces a metadata-driven + // singleton whose script is supplied per-call via Problem.Metadata. + name: "empty script path (metadata-driven singleton)", config: CustomConfig{ ScriptPath: "", }, + wantError: false, + }, + { + name: "relative script path rejected", + config: CustomConfig{ + ScriptPath: "relative/remediate.sh", + }, + wantError: true, + errorMsg: `script path must be absolute: "relative/remediate.sh"`, + }, + { + name: "script path with traversal rejected", + config: CustomConfig{ + ScriptPath: "/usr/local/../../etc/remediate.sh", + }, wantError: true, - errorMsg: "script path is required", + errorMsg: `script path must not contain '..': "/usr/local/../../etc/remediate.sh"`, }, { name: "timeout too short", diff --git a/pkg/remediators/registry.go b/pkg/remediators/registry.go index d096411..813b6e7 100644 --- a/pkg/remediators/registry.go +++ b/pkg/remediators/registry.go @@ -40,6 +40,26 @@ import ( "github.com/supporttools/node-doctor/pkg/types" ) +// Problem.Metadata keys used to thread per-strategy remediation parameters from +// the detector to the registered remediator at dispatch time. The detector +// populates these from the triggering monitor's MonitorRemediationConfig so a +// single registered remediator (built via RegisterBuiltinRemediators) can act +// on per-monitor parameters without registering one remediator per service or +// script. +const ( + // metadataKeyService carries the systemd service name for the + // systemd-restart strategy (from MonitorRemediationConfig.Service). + metadataKeyService = "service" + + // metadataKeyScriptPath carries the absolute script path for the + // custom-script strategy (from MonitorRemediationConfig.ScriptPath). + metadataKeyScriptPath = "scriptPath" + + // metadataKeyArgs carries the JSON-encoded script arguments for the + // custom-script strategy (from MonitorRemediationConfig.Args). + metadataKeyArgs = "args" +) + // CircuitBreakerState represents the state of the circuit breaker. type CircuitBreakerState int diff --git a/pkg/remediators/systemd.go b/pkg/remediators/systemd.go index 19bea86..53dcb8d 100644 --- a/pkg/remediators/systemd.go +++ b/pkg/remediators/systemd.go @@ -106,9 +106,16 @@ func NewSystemdRemediator(config SystemdConfig) (*SystemdRemediator, error) { return nil, fmt.Errorf("invalid systemd config: %w", err) } - // Create base remediator with medium cooldown (5 minutes default for systemd services) + // Create base remediator with medium cooldown (5 minutes default for systemd services). + // When ServiceName is empty the remediator is a metadata-driven singleton (the + // service is resolved per-call from Problem.Metadata["service"]); use a stable + // name so the base remediator is still uniquely identifiable. + serviceLabel := config.ServiceName + if serviceLabel == "" { + serviceLabel = "dynamic" + } base, err := NewBaseRemediator( - fmt.Sprintf("systemd-%s-%s", config.Operation, config.ServiceName), + fmt.Sprintf("systemd-%s-%s", config.Operation, serviceLabel), CooldownMedium, ) if err != nil { @@ -130,11 +137,13 @@ func NewSystemdRemediator(config SystemdConfig) (*SystemdRemediator, error) { } // validateSystemdConfig validates the systemd remediator configuration. +// +// ServiceName is intentionally NOT required here: a systemd remediator may be +// registered as a metadata-driven singleton whose target service is supplied +// per-call via Problem.Metadata["service"]. When neither config.ServiceName nor +// the per-call metadata is set, remediate() fails with a clear error at +// dispatch time rather than at construction time. func validateSystemdConfig(config SystemdConfig) error { - if config.ServiceName == "" { - return fmt.Errorf("service name is required") - } - // Validate operation switch config.Operation { case SystemdRestart, SystemdStop, SystemdStart, SystemdReload: @@ -152,31 +161,44 @@ func validateSystemdConfig(config SystemdConfig) error { } // remediate performs the actual systemd service remediation. +// +// The target service name is resolved per-call: when the Problem carries a +// "service" metadata key (set by the detector from the strategy's +// MonitorRemediationConfig.Service), that value is used; otherwise the +// remediator falls back to its construction-time config.ServiceName. This lets +// a single registered systemd-restart remediator act on whatever service the +// triggering monitor declared (kubelet, containerd, etc.) without registering +// one remediator per service. func (r *SystemdRemediator) remediate(ctx context.Context, problem types.Problem) error { + serviceName := r.resolveServiceName(problem) + if serviceName == "" { + return fmt.Errorf("no systemd service specified (neither problem metadata %q nor config.ServiceName set)", metadataKeyService) + } + // Dry-run mode if r.config.DryRun { - r.logInfof("DRY-RUN: Would execute systemctl %s %s", r.config.Operation, r.config.ServiceName) + r.logInfof("DRY-RUN: Would execute systemctl %s %s", r.config.Operation, serviceName) return nil } // Check service status before remediation - wasActive, err := r.systemdExecutor.IsActive(ctx, r.config.ServiceName) + wasActive, err := r.systemdExecutor.IsActive(ctx, serviceName) if err != nil { r.logWarnf("Failed to check service status before remediation: %v", err) } else { - r.logInfof("Service %s status before remediation: active=%v", r.config.ServiceName, wasActive) + r.logInfof("Service %s status before remediation: active=%v", serviceName, wasActive) } // Execute the systemd operation - if err := r.executeOperation(ctx); err != nil { - return fmt.Errorf("failed to execute %s on %s: %w", r.config.Operation, r.config.ServiceName, err) + if err := r.executeOperation(ctx, serviceName); err != nil { + return fmt.Errorf("failed to execute %s on %s: %w", r.config.Operation, serviceName, err) } - r.logInfof("Successfully executed systemctl %s %s", r.config.Operation, r.config.ServiceName) + r.logInfof("Successfully executed systemctl %s %s", r.config.Operation, serviceName) // Verify service status after remediation if configured if r.config.VerifyStatus { - if err := r.verifyServiceStatus(ctx); err != nil { + if err := r.verifyServiceStatus(ctx, serviceName); err != nil { return fmt.Errorf("service verification failed after %s: %w", r.config.Operation, err) } } @@ -184,26 +206,39 @@ func (r *SystemdRemediator) remediate(ctx context.Context, problem types.Problem return nil } -// executeOperation executes the configured systemd operation. -func (r *SystemdRemediator) executeOperation(ctx context.Context) error { +// resolveServiceName returns the systemd service to act on for this remediation. +// It prefers the per-call "service" metadata key (threaded from the strategy's +// MonitorRemediationConfig.Service) and falls back to the construction-time +// config.ServiceName when the metadata is absent. +func (r *SystemdRemediator) resolveServiceName(problem types.Problem) string { + if problem.Metadata != nil { + if svc := problem.Metadata[metadataKeyService]; svc != "" { + return svc + } + } + return r.config.ServiceName +} + +// executeOperation executes the configured systemd operation against serviceName. +func (r *SystemdRemediator) executeOperation(ctx context.Context, serviceName string) error { switch r.config.Operation { case SystemdRestart: - return r.restart(ctx) + return r.restart(ctx, serviceName) case SystemdStop: - return r.stop(ctx) + return r.stop(ctx, serviceName) case SystemdStart: - return r.start(ctx) + return r.start(ctx, serviceName) case SystemdReload: - return r.reload(ctx) + return r.reload(ctx, serviceName) default: return fmt.Errorf("unknown operation: %s", r.config.Operation) } } // restart restarts the systemd service. -func (r *SystemdRemediator) restart(ctx context.Context) error { - r.logInfof("Restarting service: %s", r.config.ServiceName) - output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "restart", r.config.ServiceName) +func (r *SystemdRemediator) restart(ctx context.Context, serviceName string) error { + r.logInfof("Restarting service: %s", serviceName) + output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "restart", serviceName) if err != nil { return fmt.Errorf("systemctl restart failed: %w (output: %s)", err, output) } @@ -211,9 +246,9 @@ func (r *SystemdRemediator) restart(ctx context.Context) error { } // stop stops the systemd service. -func (r *SystemdRemediator) stop(ctx context.Context) error { - r.logInfof("Stopping service: %s", r.config.ServiceName) - output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "stop", r.config.ServiceName) +func (r *SystemdRemediator) stop(ctx context.Context, serviceName string) error { + r.logInfof("Stopping service: %s", serviceName) + output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "stop", serviceName) if err != nil { return fmt.Errorf("systemctl stop failed: %w (output: %s)", err, output) } @@ -221,9 +256,9 @@ func (r *SystemdRemediator) stop(ctx context.Context) error { } // start starts the systemd service. -func (r *SystemdRemediator) start(ctx context.Context) error { - r.logInfof("Starting service: %s", r.config.ServiceName) - output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "start", r.config.ServiceName) +func (r *SystemdRemediator) start(ctx context.Context, serviceName string) error { + r.logInfof("Starting service: %s", serviceName) + output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "start", serviceName) if err != nil { return fmt.Errorf("systemctl start failed: %w (output: %s)", err, output) } @@ -231,9 +266,9 @@ func (r *SystemdRemediator) start(ctx context.Context) error { } // reload reloads the systemd service configuration. -func (r *SystemdRemediator) reload(ctx context.Context) error { - r.logInfof("Reloading service: %s", r.config.ServiceName) - output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "reload", r.config.ServiceName) +func (r *SystemdRemediator) reload(ctx context.Context, serviceName string) error { + r.logInfof("Reloading service: %s", serviceName) + output, err := r.systemdExecutor.ExecuteSystemctl(ctx, "reload", serviceName) if err != nil { return fmt.Errorf("systemctl reload failed: %w (output: %s)", err, output) } @@ -241,7 +276,7 @@ func (r *SystemdRemediator) reload(ctx context.Context) error { } // verifyServiceStatus verifies that the service is active after remediation. -func (r *SystemdRemediator) verifyServiceStatus(ctx context.Context) error { +func (r *SystemdRemediator) verifyServiceStatus(ctx context.Context, serviceName string) error { // Create a context with timeout for verification verifyCtx, cancel := context.WithTimeout(ctx, r.config.VerifyTimeout) defer cancel() @@ -256,18 +291,18 @@ func (r *SystemdRemediator) verifyServiceStatus(ctx context.Context) error { return fmt.Errorf("timeout waiting for service to become active after %v", r.config.VerifyTimeout) case <-ticker.C: - isActive, err := r.systemdExecutor.IsActive(verifyCtx, r.config.ServiceName) + isActive, err := r.systemdExecutor.IsActive(verifyCtx, serviceName) if err != nil { r.logWarnf("Error checking service status during verification: %v", err) continue } if isActive { - r.logInfof("Service %s is now active", r.config.ServiceName) + r.logInfof("Service %s is now active", serviceName) return nil } - r.logInfof("Waiting for service %s to become active...", r.config.ServiceName) + r.logInfof("Waiting for service %s to become active...", serviceName) } } } diff --git a/pkg/remediators/systemd_test.go b/pkg/remediators/systemd_test.go index d62a297..7c08d50 100644 --- a/pkg/remediators/systemd_test.go +++ b/pkg/remediators/systemd_test.go @@ -142,11 +142,14 @@ func TestNewSystemdRemediator(t *testing.T) { wantErr: false, }, { - name: "missing service name", + // Empty ServiceName is now valid: it produces a metadata-driven + // singleton whose target service is resolved per-call from + // Problem.Metadata["service"] at Remediate time. + name: "missing service name (metadata-driven singleton)", config: SystemdConfig{ Operation: SystemdRestart, }, - wantErr: true, + wantErr: false, }, { name: "invalid operation", From bc9f2eedf89ffc3cbf9a378ca75ef4cf47230433 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 08:22:06 -0500 Subject: [PATCH 09/11] feat(remediators): node-reboot + pod-delete cluster remediators (Task #19263 phase 2) Add RegisterClusterRemediators(registry,cfg,client,nodeName,self...) wiring the two destructive strategies, fail-closed: registers node-reboot + pod-delete ONLY when a k8s client and nodeName are present (else logs and registers nothing). main builds an in-cluster clientset (non-fatal on error) + passes downward-API POD_NAME/POD_NAMESPACE for self-skip. NodeRebootRemediator: cordon -> drain (evict, skipping DaemonSet/mirror/ self pods) -> reboot via injected runner (default systemctl reboot); dry-run returns before ANY mutation; cordon failure aborts reboot; bounded per-phase timeouts; 30m cooldown. PodDeleteRemediator: deletes only the pod named in Problem.Metadata (namespace/pod), refuses mirror/self, dry-run log-only. Tested with fake clientset incl. dry-run-does-not-execute, drain ordering/skips, and fail-closed registration. --- cmd/node-doctor/main.go | 71 +++- pkg/remediators/builtin.go | 94 +++++- pkg/remediators/cluster_register_test.go | 82 +++++ pkg/remediators/node_reboot.go | 413 +++++++++++++++++++++++ pkg/remediators/node_reboot_test.go | 254 ++++++++++++++ pkg/remediators/pod_delete.go | 177 ++++++++++ pkg/remediators/pod_delete_test.go | 121 +++++++ 7 files changed, 1206 insertions(+), 6 deletions(-) create mode 100644 pkg/remediators/cluster_register_test.go create mode 100644 pkg/remediators/node_reboot.go create mode 100644 pkg/remediators/node_reboot_test.go create mode 100644 pkg/remediators/pod_delete.go create mode 100644 pkg/remediators/pod_delete_test.go diff --git a/cmd/node-doctor/main.go b/cmd/node-doctor/main.go index e89611b..d3f3a44 100644 --- a/cmd/node-doctor/main.go +++ b/cmd/node-doctor/main.go @@ -13,6 +13,10 @@ import ( "syscall" "time" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "github.com/supporttools/node-doctor/pkg/detector" httpexporter "github.com/supporttools/node-doctor/pkg/exporters/http" kubernetesexporter "github.com/supporttools/node-doctor/pkg/exporters/kubernetes" @@ -230,7 +234,25 @@ func main() { // naming them fails dispatch, which is the desired fail-safe). SetDryRun // is applied above so the closures pick up the correct dry-run state. remediators.RegisterBuiltinRemediators(remediatorRegistry, config) - log.Printf("[INFO] Registered built-in remediators: %v (node-reboot/pod-delete deferred to Phase 2)", + log.Printf("[INFO] Registered built-in remediators: %v", + remediatorRegistry.GetRegisteredTypes()) + + // Register the DESTRUCTIVE cluster-scoped remediators (node-reboot, + // pod-delete) ONLY when a real in-cluster Kubernetes client and node name + // are available. buildClusterClient returns a nil client (non-fatal) when + // running out-of-cluster; RegisterClusterRemediators then registers + // nothing and logs a warning, so a config naming a destructive strategy + // fails dispatch (fail-closed) rather than doing something dangerous. + clusterClient := buildClusterClient(config) + remediators.RegisterClusterRemediators( + remediatorRegistry, + config, + clusterClient, + config.Settings.NodeName, + os.Getenv("POD_NAME"), + os.Getenv("POD_NAMESPACE"), + ) + log.Printf("[INFO] Remediators registered after cluster wiring: %v", remediatorRegistry.GetRegisteredTypes()) // Wire the controller lease client when coordination is opted in. @@ -532,3 +554,50 @@ func wireLeaseClient(registry *remediators.RemediatorRegistry, config *types.Nod coord.ControllerURL, config.Settings.NodeName, coord.LeaseTimeout) return nil } + +// buildClusterClient builds a Kubernetes clientset for the DESTRUCTIVE +// cluster-scoped remediators (node-reboot, pod-delete). It mirrors the +// auto-detection used by pkg/exporters/kubernetes/client.go (explicit kubeconfig +// if set, otherwise in-cluster config). +// +// On ANY error (no kubeconfig + not in a pod, clientset build failure) it logs +// and returns nil — this is non-fatal. A nil client causes +// RegisterClusterRemediators to skip the destructive strategies (fail-closed), +// which is strictly safer than registering remediators that cannot reach the +// API. +func buildClusterClient(config *types.NodeDoctorConfig) kubernetes.Interface { + var restConfig *rest.Config + var err error + + if config.Settings.Kubeconfig != "" { + restConfig, err = clientcmd.BuildConfigFromFlags("", config.Settings.Kubeconfig) + if err != nil { + log.Printf("[WARN] Could not build Kubernetes config from kubeconfig %q: %v "+ + "(destructive remediators will be skipped)", config.Settings.Kubeconfig, err) + return nil + } + } else { + restConfig, err = rest.InClusterConfig() + if err != nil { + log.Printf("[WARN] Could not build in-cluster Kubernetes config: %v "+ + "(destructive remediators will be skipped — are you running inside a pod?)", err) + return nil + } + } + + if config.Settings.QPS > 0 { + restConfig.QPS = config.Settings.QPS + } + if config.Settings.Burst > 0 { + restConfig.Burst = config.Settings.Burst + } + + clientset, err := kubernetes.NewForConfig(restConfig) + if err != nil { + log.Printf("[WARN] Could not create Kubernetes clientset: %v "+ + "(destructive remediators will be skipped)", err) + return nil + } + + return clientset +} diff --git a/pkg/remediators/builtin.go b/pkg/remediators/builtin.go index e10d033..aadafa4 100644 --- a/pkg/remediators/builtin.go +++ b/pkg/remediators/builtin.go @@ -1,8 +1,11 @@ package remediators import ( + "log" "time" + "k8s.io/client-go/kubernetes" + "github.com/supporttools/node-doctor/pkg/types" ) @@ -18,11 +21,14 @@ const ( // supplied per-call via Problem.Metadata["scriptPath"]/["args"]. StrategyCustomScript = "custom-script" - // StrategyNodeReboot and StrategyPodDelete are DESTRUCTIVE strategies that are - // intentionally NOT registered by RegisterBuiltinRemediators (Phase 1). They - // are deferred to Phase 2 (TaskForge #19263 Phase 2). Until then, a config - // that names them will fail dispatch with "unknown remediator type", which is - // the desired fail-safe behavior for un-implemented destructive actions. + // StrategyNodeReboot and StrategyPodDelete are DESTRUCTIVE strategies. They are + // intentionally NOT registered by RegisterBuiltinRemediators; they are + // registered separately by RegisterClusterRemediators (TaskForge #19263 Phase + // 2) and ONLY when a real Kubernetes client and node name are available. When + // the cluster client is unavailable (e.g. out-of-cluster) they remain + // unregistered and a config naming them fails dispatch with "unknown + // remediator type" — the desired fail-closed behavior for un-actionable + // destructive actions. StrategyNodeReboot = "node-reboot" StrategyPodDelete = "pod-delete" ) @@ -96,3 +102,81 @@ func RegisterBuiltinRemediators(registry *RemediatorRegistry, cfg *types.NodeDoc Description: "Runs the remediation script named in the triggering monitor's remediation config (Problem.Metadata[\"scriptPath\"]/[\"args\"]).", }) } + +// RegisterClusterRemediators registers the DESTRUCTIVE cluster-scoped remediator +// strategies "node-reboot" and "pod-delete" (TaskForge #19263 Phase 2). +// +// These are registered ONLY when a real Kubernetes client AND a non-empty node +// name are available. If either is missing (e.g. node-doctor is running +// out-of-cluster, or the in-cluster client could not be built), they are NOT +// registered and a clear warning is logged. This is the intended fail-closed +// behavior: a config naming a destructive strategy then fails dispatch with +// "unknown remediator type" rather than silently doing something dangerous (or +// nothing). +// +// Both remediators honor the registry's dry-run state (registry.IsDryRun()): in +// dry-run every destructive step is logged but never executed. +// +// - "node-reboot" -> NodeRebootRemediator: cordons the node, drains its pods +// (skipping DaemonSet-owned, mirror/static, and node-doctor's OWN pod), then +// reboots via an injected command runner. Cordon-before-drain-before-reboot +// ordering is enforced; cordon failure aborts the reboot. +// +// - "pod-delete" -> PodDeleteRemediator: deletes the single pod named in +// Problem.Metadata["namespace"]/["pod"]; refuses mirror/static pods and +// node-doctor's own pod; missing metadata is a hard error. +// +// selfPodName/selfPodNamespace identify node-doctor's own pod (typically from +// the downward-API POD_NAME/POD_NAMESPACE env vars) so neither remediator can +// act on the pod running this very process. They may be empty (self-skip simply +// disabled), but supplying them is strongly recommended. +// +// Register panics on duplicate types, so RegisterClusterRemediators must be +// called at most once per registry, after RegisterBuiltinRemediators. +func RegisterClusterRemediators(registry *RemediatorRegistry, cfg *types.NodeDoctorConfig, client kubernetes.Interface, nodeName string, selfPodName, selfPodNamespace string) { + if registry == nil { + return + } + + // Fail closed: without a real client and node name we cannot safely cordon, + // drain, or reboot. Leave the destructive strategies unregistered so any + // config naming them fails dispatch instead of doing something dangerous. + if client == nil || nodeName == "" { + log.Printf("[WARN] Destructive remediators (node-reboot, pod-delete) NOT registered: "+ + "Kubernetes client available=%v, nodeName=%q. A config naming these strategies will "+ + "fail dispatch (fail-closed).", client != nil, nodeName) + return + } + + dryRun := registry.IsDryRun() + + // node-reboot: cordon + drain + reboot (destructive, dry-run-aware). + registry.Register(RemediatorInfo{ + Type: StrategyNodeReboot, + Factory: func() (types.Remediator, error) { + return NewNodeRebootRemediator(client, NodeRebootConfig{ + NodeName: nodeName, + DryRun: dryRun, + SelfPodName: selfPodName, + SelfPodNamespace: selfPodNamespace, + }) + }, + Description: "DESTRUCTIVE (dry-run-aware): cordons, drains (skipping DaemonSet/mirror/self pods), and reboots this node.", + }) + + // pod-delete: deletes the pod named in Problem.Metadata (destructive, dry-run-aware). + registry.Register(RemediatorInfo{ + Type: StrategyPodDelete, + Factory: func() (types.Remediator, error) { + return NewPodDeleteRemediator(client, PodDeleteConfig{ + DryRun: dryRun, + SelfPodName: selfPodName, + SelfPodNamespace: selfPodNamespace, + }) + }, + Description: "DESTRUCTIVE (dry-run-aware): deletes the pod named in Problem.Metadata[\"namespace\"]/[\"pod\"]; refuses mirror/self pods.", + }) + + log.Printf("[INFO] Registered destructive cluster remediators: [%s %s] (dry-run=%v, node=%q)", + StrategyNodeReboot, StrategyPodDelete, dryRun, nodeName) +} diff --git a/pkg/remediators/cluster_register_test.go b/pkg/remediators/cluster_register_test.go new file mode 100644 index 0000000..6b03ff2 --- /dev/null +++ b/pkg/remediators/cluster_register_test.go @@ -0,0 +1,82 @@ +package remediators + +import ( + "testing" + + "k8s.io/client-go/kubernetes/fake" +) + +func hasType(types []string, want string) bool { + for _, t := range types { + if t == want { + return true + } + } + return false +} + +func TestRegisterClusterRemediators_RegistersWhenClientPresent(t *testing.T) { + registry := NewRegistry(10, 100) + client := fake.NewSimpleClientset() + + RegisterClusterRemediators(registry, nil, client, testNodeName, "node-doctor-xyz", "monitoring") + + got := registry.GetRegisteredTypes() + if !hasType(got, StrategyNodeReboot) { + t.Fatalf("expected %q registered, got %v", StrategyNodeReboot, got) + } + if !hasType(got, StrategyPodDelete) { + t.Fatalf("expected %q registered, got %v", StrategyPodDelete, got) + } + + // Factories must build successfully. + if _, err := registry.getOrCreateRemediator(StrategyNodeReboot); err != nil { + t.Fatalf("create node-reboot: %v", err) + } + if _, err := registry.getOrCreateRemediator(StrategyPodDelete); err != nil { + t.Fatalf("create pod-delete: %v", err) + } +} + +func TestRegisterClusterRemediators_NilClientRegistersNothing(t *testing.T) { + registry := NewRegistry(10, 100) + + RegisterClusterRemediators(registry, nil, nil, testNodeName, "", "") + + got := registry.GetRegisteredTypes() + if hasType(got, StrategyNodeReboot) || hasType(got, StrategyPodDelete) { + t.Fatalf("expected NO destructive remediators with nil client, got %v", got) + } +} + +func TestRegisterClusterRemediators_EmptyNodeNameRegistersNothing(t *testing.T) { + registry := NewRegistry(10, 100) + client := fake.NewSimpleClientset() + + RegisterClusterRemediators(registry, nil, client, "", "", "") + + got := registry.GetRegisteredTypes() + if hasType(got, StrategyNodeReboot) || hasType(got, StrategyPodDelete) { + t.Fatalf("expected NO destructive remediators with empty node name, got %v", got) + } +} + +func TestRegisterClusterRemediators_DryRunPropagates(t *testing.T) { + registry := NewRegistry(10, 100) + registry.SetDryRun(true) + client := fake.NewSimpleClientset() + + RegisterClusterRemediators(registry, nil, client, testNodeName, "", "") + + rem, err := registry.getOrCreateRemediator(StrategyNodeReboot) + if err != nil { + t.Fatalf("create node-reboot: %v", err) + } + nr, ok := rem.(*NodeRebootRemediator) + if !ok { + t.Fatalf("unexpected type %T", rem) + } + if !nr.config.DryRun { + t.Fatalf("expected node-reboot remediator to inherit dry-run=true from registry") + } +} diff --git a/pkg/remediators/node_reboot.go b/pkg/remediators/node_reboot.go new file mode 100644 index 0000000..ad9c5a4 --- /dev/null +++ b/pkg/remediators/node_reboot.go @@ -0,0 +1,413 @@ +package remediators + +import ( + "context" + "fmt" + "log" + "os/exec" + "strings" + "time" + + corev1 "k8s.io/api/core/v1" + policyv1 "k8s.io/api/policy/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8stypes "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes" + + "github.com/supporttools/node-doctor/pkg/types" +) + +// mirrorPodAnnotationKey is the annotation Kubernetes places on mirror pods +// (the API representation of static pods managed directly by the kubelet). +// Such pods must NEVER be evicted/deleted: the kubelet recreates them and the +// API delete is meaningless and potentially disruptive to control-plane +// components running as static pods. +const mirrorPodAnnotationKey = "kubernetes.io/config.mirror" + +// daemonSetKind is the OwnerReference kind for DaemonSet-managed pods. Such +// pods tolerate node unschedulability by design and are intentionally skipped +// during drain (matching kubectl drain's default behavior). +const daemonSetKind = "DaemonSet" + +// Default tuning for the node-reboot remediator. These are deliberately +// conservative: a node reboot is the single most destructive action this tool +// can take, so every phase is bounded and best-effort. +const ( + // defaultCordonTimeout bounds the node patch (cordon) call. + defaultCordonTimeout = 30 * time.Second + + // defaultDrainTimeout bounds the entire drain/evict phase. After this the + // reboot still proceeds (best-effort drain), but only ever AFTER a + // successful cordon. + defaultDrainTimeout = 2 * time.Minute + + // defaultPodEvictionGracePeriod is the grace period (seconds) handed to the + // Eviction API / pod delete during drain. + defaultPodEvictionGracePeriod int64 = 30 + + // defaultRebootTimeout bounds the reboot command itself. + defaultRebootTimeout = 30 * time.Second +) + +// defaultRebootCommand is the command invoked to reboot the node. It is +// overridable via NodeRebootConfig.RebootCommand for environments that reboot +// differently (e.g. nsenter into the host, or a custom helper). +var defaultRebootCommand = []string{"systemctl", "reboot"} + +// CommandRunner executes an arbitrary host command. It is injected into the +// NodeRebootRemediator so the actual (destructive) reboot can be mocked in +// tests, mirroring the SystemdExecutor/NetworkExecutor pattern used elsewhere +// in this package. +type CommandRunner interface { + // Run executes name with args and returns the combined output. + Run(ctx context.Context, name string, args ...string) (string, error) +} + +// defaultCommandRunner is the production CommandRunner that actually shells +// out. It is only ever invoked on the non-dry-run reboot path. +type defaultCommandRunner struct{} + +// Run executes the command and returns its combined output. +func (e *defaultCommandRunner) Run(ctx context.Context, name string, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, name, args...) + output, err := cmd.CombinedOutput() + return strings.TrimSpace(string(output)), err +} + +// NodeRebootConfig configures the NodeRebootRemediator. +type NodeRebootConfig struct { + // NodeName is the node this remediator operates on (this node). Required. + NodeName string + + // DryRun, when true, makes every destructive step (cordon, evict, reboot) + // log-only. The reboot command is NEVER executed in dry-run. + DryRun bool + + // RebootCommand overrides the default reboot command. When empty the + // package default ("systemctl reboot") is used. + RebootCommand []string + + // SelfPodName / SelfPodNamespace identify node-doctor's OWN pod so the + // drain never evicts the remediator out from under itself. When known + // (typically from the downward-API POD_NAME/POD_NAMESPACE env vars) the + // matching pod is skipped during drain. + SelfPodName string + SelfPodNamespace string + + // CordonTimeout / DrainTimeout / RebootTimeout / PodEvictionGracePeriod + // bound the respective phases. Zero values fall back to the package + // defaults. + CordonTimeout time.Duration + DrainTimeout time.Duration + RebootTimeout time.Duration + PodEvictionGracePeriod int64 +} + +// NodeRebootRemediator drains and reboots the local node. It is DESTRUCTIVE and +// is only ever registered when a real Kubernetes client and node name are +// available (see RegisterClusterRemediators). It honors dry-run by logging +// every intended action without executing it, cordons before draining, drains +// before rebooting, and skips DaemonSet-owned, mirror/static, and node-doctor's +// own pods during drain. +type NodeRebootRemediator struct { + *BaseRemediator + config NodeRebootConfig + client kubernetes.Interface + runner CommandRunner +} + +// NewNodeRebootRemediator constructs a NodeRebootRemediator. The client and a +// non-empty NodeName are required; without them the remediator could not +// cordon/drain and must not be constructed (RegisterClusterRemediators +// enforces this before calling). +func NewNodeRebootRemediator(client kubernetes.Interface, config NodeRebootConfig) (*NodeRebootRemediator, error) { + if client == nil { + return nil, fmt.Errorf("node-reboot remediator requires a non-nil Kubernetes client") + } + if config.NodeName == "" { + return nil, fmt.Errorf("node-reboot remediator requires a non-empty node name") + } + + if len(config.RebootCommand) == 0 { + config.RebootCommand = append([]string{}, defaultRebootCommand...) + } + if config.CordonTimeout <= 0 { + config.CordonTimeout = defaultCordonTimeout + } + if config.DrainTimeout <= 0 { + config.DrainTimeout = defaultDrainTimeout + } + if config.RebootTimeout <= 0 { + config.RebootTimeout = defaultRebootTimeout + } + if config.PodEvictionGracePeriod <= 0 { + config.PodEvictionGracePeriod = defaultPodEvictionGracePeriod + } + + // Node reboot is the highest-impact action: use the destructive cooldown so + // repeated reboots are heavily rate-limited at the remediator level in + // addition to the registry-wide limits. + base, err := NewBaseRemediator(fmt.Sprintf("node-reboot-%s", config.NodeName), CooldownDestructive) + if err != nil { + return nil, fmt.Errorf("failed to create base remediator: %w", err) + } + + r := &NodeRebootRemediator{ + BaseRemediator: base, + config: config, + client: client, + runner: &defaultCommandRunner{}, + } + + if err := base.SetRemediateFunc(r.remediate); err != nil { + return nil, fmt.Errorf("failed to set remediate function: %w", err) + } + + return r, nil +} + +// SetCommandRunner overrides the reboot command runner (used in tests). +func (r *NodeRebootRemediator) SetCommandRunner(runner CommandRunner) { + if runner != nil { + r.runner = runner + } +} + +// remediate cordons the node, drains its pods (best-effort, with exclusions), +// and finally reboots — strictly in that order. In dry-run every step is +// logged but nothing is mutated and the reboot command is NOT invoked. +func (r *NodeRebootRemediator) remediate(ctx context.Context, problem types.Problem) error { + if r.config.DryRun { + log.Printf("[WARN] [node-reboot] DRY-RUN: would cordon, drain, and reboot node %q (problem=%s). No action taken.", + r.config.NodeName, problem.Type) + // Still walk the (read-only-ish) plan in dry-run so operators see what + // WOULD happen, but never mutate anything. + r.logCordonPlan(ctx) + r.logDrainPlan(ctx) + log.Printf("[WARN] [node-reboot] DRY-RUN: would execute reboot command %v on node %q. Skipped.", + r.config.RebootCommand, r.config.NodeName) + return nil + } + + log.Printf("[WARN] [node-reboot] Beginning DESTRUCTIVE node reboot sequence for node %q (problem=%s)", + r.config.NodeName, problem.Type) + + // Phase 1: cordon. Must succeed before we touch pods or reboot. + if err := r.cordon(ctx); err != nil { + return fmt.Errorf("cordon failed, aborting reboot (node not rebooted): %w", err) + } + + // Phase 2: drain (best-effort). Failures are logged and tolerated, but the + // drain only runs AFTER a successful cordon. + r.drain(ctx) + + // Phase 3: reboot. Only reached after cordon (and best-effort drain). + if err := r.reboot(ctx); err != nil { + return fmt.Errorf("reboot command failed on node %q: %w", r.config.NodeName, err) + } + + log.Printf("[WARN] [node-reboot] Reboot command issued on node %q", r.config.NodeName) + return nil +} + +// cordon patches the node spec.unschedulable=true. It is a no-op if the node is +// already unschedulable. +func (r *NodeRebootRemediator) cordon(ctx context.Context) error { + cordonCtx, cancel := context.WithTimeout(ctx, r.config.CordonTimeout) + defer cancel() + + node, err := r.client.CoreV1().Nodes().Get(cordonCtx, r.config.NodeName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get node %q: %w", r.config.NodeName, err) + } + + if node.Spec.Unschedulable { + log.Printf("[INFO] [node-reboot] Node %q already cordoned (unschedulable), skipping cordon", r.config.NodeName) + return nil + } + + log.Printf("[WARN] [node-reboot] Cordoning node %q (spec.unschedulable=true)", r.config.NodeName) + patch := []byte(`{"spec":{"unschedulable":true}}`) + if _, err := r.client.CoreV1().Nodes().Patch(cordonCtx, r.config.NodeName, k8stypes.StrategicMergePatchType, patch, metav1.PatchOptions{}); err != nil { + return fmt.Errorf("failed to patch node %q unschedulable: %w", r.config.NodeName, err) + } + return nil +} + +// logCordonPlan logs (without mutating) whether a cordon would occur. Used in +// dry-run only. +func (r *NodeRebootRemediator) logCordonPlan(ctx context.Context) { + cordonCtx, cancel := context.WithTimeout(ctx, r.config.CordonTimeout) + defer cancel() + node, err := r.client.CoreV1().Nodes().Get(cordonCtx, r.config.NodeName, metav1.GetOptions{}) + if err != nil { + log.Printf("[WARN] [node-reboot] DRY-RUN: could not read node %q to plan cordon: %v", r.config.NodeName, err) + return + } + if node.Spec.Unschedulable { + log.Printf("[WARN] [node-reboot] DRY-RUN: node %q already cordoned; cordon would be a no-op", r.config.NodeName) + } else { + log.Printf("[WARN] [node-reboot] DRY-RUN: would cordon node %q (spec.unschedulable=true)", r.config.NodeName) + } +} + +// drain evicts eligible pods on the node, skipping DaemonSet-owned, mirror, and +// self pods. It is best-effort: individual failures are logged and the drain +// continues. The whole phase is bounded by DrainTimeout. +func (r *NodeRebootRemediator) drain(ctx context.Context) { + drainCtx, cancel := context.WithTimeout(ctx, r.config.DrainTimeout) + defer cancel() + + pods, err := r.listNodePods(drainCtx) + if err != nil { + log.Printf("[WARN] [node-reboot] Failed to list pods on node %q for drain (continuing best-effort): %v", + r.config.NodeName, err) + return + } + + log.Printf("[WARN] [node-reboot] Draining node %q: %d pod(s) found", r.config.NodeName, len(pods)) + for i := range pods { + pod := &pods[i] + if skip, reason := r.shouldSkipPod(pod); skip { + log.Printf("[INFO] [node-reboot] Skipping pod %s/%s during drain: %s", pod.Namespace, pod.Name, reason) + continue + } + r.evictPod(drainCtx, pod) + } +} + +// logDrainPlan logs (without evicting) which pods would be evicted/skipped. +// Used in dry-run only. +func (r *NodeRebootRemediator) logDrainPlan(ctx context.Context) { + drainCtx, cancel := context.WithTimeout(ctx, r.config.DrainTimeout) + defer cancel() + pods, err := r.listNodePods(drainCtx) + if err != nil { + log.Printf("[WARN] [node-reboot] DRY-RUN: could not list pods on node %q to plan drain: %v", + r.config.NodeName, err) + return + } + for i := range pods { + pod := &pods[i] + if skip, reason := r.shouldSkipPod(pod); skip { + log.Printf("[WARN] [node-reboot] DRY-RUN: would SKIP pod %s/%s (%s)", pod.Namespace, pod.Name, reason) + continue + } + log.Printf("[WARN] [node-reboot] DRY-RUN: would EVICT pod %s/%s", pod.Namespace, pod.Name) + } +} + +// listNodePods returns the pods scheduled on this node. +func (r *NodeRebootRemediator) listNodePods(ctx context.Context) ([]corev1.Pod, error) { + fieldSelector := fmt.Sprintf("spec.nodeName=%s", r.config.NodeName) + list, err := r.client.CoreV1().Pods(metav1.NamespaceAll).List(ctx, metav1.ListOptions{ + FieldSelector: fieldSelector, + }) + if err != nil { + return nil, err + } + return list.Items, nil +} + +// shouldSkipPod returns true (with a reason) if the pod must NOT be evicted +// during drain: mirror/static pods, DaemonSet-owned pods, and node-doctor's own +// pod are always preserved. +func (r *NodeRebootRemediator) shouldSkipPod(pod *corev1.Pod) (bool, string) { + if isMirrorPod(pod) { + return true, "mirror/static pod" + } + if isDaemonSetPod(pod) { + return true, "DaemonSet-owned pod" + } + if r.isSelfPod(pod) { + return true, "node-doctor's own pod" + } + return false, "" +} + +// isSelfPod reports whether pod is node-doctor's own pod (so drain never evicts +// the remediator itself). +func (r *NodeRebootRemediator) isSelfPod(pod *corev1.Pod) bool { + if r.config.SelfPodName == "" { + return false + } + return pod.Name == r.config.SelfPodName && pod.Namespace == r.config.SelfPodNamespace +} + +// evictPod evicts a single pod via the Eviction API, falling back to a graceful +// delete if eviction is unsupported. Failures are logged and tolerated. +func (r *NodeRebootRemediator) evictPod(ctx context.Context, pod *corev1.Pod) { + grace := r.config.PodEvictionGracePeriod + log.Printf("[WARN] [node-reboot] Evicting pod %s/%s (grace=%ds)", pod.Namespace, pod.Name, grace) + + eviction := &policyv1.Eviction{ + ObjectMeta: metav1.ObjectMeta{ + Name: pod.Name, + Namespace: pod.Namespace, + }, + DeleteOptions: &metav1.DeleteOptions{ + GracePeriodSeconds: &grace, + }, + } + + err := r.client.PolicyV1().Evictions(pod.Namespace).Evict(ctx, eviction) + if err == nil { + return + } + + // Eviction not supported (older clusters / fakes) -> fall back to delete. + if apierrors.IsNotFound(err) { + log.Printf("[INFO] [node-reboot] Pod %s/%s already gone during eviction", pod.Namespace, pod.Name) + return + } + log.Printf("[WARN] [node-reboot] Eviction of pod %s/%s failed (%v); falling back to graceful delete", + pod.Namespace, pod.Name, err) + + if delErr := r.client.CoreV1().Pods(pod.Namespace).Delete(ctx, pod.Name, metav1.DeleteOptions{ + GracePeriodSeconds: &grace, + }); delErr != nil && !apierrors.IsNotFound(delErr) { + log.Printf("[WARN] [node-reboot] Graceful delete of pod %s/%s also failed (continuing): %v", + pod.Namespace, pod.Name, delErr) + } +} + +// reboot invokes the configured reboot command via the injected runner. It is +// only ever called on the non-dry-run path, after cordon+drain. +func (r *NodeRebootRemediator) reboot(ctx context.Context) error { + rebootCtx, cancel := context.WithTimeout(ctx, r.config.RebootTimeout) + defer cancel() + + cmd := r.config.RebootCommand + log.Printf("[WARN] [node-reboot] Executing reboot command %v on node %q", cmd, r.config.NodeName) + output, err := r.runner.Run(rebootCtx, cmd[0], cmd[1:]...) + if err != nil { + return fmt.Errorf("reboot command %v failed: %w (output: %s)", cmd, err, output) + } + return nil +} + +// CanRemediate returns true for node-reboot problems. It defers to the base +// remediator's cooldown/attempt gating. +func (r *NodeRebootRemediator) CanRemediate(problem types.Problem) bool { + return r.BaseRemediator.CanRemediate(problem) +} + +// isMirrorPod reports whether a pod is a mirror (static) pod. +func isMirrorPod(pod *corev1.Pod) bool { + if pod.Annotations == nil { + return false + } + _, ok := pod.Annotations[mirrorPodAnnotationKey] + return ok +} + +// isDaemonSetPod reports whether a pod is owned by a DaemonSet. +func isDaemonSetPod(pod *corev1.Pod) bool { + for _, ref := range pod.OwnerReferences { + if ref.Kind == daemonSetKind { + return true + } + } + return false +} diff --git a/pkg/remediators/node_reboot_test.go b/pkg/remediators/node_reboot_test.go new file mode 100644 index 0000000..3b82201 --- /dev/null +++ b/pkg/remediators/node_reboot_test.go @@ -0,0 +1,254 @@ +package remediators + +import ( + "context" + "sync" + "testing" + + corev1 "k8s.io/api/core/v1" + policyv1 "k8s.io/api/policy/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + ktesting "k8s.io/client-go/testing" + + "github.com/supporttools/node-doctor/pkg/types" +) + +const testNodeName = "test-node" + +// recordingRunner records reboot invocations and the order in which they +// occurred relative to other actions. +type recordingRunner struct { + mu sync.Mutex + calls int + args [][]string + onCall func() +} + +func (r *recordingRunner) Run(_ context.Context, name string, args ...string) (string, error) { + r.mu.Lock() + defer r.mu.Unlock() + r.calls++ + r.args = append(r.args, append([]string{name}, args...)) + if r.onCall != nil { + r.onCall() + } + return "ok", nil +} + +func (r *recordingRunner) callCount() int { + r.mu.Lock() + defer r.mu.Unlock() + return r.calls +} + +// makePod builds a pod scheduled on testNodeName with the given properties. +func makePod(name, namespace string, opts ...func(*corev1.Pod)) *corev1.Pod { + p := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Spec: corev1.PodSpec{ + NodeName: testNodeName, + }, + } + for _, o := range opts { + o(p) + } + return p +} + +func asDaemonSet(p *corev1.Pod) { + p.OwnerReferences = []metav1.OwnerReference{{Kind: "DaemonSet", Name: "ds"}} +} + +func asMirror(p *corev1.Pod) { + if p.Annotations == nil { + p.Annotations = map[string]string{} + } + p.Annotations[mirrorPodAnnotationKey] = "abc123" +} + +func newTestNode(unschedulable bool) *corev1.Node { + return &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: testNodeName}, + Spec: corev1.NodeSpec{Unschedulable: unschedulable}, + } +} + +func TestNodeReboot_DryRun_DoesNotExecute(t *testing.T) { + node := newTestNode(false) + pod := makePod("app", "default") + client := fake.NewSimpleClientset(node, pod) + runner := &recordingRunner{} + + r, err := NewNodeRebootRemediator(client, NodeRebootConfig{ + NodeName: testNodeName, + DryRun: true, + }) + if err != nil { + t.Fatalf("construct: %v", err) + } + r.SetCommandRunner(runner) + + if err := r.Remediate(context.Background(), types.Problem{Type: StrategyNodeReboot}); err != nil { + t.Fatalf("dry-run remediate: %v", err) + } + + // Reboot runner must NOT have been called. + if runner.callCount() != 0 { + t.Fatalf("dry-run executed reboot runner %d times, want 0", runner.callCount()) + } + + // Node must NOT have been cordoned in dry-run. + got, err := client.CoreV1().Nodes().Get(context.Background(), testNodeName, metav1.GetOptions{}) + if err != nil { + t.Fatalf("get node: %v", err) + } + if got.Spec.Unschedulable { + t.Fatalf("dry-run cordoned the node (Unschedulable=true), want unchanged") + } + + // Pod must still be present. + if _, err := client.CoreV1().Pods("default").Get(context.Background(), "app", metav1.GetOptions{}); err != nil { + t.Fatalf("dry-run deleted/evicted the pod: %v", err) + } +} + +func TestNodeReboot_NonDryRun_CordonsDrainsAndReboots(t *testing.T) { + node := newTestNode(false) + appPod := makePod("app", "default") + dsPod := makePod("ds-pod", "kube-system", asDaemonSet) + mirrorPod := makePod("mirror-pod", "kube-system", asMirror) + selfPod := makePod("node-doctor-xyz", "monitoring") + + client := fake.NewSimpleClientset(node, appPod, dsPod, mirrorPod, selfPod) + + // Track ordering: record whether any pod eviction/deletion happened, and + // whether the reboot occurred before any such drain action. + var ( + mu sync.Mutex + rebooted bool + drainObserved bool + rebootBeforeDrain bool + ) + + // The fake clientset does not delete the pod on an Eviction create, so wire a + // reactor that turns an eviction create on the pods/eviction subresource into + // an actual pod delete (mimicking the real API). Also record drain ordering. + podsGVR := schema.GroupVersionResource{Version: "v1", Resource: "pods"} + client.PrependReactor("create", "pods/eviction", func(action ktesting.Action) (bool, runtime.Object, error) { + mu.Lock() + drainObserved = true + if rebooted { + rebootBeforeDrain = true + } + mu.Unlock() + + evictAction := action.(ktesting.CreateActionImpl) + evicted := evictAction.GetObject().(*policyv1.Eviction) + _ = client.Tracker().Delete(podsGVR, evictAction.GetNamespace(), evicted.Name) + return true, nil, nil + }) + + runner := &recordingRunner{onCall: func() { + mu.Lock() + rebooted = true + mu.Unlock() + }} + + r, err := NewNodeRebootRemediator(client, NodeRebootConfig{ + NodeName: testNodeName, + DryRun: false, + SelfPodName: "node-doctor-xyz", + SelfPodNamespace: "monitoring", + }) + if err != nil { + t.Fatalf("construct: %v", err) + } + r.SetCommandRunner(runner) + + if err := r.Remediate(context.Background(), types.Problem{Type: StrategyNodeReboot}); err != nil { + t.Fatalf("remediate: %v", err) + } + + // Node cordoned. + got, err := client.CoreV1().Nodes().Get(context.Background(), testNodeName, metav1.GetOptions{}) + if err != nil { + t.Fatalf("get node: %v", err) + } + if !got.Spec.Unschedulable { + t.Fatalf("node not cordoned (Unschedulable=false), want true") + } + + // Reboot runner called exactly once. + if runner.callCount() != 1 { + t.Fatalf("reboot runner called %d times, want 1", runner.callCount()) + } + + // Ordering: drain (eviction) must have occurred, and the reboot must NOT have + // happened before the drain. + mu.Lock() + gotDrain := drainObserved + gotRebootBeforeDrain := rebootBeforeDrain + mu.Unlock() + if !gotDrain { + t.Fatalf("expected drain (eviction) to occur before reboot, but no eviction observed") + } + if gotRebootBeforeDrain { + t.Fatalf("reboot occurred BEFORE drain; ordering violated") + } + + // app pod must be gone (evicted or deleted). + if _, err := client.CoreV1().Pods("default").Get(context.Background(), "app", metav1.GetOptions{}); err == nil { + t.Fatalf("app pod still present, expected eviction/deletion") + } + + // DaemonSet pod must remain. + if _, err := client.CoreV1().Pods("kube-system").Get(context.Background(), "ds-pod", metav1.GetOptions{}); err != nil { + t.Fatalf("DaemonSet pod was removed, expected skip: %v", err) + } + // Mirror pod must remain. + if _, err := client.CoreV1().Pods("kube-system").Get(context.Background(), "mirror-pod", metav1.GetOptions{}); err != nil { + t.Fatalf("mirror pod was removed, expected skip: %v", err) + } + // Self pod must remain. + if _, err := client.CoreV1().Pods("monitoring").Get(context.Background(), "node-doctor-xyz", metav1.GetOptions{}); err != nil { + t.Fatalf("self pod was removed, expected skip: %v", err) + } +} + +func TestNodeReboot_CordonFailureAbortsReboot(t *testing.T) { + // No node object -> Get fails -> cordon fails -> reboot must NOT run. + client := fake.NewSimpleClientset() + runner := &recordingRunner{} + + r, err := NewNodeRebootRemediator(client, NodeRebootConfig{ + NodeName: testNodeName, + DryRun: false, + }) + if err != nil { + t.Fatalf("construct: %v", err) + } + r.SetCommandRunner(runner) + + if err := r.Remediate(context.Background(), types.Problem{Type: StrategyNodeReboot}); err == nil { + t.Fatalf("expected error when cordon fails, got nil") + } + if runner.callCount() != 0 { + t.Fatalf("reboot ran %d times despite cordon failure, want 0", runner.callCount()) + } +} + +func TestNodeReboot_RequiresClientAndNodeName(t *testing.T) { + if _, err := NewNodeRebootRemediator(nil, NodeRebootConfig{NodeName: testNodeName}); err == nil { + t.Fatalf("expected error with nil client") + } + client := fake.NewSimpleClientset() + if _, err := NewNodeRebootRemediator(client, NodeRebootConfig{NodeName: ""}); err == nil { + t.Fatalf("expected error with empty node name") + } +} diff --git a/pkg/remediators/pod_delete.go b/pkg/remediators/pod_delete.go new file mode 100644 index 0000000..98a8b5d --- /dev/null +++ b/pkg/remediators/pod_delete.go @@ -0,0 +1,177 @@ +package remediators + +import ( + "context" + "fmt" + "log" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + + "github.com/supporttools/node-doctor/pkg/types" +) + +// Problem.Metadata keys used by the pod-delete strategy to identify the target +// pod. The detector populates these from the triggering monitor's remediation +// config; without BOTH the remediator refuses to act (it never guesses or +// deletes broadly). +const ( + // metadataKeyNamespace carries the target pod's namespace. + metadataKeyNamespace = "namespace" + + // metadataKeyPod carries the target pod's name. + metadataKeyPod = "pod" +) + +const ( + // defaultPodDeleteTimeout bounds the get+delete calls. + defaultPodDeleteTimeout = 30 * time.Second + + // defaultPodDeleteGracePeriod is the grace period (seconds) for the delete. + defaultPodDeleteGracePeriod int64 = 30 +) + +// PodDeleteConfig configures the PodDeleteRemediator. +type PodDeleteConfig struct { + // DryRun, when true, makes the delete log-only (the pod is NOT deleted). + DryRun bool + + // SelfPodName / SelfPodNamespace identify node-doctor's OWN pod so it can + // never delete itself. + SelfPodName string + SelfPodNamespace string + + // GracePeriodSeconds is the delete grace period. Zero falls back to the + // package default. + GracePeriodSeconds int64 + + // Timeout bounds the get+delete API calls. Zero falls back to the default. + Timeout time.Duration +} + +// PodDeleteRemediator deletes a single target pod identified via Problem +// metadata. It is DESTRUCTIVE and only registered when a real Kubernetes client +// is available (see RegisterClusterRemediators). It honors dry-run, refuses to +// delete mirror/static pods, and refuses to delete node-doctor's own pod. +type PodDeleteRemediator struct { + *BaseRemediator + config PodDeleteConfig + client kubernetes.Interface +} + +// NewPodDeleteRemediator constructs a PodDeleteRemediator. A non-nil client is +// required. +func NewPodDeleteRemediator(client kubernetes.Interface, config PodDeleteConfig) (*PodDeleteRemediator, error) { + if client == nil { + return nil, fmt.Errorf("pod-delete remediator requires a non-nil Kubernetes client") + } + + if config.GracePeriodSeconds <= 0 { + config.GracePeriodSeconds = defaultPodDeleteGracePeriod + } + if config.Timeout <= 0 { + config.Timeout = defaultPodDeleteTimeout + } + + base, err := NewBaseRemediator("pod-delete", CooldownMedium) + if err != nil { + return nil, fmt.Errorf("failed to create base remediator: %w", err) + } + + r := &PodDeleteRemediator{ + BaseRemediator: base, + config: config, + client: client, + } + + if err := base.SetRemediateFunc(r.remediate); err != nil { + return nil, fmt.Errorf("failed to set remediate function: %w", err) + } + + return r, nil +} + +// remediate deletes the pod named in Problem.Metadata["namespace"]/["pod"]. +// Missing metadata is a hard error (no broad/guessed deletion). Mirror/static +// pods and node-doctor's own pod are refused. In dry-run the delete is logged +// only. +func (r *PodDeleteRemediator) remediate(ctx context.Context, problem types.Problem) error { + namespace, podName, err := r.resolveTarget(problem) + if err != nil { + return err + } + + // Refuse to delete our own pod regardless of dry-run. + if r.isSelfPod(namespace, podName) { + return fmt.Errorf("refusing to delete node-doctor's own pod %s/%s", namespace, podName) + } + + deleteCtx, cancel := context.WithTimeout(ctx, r.config.Timeout) + defer cancel() + + // Read the pod first so we can refuse mirror/static pods. + pod, err := r.client.CoreV1().Pods(namespace).Get(deleteCtx, podName, metav1.GetOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + log.Printf("[INFO] [pod-delete] Target pod %s/%s not found; nothing to delete", namespace, podName) + return nil + } + return fmt.Errorf("failed to get target pod %s/%s: %w", namespace, podName, err) + } + + if isMirrorPod(pod) { + return fmt.Errorf("refusing to delete mirror/static pod %s/%s", namespace, podName) + } + + if r.config.DryRun { + log.Printf("[WARN] [pod-delete] DRY-RUN: would delete pod %s/%s (grace=%ds). No action taken.", + namespace, podName, r.config.GracePeriodSeconds) + return nil + } + + grace := r.config.GracePeriodSeconds + log.Printf("[WARN] [pod-delete] Deleting pod %s/%s (grace=%ds)", namespace, podName, grace) + if err := r.client.CoreV1().Pods(namespace).Delete(deleteCtx, podName, metav1.DeleteOptions{ + GracePeriodSeconds: &grace, + }); err != nil { + if apierrors.IsNotFound(err) { + log.Printf("[INFO] [pod-delete] Pod %s/%s already gone", namespace, podName) + return nil + } + return fmt.Errorf("failed to delete pod %s/%s: %w", namespace, podName, err) + } + + log.Printf("[WARN] [pod-delete] Deleted pod %s/%s", namespace, podName) + return nil +} + +// resolveTarget extracts the target namespace and pod name from problem +// metadata. Both are required. +func (r *PodDeleteRemediator) resolveTarget(problem types.Problem) (namespace, podName string, err error) { + if problem.Metadata == nil { + return "", "", fmt.Errorf("pod-delete requires problem metadata %q and %q but none was provided", metadataKeyNamespace, metadataKeyPod) + } + namespace = problem.Metadata[metadataKeyNamespace] + podName = problem.Metadata[metadataKeyPod] + if namespace == "" || podName == "" { + return "", "", fmt.Errorf("pod-delete requires non-empty problem metadata %q and %q (got namespace=%q pod=%q)", + metadataKeyNamespace, metadataKeyPod, namespace, podName) + } + return namespace, podName, nil +} + +// isSelfPod reports whether the target is node-doctor's own pod. +func (r *PodDeleteRemediator) isSelfPod(namespace, podName string) bool { + if r.config.SelfPodName == "" { + return false + } + return podName == r.config.SelfPodName && namespace == r.config.SelfPodNamespace +} + +// CanRemediate returns true for pod-delete problems, deferring to the base +// remediator's cooldown/attempt gating. +func (r *PodDeleteRemediator) CanRemediate(problem types.Problem) bool { + return r.BaseRemediator.CanRemediate(problem) +} diff --git a/pkg/remediators/pod_delete_test.go b/pkg/remediators/pod_delete_test.go new file mode 100644 index 0000000..406ccc8 --- /dev/null +++ b/pkg/remediators/pod_delete_test.go @@ -0,0 +1,121 @@ +package remediators + +import ( + "context" + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" + + "github.com/supporttools/node-doctor/pkg/types" +) + +func podDeleteProblem(namespace, pod string) types.Problem { + return types.Problem{ + Type: StrategyPodDelete, + Metadata: map[string]string{ + metadataKeyNamespace: namespace, + metadataKeyPod: pod, + }, + } +} + +func TestPodDelete_DryRun_DoesNotDelete(t *testing.T) { + pod := makePod("victim", "default") + client := fake.NewSimpleClientset(pod) + + r, err := NewPodDeleteRemediator(client, PodDeleteConfig{DryRun: true}) + if err != nil { + t.Fatalf("construct: %v", err) + } + + if err := r.Remediate(context.Background(), podDeleteProblem("default", "victim")); err != nil { + t.Fatalf("dry-run remediate: %v", err) + } + + if _, err := client.CoreV1().Pods("default").Get(context.Background(), "victim", metav1.GetOptions{}); err != nil { + t.Fatalf("dry-run deleted the pod: %v", err) + } +} + +func TestPodDelete_NonDryRun_DeletesTarget(t *testing.T) { + pod := makePod("victim", "default") + client := fake.NewSimpleClientset(pod) + + r, err := NewPodDeleteRemediator(client, PodDeleteConfig{DryRun: false}) + if err != nil { + t.Fatalf("construct: %v", err) + } + + if err := r.Remediate(context.Background(), podDeleteProblem("default", "victim")); err != nil { + t.Fatalf("remediate: %v", err) + } + + if _, err := client.CoreV1().Pods("default").Get(context.Background(), "victim", metav1.GetOptions{}); err == nil { + t.Fatalf("pod still present, expected deletion") + } +} + +func TestPodDelete_MissingMetadata_Errors(t *testing.T) { + client := fake.NewSimpleClientset() + r, err := NewPodDeleteRemediator(client, PodDeleteConfig{}) + if err != nil { + t.Fatalf("construct: %v", err) + } + + cases := []types.Problem{ + {Type: StrategyPodDelete}, // nil metadata + {Type: StrategyPodDelete, Metadata: map[string]string{metadataKeyNamespace: "default"}}, // missing pod + {Type: StrategyPodDelete, Metadata: map[string]string{metadataKeyPod: "victim"}}, // missing namespace + {Type: StrategyPodDelete, Metadata: map[string]string{metadataKeyNamespace: "", metadataKeyPod: ""}}, // empty + } + for i, p := range cases { + if err := r.Remediate(context.Background(), p); err == nil { + t.Fatalf("case %d: expected error for missing/empty metadata, got nil", i) + } + } +} + +func TestPodDelete_RefusesMirrorPod(t *testing.T) { + pod := makePod("mirror", "kube-system", asMirror) + client := fake.NewSimpleClientset(pod) + + r, err := NewPodDeleteRemediator(client, PodDeleteConfig{}) + if err != nil { + t.Fatalf("construct: %v", err) + } + + if err := r.Remediate(context.Background(), podDeleteProblem("kube-system", "mirror")); err == nil { + t.Fatalf("expected refusal to delete mirror pod, got nil") + } + // Pod must remain. + if _, err := client.CoreV1().Pods("kube-system").Get(context.Background(), "mirror", metav1.GetOptions{}); err != nil { + t.Fatalf("mirror pod was removed: %v", err) + } +} + +func TestPodDelete_RefusesSelfPod(t *testing.T) { + pod := makePod("node-doctor-xyz", "monitoring") + client := fake.NewSimpleClientset(pod) + + r, err := NewPodDeleteRemediator(client, PodDeleteConfig{ + SelfPodName: "node-doctor-xyz", + SelfPodNamespace: "monitoring", + }) + if err != nil { + t.Fatalf("construct: %v", err) + } + + if err := r.Remediate(context.Background(), podDeleteProblem("monitoring", "node-doctor-xyz")); err == nil { + t.Fatalf("expected refusal to delete self pod, got nil") + } + if _, err := client.CoreV1().Pods("monitoring").Get(context.Background(), "node-doctor-xyz", metav1.GetOptions{}); err != nil { + t.Fatalf("self pod was removed: %v", err) + } +} + +func TestPodDelete_RequiresClient(t *testing.T) { + if _, err := NewPodDeleteRemediator(nil, PodDeleteConfig{}); err == nil { + t.Fatalf("expected error with nil client") + } +} From a8ce5ca099662bac5ea412414dcb05d15f690165 Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 08:35:37 -0500 Subject: [PATCH 10/11] feat(remediators): register network remediation strategies, close #17222 (Task #19263 phase 3) Add flush-dns/restart-interface/reset-routing/flush-ipv6-route to validRemediationStrategies (+ restart-interface requires an interface) and register all four in RegisterBuiltinRemediators as NetworkRemediator closures (per-op). flush-ipv6-route is now registered + reachable, closing #17222. NetworkRemediator resolves the target interface from Problem.Metadata["interface"] (new MonitorRemediationConfig.Interface field threaded by the detector) with config fallback; AllowMetadataInterface lets the registered restart-interface singleton construct without a fixed NIC. Tests: 6 safe strategies registered, dry-run + executor dispatch of flush-ipv6-route, interface-from-metadata, missing-interface error. --- cmd/node-doctor/main.go | 13 ++- pkg/detector/detector.go | 5 + pkg/detector/remediation_test.go | 36 +++++++ pkg/remediators/builtin.go | 49 +++++++++ pkg/remediators/builtin_test.go | 173 ++++++++++++++++++++++++++++--- pkg/remediators/network.go | 89 ++++++++++++---- pkg/remediators/network_test.go | 2 +- pkg/remediators/registry.go | 6 ++ pkg/types/config.go | 22 +++- pkg/types/config_test.go | 52 ++++++++++ 10 files changed, 406 insertions(+), 41 deletions(-) diff --git a/cmd/node-doctor/main.go b/cmd/node-doctor/main.go index d3f3a44..ec96231 100644 --- a/cmd/node-doctor/main.go +++ b/cmd/node-doctor/main.go @@ -228,11 +228,14 @@ func main() { // Register the built-in remediator strategies so the detector's dispatch // (which addresses a remediator by its strategy type) can find one. - // TaskForge #19263 Phase 1 registers ONLY the two SAFE strategies - // (systemd-restart, custom-script); the destructive node-reboot/pod-delete - // strategies are deferred to Phase 2 and remain unregistered (a config - // naming them fails dispatch, which is the desired fail-safe). SetDryRun - // is applied above so the closures pick up the correct dry-run state. + // TaskForge #19263 registers the SAFE strategies: systemd-restart, + // custom-script, and the four network operations (flush-dns, + // restart-interface, reset-routing, flush-ipv6-route — Phase 3, the last + // closing #17222). The destructive node-reboot/pod-delete strategies are + // NOT registered here; they are registered separately by + // RegisterClusterRemediators only when a cluster client is available (a + // config naming them otherwise fails dispatch, the desired fail-safe). + // SetDryRun is applied above so the closures pick up the correct dry-run state. remediators.RegisterBuiltinRemediators(remediatorRegistry, config) log.Printf("[INFO] Registered built-in remediators: %v", remediatorRegistry.GetRegisteredTypes()) diff --git a/pkg/detector/detector.go b/pkg/detector/detector.go index e1d20a3..b73eaed 100644 --- a/pkg/detector/detector.go +++ b/pkg/detector/detector.go @@ -591,6 +591,7 @@ func (pd *ProblemDetector) evaluateRemediation(status *types.Status) { // - "service" : strat.Service (systemd-restart target service) // - "scriptPath" : strat.ScriptPath (custom-script script path) // - "args" : JSON-encoded strat.Args (custom-script arguments) +// - "interface" : strat.Interface (restart-interface target interface) // // Only non-empty params are added so a strategy that does not use a given param // does not pollute the metadata. @@ -606,6 +607,9 @@ func buildProblemMetadata(source string, cond types.Condition, strat types.Monit if strat.ScriptPath != "" { meta["scriptPath"] = strat.ScriptPath } + if strat.Interface != "" { + meta["interface"] = strat.Interface + } if len(strat.Args) > 0 { // JSON-encode args so values containing commas/spaces survive intact. if encoded, err := json.Marshal(strat.Args); err == nil { @@ -648,6 +652,7 @@ func buildStrategyList(remCfg *types.MonitorRemediationConfig) []types.MonitorRe Service: remCfg.Service, ScriptPath: remCfg.ScriptPath, Args: remCfg.Args, + Interface: remCfg.Interface, }} } diff --git a/pkg/detector/remediation_test.go b/pkg/detector/remediation_test.go index 82281a8..0780ec4 100644 --- a/pkg/detector/remediation_test.go +++ b/pkg/detector/remediation_test.go @@ -498,6 +498,42 @@ func TestEvaluateRemediation_MultiStrategyThreadsPerStrategyParams(t *testing.T) } } +// TestEvaluateRemediation_ThreadsInterfaceMetadata verifies that a +// restart-interface strategy's Interface field is threaded into the dispatched +// Problem.Metadata["interface"] so the NetworkRemediator can resolve it at +// Remediate time (TaskForge #19263 Phase 3). +func TestEvaluateRemediation_ThreadsInterfaceMetadata(t *testing.T) { + exec := NewMockRemediationExecutor() + + monCfg := types.MonitorConfig{ + Name: "iface-monitor", + Type: "test", + Enabled: true, + Interval: 30 * time.Second, + Timeout: 10 * time.Second, + Remediation: &types.MonitorRemediationConfig{ + Enabled: true, + Strategy: "restart-interface", + Interface: "eth0", + }, + } + pd, mon := buildDetectorWithRemediation(t, monCfg, exec) + _ = pd + + mon.AddStatusUpdate(unhealthyStatus("iface-monitor", "InterfaceHealthy")) + if !pollUntil(t, time.Second, func() bool { return exec.CallCount() == 1 }) { + t.Fatalf("expected 1 executor call within 1s, got %d", exec.CallCount()) + } + + call := exec.Calls()[0] + if call.RemediatorType != "restart-interface" { + t.Errorf("strategy = %q, want restart-interface", call.RemediatorType) + } + if got := call.Problem.Metadata["interface"]; got != "eth0" { + t.Errorf("Problem.Metadata[interface] = %q, want eth0", got) + } +} + // TestEvaluateRemediation_MultiStrategyFirstSuccessWins verifies that when the // first strategy succeeds, subsequent strategies are not attempted. func TestEvaluateRemediation_MultiStrategyFirstSuccessWins(t *testing.T) { diff --git a/pkg/remediators/builtin.go b/pkg/remediators/builtin.go index aadafa4..81fa0a8 100644 --- a/pkg/remediators/builtin.go +++ b/pkg/remediators/builtin.go @@ -31,6 +31,19 @@ const ( // destructive actions. StrategyNodeReboot = "node-reboot" StrategyPodDelete = "pod-delete" + + // Network remediation strategies (TaskForge #19263 Phase 3). Each maps + // directly to a NetworkRemediator operation (see NetworkOperation consts in + // network.go). flush-dns / reset-routing / flush-ipv6-route are + // non-destructive cache/route flushes; restart-interface briefly bounces an + // interface (its target comes from Problem.Metadata["interface"]). All four + // are SAFE-enough to live with the built-ins (NOT the destructive cluster + // set). Registering StrategyFlushIPv6Route makes flush-ipv6-route reachable, + // closing Task #17222. + StrategyFlushDNS = string(NetworkFlushDNS) + StrategyRestartInterface = string(NetworkRestartInterface) + StrategyResetRouting = string(NetworkResetRouting) + StrategyFlushIPv6Route = string(NetworkFlushIPv6Route) ) // RegisterBuiltinRemediators registers the SAFE built-in remediator strategies @@ -101,6 +114,42 @@ func RegisterBuiltinRemediators(registry *RemediatorRegistry, cfg *types.NodeDoc }, Description: "Runs the remediation script named in the triggering monitor's remediation config (Problem.Metadata[\"scriptPath\"]/[\"args\"]).", }) + + // Network remediation strategies (TaskForge #19263 Phase 3). Each is keyed by + // its operation string and builds a NewNetworkRemediator for that operation, + // honoring the registry dry-run state. flush-dns / reset-routing / + // flush-ipv6-route take no per-call parameters; restart-interface resolves + // its target interface per-call from Problem.Metadata["interface"] + // (AllowMetadataInterface lets it construct with an empty InterfaceName). + // Registering flush-ipv6-route makes that operation reachable -> closes #17222. + networkOps := []struct { + strategy string + operation NetworkOperation + description string + }{ + {StrategyFlushDNS, NetworkFlushDNS, "Flushes the resolver DNS cache (resolvectl/systemd-resolve; clears A and AAAA)."}, + {StrategyRestartInterface, NetworkRestartInterface, "Bounces the network interface named in Problem.Metadata[\"interface\"] (down then up)."}, + {StrategyResetRouting, NetworkResetRouting, "Flushes the IPv4 routing cache (ip route flush cache)."}, + {StrategyFlushIPv6Route, NetworkFlushIPv6Route, "Flushes the IPv6 route cache (ip -6 route flush cache)."}, + } + for _, op := range networkOps { + op := op // capture per-iteration for the closure + registry.Register(RemediatorInfo{ + Type: op.strategy, + Factory: func() (types.Remediator, error) { + return NewNetworkRemediator(NetworkConfig{ + Operation: op.operation, + // InterfaceName intentionally empty: restart-interface resolves + // it per-call from Problem.Metadata["interface"]. + // AllowMetadataInterface permits construction without a fixed + // interface for the metadata-driven singleton. + AllowMetadataInterface: op.operation == NetworkRestartInterface, + DryRun: dryRun, + }) + }, + Description: op.description, + }) + } } // RegisterClusterRemediators registers the DESTRUCTIVE cluster-scoped remediator diff --git a/pkg/remediators/builtin_test.go b/pkg/remediators/builtin_test.go index 200e07b..06e7dbb 100644 --- a/pkg/remediators/builtin_test.go +++ b/pkg/remediators/builtin_test.go @@ -8,31 +8,46 @@ import ( "github.com/supporttools/node-doctor/pkg/types" ) -// TestRegisterBuiltinRemediators_RegistersSafeStrategiesOnly verifies that -// Phase 1 registers ONLY systemd-restart and custom-script, and explicitly does -// NOT register the destructive node-reboot / pod-delete strategies (Phase 2). +// TestRegisterBuiltinRemediators_RegistersSafeStrategiesOnly verifies that the +// built-in registration registers the safe strategies (systemd-restart, +// custom-script, and the four network operations) and explicitly does NOT +// register the destructive node-reboot / pod-delete strategies (those are +// Phase 2, registered only by RegisterClusterRemediators). func TestRegisterBuiltinRemediators_RegistersSafeStrategiesOnly(t *testing.T) { registry := NewRegistry(10, 100) RegisterBuiltinRemediators(registry, &types.NodeDoctorConfig{}) - if !registry.IsRegistered(StrategySystemdRestart) { - t.Errorf("expected %q to be registered", StrategySystemdRestart) + // All safe built-in strategies must be registered. + safe := []string{ + StrategySystemdRestart, + StrategyCustomScript, + StrategyFlushDNS, + StrategyRestartInterface, + StrategyResetRouting, + StrategyFlushIPv6Route, } - if !registry.IsRegistered(StrategyCustomScript) { - t.Errorf("expected %q to be registered", StrategyCustomScript) + for _, s := range safe { + if !registry.IsRegistered(s) { + t.Errorf("expected %q to be registered", s) + } + } + + // flush-ipv6-route reachable closes #17222. + if !registry.IsRegistered("flush-ipv6-route") { + t.Errorf("flush-ipv6-route must be registered (closes #17222)") } - // Destructive strategies must NOT be registered in Phase 1. + // Destructive strategies must NOT be registered by the built-ins. if registry.IsRegistered(StrategyNodeReboot) { - t.Errorf("%q must NOT be registered in Phase 1 (destructive, deferred to Phase 2)", StrategyNodeReboot) + t.Errorf("%q must NOT be registered by built-ins (destructive, cluster-only)", StrategyNodeReboot) } if registry.IsRegistered(StrategyPodDelete) { - t.Errorf("%q must NOT be registered in Phase 1 (destructive, deferred to Phase 2)", StrategyPodDelete) + t.Errorf("%q must NOT be registered by built-ins (destructive, cluster-only)", StrategyPodDelete) } got := registry.GetRegisteredTypes() - if len(got) != 2 { - t.Errorf("expected exactly 2 registered types, got %d: %v", len(got), got) + if len(got) != 6 { + t.Errorf("expected exactly 6 registered types, got %d: %v", len(got), got) } } @@ -224,3 +239,137 @@ func TestBuiltinCustomScript_InvalidArgsJSONRejected(t *testing.T) { t.Errorf("script must NOT be executed when args are invalid, executed: %v", got) } } + +// TestBuiltinFlushIPv6Route_DryRunDispatch verifies that flush-ipv6-route is +// registered and dispatchable via the registry in dry-run, reaching the +// NetworkRemediator path without executing a real command. This is the wiring +// that makes flush-ipv6-route reachable (closes Task #17222). +func TestBuiltinFlushIPv6Route_DryRunDispatch(t *testing.T) { + registry := NewRegistry(10, 100) + registry.SetDryRun(true) + RegisterBuiltinRemediators(registry, &types.NodeDoctorConfig{}) + + if !registry.IsRegistered(StrategyFlushIPv6Route) { + t.Fatalf("flush-ipv6-route must be registered after RegisterBuiltinRemediators") + } + + problem := types.Problem{ + Type: StrategyFlushIPv6Route, + Resource: "ipv6-routing-table", + Severity: types.ProblemCritical, + } + if err := registry.Remediate(context.Background(), StrategyFlushIPv6Route, problem); err != nil { + t.Fatalf("Remediate(flush-ipv6-route) dry-run failed: %v", err) + } +} + +// TestBuiltinFlushIPv6Route_DryRunNoRealCommand verifies the remediator built by +// the flush-ipv6-route factory honors dry-run: with an injected executor and +// DryRun=true, the ipv6 flush op runs through the remediator but NO real command +// is executed. +func TestBuiltinFlushIPv6Route_DryRunNoRealCommand(t *testing.T) { + r, err := NewNetworkRemediator(NetworkConfig{ + Operation: NetworkFlushIPv6Route, + DryRun: true, + }) + if err != nil { + t.Fatalf("NewNetworkRemediator(flush-ipv6-route): %v", err) + } + mock := &mockNetworkExecutor{} + r.SetNetworkExecutor(mock) + + problem := types.Problem{Type: StrategyFlushIPv6Route} + if err := r.Remediate(context.Background(), problem); err != nil { + t.Fatalf("Remediate(flush-ipv6-route) dry-run: %v", err) + } + if len(mock.executedCommands) != 0 { + t.Errorf("dry-run must execute NO real command, executed: %v", mock.executedCommands) + } +} + +// TestBuiltinFlushIPv6Route_ReachesExecutor verifies a non-dry-run flush-ipv6-route +// remediation reaches the injected executor and issues the ipv6 flush command. +func TestBuiltinFlushIPv6Route_ReachesExecutor(t *testing.T) { + r, err := NewNetworkRemediator(NetworkConfig{Operation: NetworkFlushIPv6Route}) + if err != nil { + t.Fatalf("NewNetworkRemediator(flush-ipv6-route): %v", err) + } + mock := &mockNetworkExecutor{} + r.SetNetworkExecutor(mock) + + problem := types.Problem{Type: StrategyFlushIPv6Route} + if err := r.Remediate(context.Background(), problem); err != nil { + t.Fatalf("Remediate(flush-ipv6-route): %v", err) + } + + found := false + for _, c := range mock.executedCommands { + if c == "ip -6 route flush cache" { + found = true + } + } + if !found { + t.Errorf("expected 'ip -6 route flush cache', executed: %v", mock.executedCommands) + } +} + +// TestBuiltinRestartInterface_MetadataInterface verifies the restart-interface +// remediator built by the builtin factory (AllowMetadataInterface, empty fixed +// interface) resolves its target interface from Problem.Metadata["interface"] +// and bounces it via the injected executor. +func TestBuiltinRestartInterface_MetadataInterface(t *testing.T) { + // Same construction the builtin factory uses for restart-interface. + r, err := NewNetworkRemediator(NetworkConfig{ + Operation: NetworkRestartInterface, + AllowMetadataInterface: true, + }) + if err != nil { + t.Fatalf("NewNetworkRemediator(restart-interface): %v", err) + } + mock := &mockNetworkExecutor{interfaceExists: true} + r.SetNetworkExecutor(mock) + + problem := types.Problem{ + Type: StrategyRestartInterface, + Metadata: map[string]string{metadataKeyInterface: "eth0"}, + } + if err := r.Remediate(context.Background(), problem); err != nil { + t.Fatalf("Remediate(restart-interface): %v", err) + } + + wantDown, wantUp := false, false + for _, c := range mock.executedCommands { + if c == "ip link set eth0 down" { + wantDown = true + } + if c == "ip link set eth0 up" { + wantUp = true + } + } + if !wantDown || !wantUp { + t.Errorf("expected eth0 to be bounced (down+up), executed: %v", mock.executedCommands) + } +} + +// TestBuiltinRestartInterface_MissingInterfaceFails verifies a restart-interface +// remediation with neither config nor metadata interface fails at Remediate time +// (rather than acting on an empty interface), and executes no command. +func TestBuiltinRestartInterface_MissingInterfaceFails(t *testing.T) { + r, err := NewNetworkRemediator(NetworkConfig{ + Operation: NetworkRestartInterface, + AllowMetadataInterface: true, + }) + if err != nil { + t.Fatalf("NewNetworkRemediator(restart-interface): %v", err) + } + mock := &mockNetworkExecutor{interfaceExists: true} + r.SetNetworkExecutor(mock) + + problem := types.Problem{Type: StrategyRestartInterface} // no interface metadata + if err := r.Remediate(context.Background(), problem); err == nil { + t.Fatal("expected error when no interface is specified, got nil") + } + if len(mock.executedCommands) != 0 { + t.Errorf("no command must run when interface is missing, executed: %v", mock.executedCommands) + } +} diff --git a/pkg/remediators/network.go b/pkg/remediators/network.go index 77c3964..2a90376 100644 --- a/pkg/remediators/network.go +++ b/pkg/remediators/network.go @@ -39,6 +39,15 @@ type NetworkConfig struct { // BackupRouting when true, backs up routing table before reset (for ResetRouting) BackupRouting bool + // AllowMetadataInterface allows constructing a restart-interface remediator + // without a fixed InterfaceName, deferring resolution to Remediate time from + // Problem.Metadata["interface"]. This is set for the metadata-driven + // singleton registered by RegisterBuiltinRemediators so a single remediator + // can bounce whatever interface the triggering monitor named. When false (the + // default) a restart-interface config without an InterfaceName fails + // construction as before. + AllowMetadataInterface bool + // VerifyAfter when true, verifies the operation succeeded VerifyAfter bool @@ -159,8 +168,14 @@ func validateNetworkConfig(config NetworkConfig) error { return fmt.Errorf("invalid operation: %s (must be flush-dns, restart-interface, reset-routing, or flush-ipv6-route)", config.Operation) } - // RestartInterface requires an interface name - if config.Operation == NetworkRestartInterface && config.InterfaceName == "" { + // RestartInterface requires an interface name. The name may instead be + // supplied per-call via Problem.Metadata["interface"] (the metadata-driven + // singleton built by RegisterBuiltinRemediators), so only enforce a + // construction-time interface when the remediator allows per-call metadata + // resolution to be disabled. AllowMetadataInterface signals that the missing + // InterfaceName will be resolved at Remediate time and must not fail + // construction. + if config.Operation == NetworkRestartInterface && config.InterfaceName == "" && !config.AllowMetadataInterface { return fmt.Errorf("interface name is required for restart-interface operation") } @@ -172,16 +187,44 @@ func validateNetworkConfig(config NetworkConfig) error { return nil } +// resolveInterfaceName returns the network interface to act on for this +// remediation. It prefers the per-call "interface" metadata key (threaded from +// the strategy's MonitorRemediationConfig.Interface) and falls back to the +// construction-time config.InterfaceName when the metadata is absent. This +// mirrors SystemdRemediator.resolveServiceName so a single registered +// restart-interface remediator can bounce whatever interface the triggering +// monitor declared. +func (r *NetworkRemediator) resolveInterfaceName(problem types.Problem) string { + if problem.Metadata != nil { + if iface := problem.Metadata[metadataKeyInterface]; iface != "" { + return iface + } + } + return r.config.InterfaceName +} + // remediate performs the actual network remediation. +// +// For restart-interface the target interface is resolved per-call: when the +// Problem carries an "interface" metadata key (set by the detector from the +// strategy's MonitorRemediationConfig.Interface), that value is used; otherwise +// the remediator falls back to its construction-time config.InterfaceName. func (r *NetworkRemediator) remediate(ctx context.Context, problem types.Problem) error { + interfaceName := r.resolveInterfaceName(problem) + // Dry-run mode if r.config.DryRun { r.logInfof("DRY-RUN: Would execute network operation: %s", r.config.Operation) return nil } + // restart-interface needs a concrete interface name (from config or metadata). + if r.config.Operation == NetworkRestartInterface && interfaceName == "" { + return fmt.Errorf("no network interface specified for restart-interface (neither problem metadata %q nor config.InterfaceName set)", metadataKeyInterface) + } + // Execute the network operation - if err := r.executeOperation(ctx); err != nil { + if err := r.executeOperation(ctx, interfaceName); err != nil { return fmt.Errorf("failed to execute %s: %w", r.config.Operation, err) } @@ -189,7 +232,7 @@ func (r *NetworkRemediator) remediate(ctx context.Context, problem types.Problem // Verify operation success if configured if r.config.VerifyAfter { - if err := r.verifyOperation(ctx); err != nil { + if err := r.verifyOperation(ctx, interfaceName); err != nil { return fmt.Errorf("operation verification failed: %w", err) } } @@ -197,13 +240,14 @@ func (r *NetworkRemediator) remediate(ctx context.Context, problem types.Problem return nil } -// executeOperation executes the configured network operation. -func (r *NetworkRemediator) executeOperation(ctx context.Context) error { +// executeOperation executes the configured network operation against the +// resolved interfaceName (used only by restart-interface). +func (r *NetworkRemediator) executeOperation(ctx context.Context, interfaceName string) error { switch r.config.Operation { case NetworkFlushDNS: return r.flushDNS(ctx) case NetworkRestartInterface: - return r.restartInterface(ctx) + return r.restartInterface(ctx, interfaceName) case NetworkResetRouting: return r.resetRouting(ctx) case NetworkFlushIPv6Route: @@ -240,21 +284,22 @@ func (r *NetworkRemediator) flushDNS(ctx context.Context) error { } // restartInterface restarts a network interface by bringing it down and then up. -func (r *NetworkRemediator) restartInterface(ctx context.Context) error { - r.logInfof("Restarting network interface: %s", r.config.InterfaceName) +// interfaceName is resolved per-call (metadata or config) by remediate. +func (r *NetworkRemediator) restartInterface(ctx context.Context, interfaceName string) error { + r.logInfof("Restarting network interface: %s", interfaceName) // Safety check: verify interface exists - exists, err := r.networkExecutor.InterfaceExists(ctx, r.config.InterfaceName) + exists, err := r.networkExecutor.InterfaceExists(ctx, interfaceName) if err != nil { return fmt.Errorf("failed to verify interface exists: %w", err) } if !exists { - return fmt.Errorf("interface %s does not exist", r.config.InterfaceName) + return fmt.Errorf("interface %s does not exist", interfaceName) } // Bring interface down - r.logInfof("Bringing interface %s down", r.config.InterfaceName) - output, err := r.networkExecutor.ExecuteCommand(ctx, "ip", "link", "set", r.config.InterfaceName, "down") + r.logInfof("Bringing interface %s down", interfaceName) + output, err := r.networkExecutor.ExecuteCommand(ctx, "ip", "link", "set", interfaceName, "down") if err != nil { return fmt.Errorf("failed to bring interface down: %w (output: %s)", err, output) } @@ -263,8 +308,8 @@ func (r *NetworkRemediator) restartInterface(ctx context.Context) error { time.Sleep(500 * time.Millisecond) // Bring interface up - r.logInfof("Bringing interface %s up", r.config.InterfaceName) - output, err = r.networkExecutor.ExecuteCommand(ctx, "ip", "link", "set", r.config.InterfaceName, "up") + r.logInfof("Bringing interface %s up", interfaceName) + output, err = r.networkExecutor.ExecuteCommand(ctx, "ip", "link", "set", interfaceName, "up") if err != nil { return fmt.Errorf("failed to bring interface up: %w (output: %s)", err, output) } @@ -340,7 +385,7 @@ func (r *NetworkRemediator) flushIPv6RouteCache(ctx context.Context) error { } // verifyOperation verifies that the network operation succeeded. -func (r *NetworkRemediator) verifyOperation(ctx context.Context) error { +func (r *NetworkRemediator) verifyOperation(ctx context.Context, interfaceName string) error { // Create a context with timeout for verification verifyCtx, cancel := context.WithTimeout(ctx, r.config.VerifyTimeout) defer cancel() @@ -353,7 +398,7 @@ func (r *NetworkRemediator) verifyOperation(ctx context.Context) error { case NetworkRestartInterface: // Verify interface is up after restart - return r.verifyInterfaceUp(verifyCtx) + return r.verifyInterfaceUp(verifyCtx, interfaceName) case NetworkResetRouting: // Verify routing table exists after reset @@ -370,28 +415,28 @@ func (r *NetworkRemediator) verifyOperation(ctx context.Context) error { } // verifyInterfaceUp verifies that an interface is up after restart. -func (r *NetworkRemediator) verifyInterfaceUp(ctx context.Context) error { +func (r *NetworkRemediator) verifyInterfaceUp(ctx context.Context, interfaceName string) error { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): - return fmt.Errorf("timeout waiting for interface %s to come up after %v", r.config.InterfaceName, r.config.VerifyTimeout) + return fmt.Errorf("timeout waiting for interface %s to come up after %v", interfaceName, r.config.VerifyTimeout) case <-ticker.C: - isUp, err := r.networkExecutor.IsInterfaceUp(ctx, r.config.InterfaceName) + isUp, err := r.networkExecutor.IsInterfaceUp(ctx, interfaceName) if err != nil { r.logWarnf("Error checking interface status during verification: %v", err) continue } if isUp { - r.logInfof("Interface %s is up", r.config.InterfaceName) + r.logInfof("Interface %s is up", interfaceName) return nil } - r.logInfof("Waiting for interface %s to come up...", r.config.InterfaceName) + r.logInfof("Waiting for interface %s to come up...", interfaceName) } } } diff --git a/pkg/remediators/network_test.go b/pkg/remediators/network_test.go index 295a79f..6ef0d37 100644 --- a/pkg/remediators/network_test.go +++ b/pkg/remediators/network_test.go @@ -1099,7 +1099,7 @@ func TestNetworkRemediator_VerifyOperation_FlushIPv6Route(t *testing.T) { r.SetNetworkExecutor(mockExec) ctx := context.Background() - if err := r.verifyOperation(ctx); err != nil { + if err := r.verifyOperation(ctx, ""); err != nil { t.Errorf("verifyOperation() unexpected error: %v", err) } diff --git a/pkg/remediators/registry.go b/pkg/remediators/registry.go index 813b6e7..007dab4 100644 --- a/pkg/remediators/registry.go +++ b/pkg/remediators/registry.go @@ -58,6 +58,12 @@ const ( // metadataKeyArgs carries the JSON-encoded script arguments for the // custom-script strategy (from MonitorRemediationConfig.Args). metadataKeyArgs = "args" + + // metadataKeyInterface carries the network interface name for the + // restart-interface strategy (from MonitorRemediationConfig.Interface). + // It is consumed by NetworkRemediator at Remediate time, mirroring how + // metadataKeyService is consumed by SystemdRemediator. + metadataKeyInterface = "interface" ) // CircuitBreakerState represents the state of the circuit breaker. diff --git a/pkg/types/config.go b/pkg/types/config.go index bc88bc6..e6d618b 100644 --- a/pkg/types/config.go +++ b/pkg/types/config.go @@ -73,6 +73,13 @@ var ( "custom-script": true, "node-reboot": true, "pod-delete": true, + // Network remediation operations (TaskForge #19263 Phase 3). These map + // directly to NetworkRemediator operations and are registered as safe + // built-in strategies. "flush-ipv6-route" closes Task #17222. + "flush-dns": true, + "restart-interface": true, + "reset-routing": true, + "flush-ipv6-route": true, } // Minimum interval thresholds (conservative settings to prevent system overload) @@ -215,6 +222,12 @@ type MonitorRemediationConfig struct { // ScriptPath is the path to remediation script (for custom-script strategy) ScriptPath string `json:"scriptPath,omitempty" yaml:"scriptPath,omitempty"` + // Interface is the network interface name (for the restart-interface + // network remediation strategy). Examples: "eth0", "ens3". It is threaded + // to the NetworkRemediator via Problem.Metadata["interface"] at dispatch + // time, mirroring how Service is threaded for systemd-restart. + Interface string `json:"interface,omitempty" yaml:"interface,omitempty"` + // Args are arguments to pass to the script Args []string `json:"args,omitempty" yaml:"args,omitempty"` @@ -1255,7 +1268,7 @@ func (r *MonitorRemediationConfig) Validate() error { return fmt.Errorf("strategy is required when remediation is enabled") } if !validRemediationStrategies[r.Strategy] { - return fmt.Errorf("invalid strategy %q, must be one of: systemd-restart, custom-script, node-reboot, pod-delete", r.Strategy) + return fmt.Errorf("invalid strategy %q, must be one of: systemd-restart, custom-script, node-reboot, pod-delete, flush-dns, restart-interface, reset-routing, flush-ipv6-route", r.Strategy) } // Strategy-specific validation @@ -1277,6 +1290,13 @@ func (r *MonitorRemediationConfig) Validate() error { } // Note: File existence check skipped to support containerized deployments // where scripts may be mounted at runtime + case "restart-interface": + // The restart-interface network strategy must name the interface to + // bounce, mirroring how systemd-restart requires Service. The value is + // threaded to the NetworkRemediator via Problem.Metadata["interface"]. + if r.Interface == "" { + return fmt.Errorf("interface is required for restart-interface strategy") + } } // Validate cooldown is positive diff --git a/pkg/types/config_test.go b/pkg/types/config_test.go index 5d1e47c..986b008 100644 --- a/pkg/types/config_test.go +++ b/pkg/types/config_test.go @@ -601,6 +601,58 @@ func TestMonitorRemediationConfigValidation(t *testing.T) { wantErr: true, errMsg: "maxAttempts must be positive", }, + { + name: "valid flush-dns network strategy", + input: MonitorRemediationConfig{ + Enabled: true, + Strategy: "flush-dns", + Cooldown: 5 * time.Minute, + MaxAttempts: 3, + }, + wantErr: false, + }, + { + name: "valid reset-routing network strategy", + input: MonitorRemediationConfig{ + Enabled: true, + Strategy: "reset-routing", + Cooldown: 5 * time.Minute, + MaxAttempts: 3, + }, + wantErr: false, + }, + { + name: "valid flush-ipv6-route network strategy", + input: MonitorRemediationConfig{ + Enabled: true, + Strategy: "flush-ipv6-route", + Cooldown: 5 * time.Minute, + MaxAttempts: 3, + }, + wantErr: false, + }, + { + name: "valid restart-interface network strategy with interface", + input: MonitorRemediationConfig{ + Enabled: true, + Strategy: "restart-interface", + Interface: "eth0", + Cooldown: 5 * time.Minute, + MaxAttempts: 3, + }, + wantErr: false, + }, + { + name: "restart-interface without interface", + input: MonitorRemediationConfig{ + Enabled: true, + Strategy: "restart-interface", + Cooldown: 5 * time.Minute, + MaxAttempts: 3, + }, + wantErr: true, + errMsg: "interface is required for restart-interface strategy", + }, } for _, tt := range tests { From 24d20c7b561080afdb346eadf1ffe1a5c8958c3e Mon Sep 17 00:00:00 2001 From: Matthew Mattox Date: Thu, 25 Jun 2026 08:53:19 -0500 Subject: [PATCH 11/11] test(prometheus): eliminate freePort TOCTOU bind flake (Task #19260) Remove the freePort helper (bind :0 -> read port -> close -> rebind race). The exporter now captures its real bound address in Start (from the eagerly-bound listener) and exposes BoundAddr()/BoundPort(); tests construct an ephemeral-port exporter (Port 0, 9100 default applied only on the production path) and read the actual port back instead of pre-allocating. Migrated all freePort call sites. Verified: package -count=20 and -race -count=5 pass with zero bind flakes; full -short suite stable. --- .../prometheus/dns_health_score_test.go | 33 +-- pkg/exporters/prometheus/exporter.go | 86 ++++++- pkg/exporters/prometheus/exporter_test.go | 210 ++++-------------- pkg/exporters/prometheus/integration_test.go | 134 +++-------- pkg/exporters/prometheus/testhelpers_test.go | 15 -- 5 files changed, 175 insertions(+), 303 deletions(-) diff --git a/pkg/exporters/prometheus/dns_health_score_test.go b/pkg/exporters/prometheus/dns_health_score_test.go index 38d21d3..8519a66 100644 --- a/pkg/exporters/prometheus/dns_health_score_test.go +++ b/pkg/exporters/prometheus/dns_health_score_test.go @@ -74,24 +74,18 @@ func TestDNSHealthScore_GaugeValueRecording(t *testing.T) { // TestDNSHealthScore_LabelCorrectness verifies that the node label reflects the // configured node name and the nameserver label contains the nameserver address. func TestDNSHealthScore_LabelCorrectness(t *testing.T) { - port := freePort(t) - cfg := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } settings := &types.GlobalSettings{NodeName: "my-custom-node"} - exp, err := NewPrometheusExporter(cfg, settings) + exp, err := newEphemeralExporter(settings) if err != nil { - t.Fatalf("NewPrometheusExporter: %v", err) + t.Fatalf("newEphemeralExporter: %v", err) } ctx := context.Background() if err := exp.Start(ctx); err != nil { t.Fatalf("Start: %v", err) } t.Cleanup(func() { exp.Stop() }) + port := exp.BoundPort() if err := waitForServerReady(fmt.Sprintf("localhost:%d", port), 5*time.Second); err != nil { t.Fatalf("server not ready: %v", err) } @@ -226,26 +220,23 @@ func TestDNSHealthScore_AllInsufficientDataClearsGauges(t *testing.T) { // ── helpers ────────────────────────────────────────────────────────────────── -// newStartedExporter creates and starts a PrometheusExporter on a free port -// with namespace "test" and node name "test-node". It registers t.Cleanup to -// stop the exporter and blocks until the server is ready. +// newStartedExporter creates and starts a PrometheusExporter on an OS-assigned +// ephemeral port with namespace "test" and node name "test-node". It registers +// t.Cleanup to stop the exporter and blocks until the server is ready. The +// returned port is the real bound port (read back via BoundPort), so there is no +// close-then-rebind window for another process to grab — the exporter binds the +// port exactly once. func newStartedExporter(t *testing.T) (*PrometheusExporter, int) { t.Helper() - port := freePort(t) - cfg := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - exp, err := NewPrometheusExporter(cfg, &types.GlobalSettings{NodeName: "test-node"}) + exp, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { - t.Fatalf("NewPrometheusExporter: %v", err) + t.Fatalf("newEphemeralExporter: %v", err) } if err := exp.Start(context.Background()); err != nil { t.Fatalf("Start: %v", err) } t.Cleanup(func() { exp.Stop() }) + port := exp.BoundPort() if err := waitForServerReady(fmt.Sprintf("localhost:%d", port), 5*time.Second); err != nil { t.Fatalf("server not ready: %v", err) } diff --git a/pkg/exporters/prometheus/exporter.go b/pkg/exporters/prometheus/exporter.go index 3f2279c..0383be6 100644 --- a/pkg/exporters/prometheus/exporter.go +++ b/pkg/exporters/prometheus/exporter.go @@ -4,8 +4,10 @@ import ( "context" "fmt" "log" + "net" "net/http" "runtime" + "strconv" "sync" "time" @@ -26,6 +28,19 @@ type PrometheusExporter struct { mu sync.RWMutex started bool + // boundAddr is the actual host:port the HTTP server is listening on, captured + // from the bound net.Listener after Start. When an ephemeral port (0) is + // requested, this is the only place the real, kernel-assigned port can be + // read. It is guarded by mu and exposed via BoundAddr/BoundPort. + boundAddr string + + // ephemeral, when true, suppresses the production Port==0 -> 9100 default so + // the server binds an OS-assigned ephemeral port (bind :0). This is a test + // seam (see newEphemeralExporter) that lets tests bind hermetically and read + // the real port back via BoundPort, eliminating the freePort close-then-rebind + // TOCTOU race. + ephemeral bool + // consecutiveFailures tracks the running count of failed exports since the // last successful export. It backs the ExporterConsecutiveFailures gauge and // is guarded by mu to avoid racy read-modify-write on the gauge itself. @@ -34,6 +49,29 @@ type PrometheusExporter struct { // NewPrometheusExporter creates a new Prometheus exporter with the given configuration func NewPrometheusExporter(config *types.PrometheusExporterConfig, settings *types.GlobalSettings) (*PrometheusExporter, error) { + return newPrometheusExporter(config, settings, false) +} + +// newEphemeralExporter is a test-only constructor that builds an exporter which +// binds an OS-assigned ephemeral port (bind :0) instead of applying the +// production Port==0 -> 9100 default. After Start, the real bound port is +// available via BoundPort/BoundAddr. This makes the exporter test path hermetic: +// the port is bound exactly once, with no close-then-rebind window for another +// process to grab it. +func newEphemeralExporter(settings *types.GlobalSettings) (*PrometheusExporter, error) { + config := &types.PrometheusExporterConfig{ + Enabled: true, + Port: 0, // ephemeral: bound by the OS, read back via BoundPort + Path: "/metrics", + Namespace: "test", + } + return newPrometheusExporter(config, settings, true) +} + +// newPrometheusExporter is the shared constructor. When ephemeral is true the +// Port==0 -> 9100 production default is skipped so the server binds an OS-assigned +// ephemeral port. +func newPrometheusExporter(config *types.PrometheusExporterConfig, settings *types.GlobalSettings, ephemeral bool) (*PrometheusExporter, error) { if config == nil { return nil, fmt.Errorf("config cannot be nil") } @@ -54,8 +92,10 @@ func NewPrometheusExporter(config *types.PrometheusExporterConfig, settings *typ return nil, fmt.Errorf("node name is required") } - // Set defaults - if config.Port == 0 { + // Set defaults. In ephemeral mode, leave Port at 0 so the OS assigns a free + // port at bind time (the real port is read back via BoundPort after Start); + // otherwise apply the production default of 9100. + if config.Port == 0 && !ephemeral { config.Port = 9100 } if config.BindAddress == "" { @@ -96,6 +136,7 @@ func NewPrometheusExporter(config *types.PrometheusExporterConfig, settings *typ metrics: metrics, startTime: time.Now(), activeProblems: make(map[string]*types.Problem), + ephemeral: ephemeral, } log.Printf("[INFO] Created Prometheus exporter on port %d with namespace '%s'", @@ -126,6 +167,11 @@ func (e *PrometheusExporter) Start(ctx context.Context) error { } e.server = server + // Capture the actual bound address from the listener (server.Addr is set to + // ln.Addr().String() by startHTTPServer). This is set synchronously before + // Start returns and is the authoritative source for the real port, which + // matters when an ephemeral port (0) was requested. + e.boundAddr = server.Addr e.started = true log.Printf("[INFO] Prometheus exporter started successfully on %s%s", @@ -134,6 +180,35 @@ func (e *PrometheusExporter) Start(ctx context.Context) error { return nil } +// BoundAddr returns the actual host:port the exporter's HTTP server is listening +// on, as captured from the bound listener after Start. It returns "" before Start +// has succeeded. When an ephemeral port (0) was requested, this reports the real, +// kernel-assigned port. Safe for concurrent use. +func (e *PrometheusExporter) BoundAddr() string { + e.mu.RLock() + defer e.mu.RUnlock() + return e.boundAddr +} + +// BoundPort returns the actual port the exporter's HTTP server is listening on, +// extracted from BoundAddr. It returns 0 before Start has succeeded or if the +// bound address cannot be parsed. Safe for concurrent use. +func (e *PrometheusExporter) BoundPort() int { + addr := e.BoundAddr() + if addr == "" { + return 0 + } + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return 0 + } + port, err := strconv.Atoi(portStr) + if err != nil { + return 0 + } + return port +} + // Stop gracefully stops the Prometheus exporter func (e *PrometheusExporter) Stop() error { e.mu.Lock() @@ -546,8 +621,10 @@ func (e *PrometheusExporter) Reload(config interface{}) error { oldConfig := e.config - // Set defaults for new config - if prometheusConfig.Port == 0 { + // Set defaults for new config. In ephemeral mode (test seam), leave Port at 0 + // so a restart binds a fresh OS-assigned port rather than the production 9100 + // default; the real port is read back via BoundPort. + if prometheusConfig.Port == 0 && !e.ephemeral { prometheusConfig.Port = 9100 } if prometheusConfig.BindAddress == "" { @@ -614,6 +691,7 @@ func (e *PrometheusExporter) Reload(config interface{}) error { return fmt.Errorf("failed to start new HTTP server: %w", err) } e.server = server + e.boundAddr = server.Addr log.Printf("[INFO] Prometheus server restarted on %s%s", server.Addr, prometheusConfig.Path) } diff --git a/pkg/exporters/prometheus/exporter_test.go b/pkg/exporters/prometheus/exporter_test.go index 84772dd..169a1da 100644 --- a/pkg/exporters/prometheus/exporter_test.go +++ b/pkg/exporters/prometheus/exporter_test.go @@ -129,16 +129,7 @@ func TestNewPrometheusExporter(t *testing.T) { } func TestPrometheusExporterLifecycle(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -152,12 +143,13 @@ func TestPrometheusExporterLifecycle(t *testing.T) { } // Wait for server to be ready before making requests - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { t.Fatalf("failed to connect to metrics server: %v", err) } @@ -187,16 +179,7 @@ func TestPrometheusExporterLifecycle(t *testing.T) { } func TestExportStatus(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -260,16 +243,7 @@ func TestExportStatus(t *testing.T) { // MonitorCycleLastTimestamp) and classifies the cycle result based on the // presence of a ConditionFalse condition. func TestExportStatusRecordsMonitorCycle(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -332,16 +306,7 @@ func TestExportStatusRecordsMonitorCycle(t *testing.T) { } func TestExportProblem(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -391,16 +356,7 @@ func TestExportProblem(t *testing.T) { } func TestConcurrentExports(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -480,16 +436,7 @@ func contains(s, substr string) bool { } func TestIsReloadable(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -501,12 +448,15 @@ func TestIsReloadable(t *testing.T) { } func TestReload(t *testing.T) { - port1 := freePort(t) - port2 := freePort(t) - + // Bind ephemerally (Port 0) so each Start/Reload-driven server restart binds + // an OS-assigned port atomically — there is no freePort close-then-rebind + // window for another process to grab the port (TaskForge #19260). The + // port-comparison restart logic itself is unit-covered by + // TestNeedsServerRestart; here the restart-and-rebind path is exercised + // hermetically via path/namespace/subsystem/label changes. config := &types.PrometheusExporterConfig{ Enabled: true, - Port: port1, + Port: 0, Path: "/metrics", Namespace: "test", Subsystem: "sub1", @@ -514,7 +464,7 @@ func TestReload(t *testing.T) { } settings := &types.GlobalSettings{NodeName: "test-node"} - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newPrometheusExporter(config, settings, true) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -563,7 +513,7 @@ func TestReload(t *testing.T) { name: "same config - no restart needed", newConfig: &types.PrometheusExporterConfig{ Enabled: true, - Port: port1, + Port: 0, Path: "/metrics", Namespace: "test", Subsystem: "sub1", @@ -572,21 +522,10 @@ func TestReload(t *testing.T) { expectedError: false, }, { - name: "different port - restart needed", + name: "different path - restart needed (rebinds ephemeral port)", newConfig: &types.PrometheusExporterConfig{ Enabled: true, - Port: port2, - Path: "/metrics", - Namespace: "test", - Subsystem: "sub1", - }, - expectedError: false, - }, - { - name: "different path - restart needed", - newConfig: &types.PrometheusExporterConfig{ - Enabled: true, - Port: port2, + Port: 0, Path: "/custom-metrics", Namespace: "test", }, @@ -596,7 +535,7 @@ func TestReload(t *testing.T) { name: "different namespace - metrics recreation needed", newConfig: &types.PrometheusExporterConfig{ Enabled: true, - Port: port2, + Port: 0, Path: "/custom-metrics", Namespace: "new_namespace", }, @@ -606,7 +545,7 @@ func TestReload(t *testing.T) { name: "different subsystem - metrics recreation needed", newConfig: &types.PrometheusExporterConfig{ Enabled: true, - Port: port2, + Port: 0, Path: "/custom-metrics", Namespace: "new_namespace", Subsystem: "new_subsystem", @@ -617,7 +556,7 @@ func TestReload(t *testing.T) { name: "different labels - metrics recreation needed", newConfig: &types.PrometheusExporterConfig{ Enabled: true, - Port: port2, + Port: 0, Path: "/custom-metrics", Namespace: "new_namespace", Labels: map[string]string{"env": "prod", "region": "us-east"}, @@ -647,18 +586,9 @@ func TestReload(t *testing.T) { } func TestReloadNotStarted(t *testing.T) { - port1 := freePort(t) - port2 := freePort(t) - - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port1, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + // Not started, so Reload never binds a server; ephemeral construction keeps + // the path hermetic and consistent with the rest of the suite. + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -666,7 +596,7 @@ func TestReloadNotStarted(t *testing.T) { // Reload without starting - should work but not restart server newConfig := &types.PrometheusExporterConfig{ Enabled: true, - Port: port2, + Port: 0, Path: "/new-metrics", Namespace: "new_test", } @@ -678,16 +608,7 @@ func TestReloadNotStarted(t *testing.T) { } func TestNeedsServerRestart(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -774,16 +695,7 @@ func TestNeedsServerRestart(t *testing.T) { } func TestNeedsMetricsRecreation(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -1135,16 +1047,7 @@ func TestShutdownServer(t *testing.T) { } // Test shutdown of valid server - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -1156,13 +1059,14 @@ func TestShutdownServer(t *testing.T) { } // Wait for server to be ready using deterministic readiness check - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } // Verify server is running - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d/health", port)) if err != nil { t.Fatalf("failed to connect to server: %v", err) } @@ -1181,23 +1085,14 @@ func TestShutdownServer(t *testing.T) { } client := &http.Client{Timeout: 500 * time.Millisecond} - _, err = client.Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + _, err = client.Get(fmt.Sprintf("http://localhost:%d/health", port)) if err == nil { t.Errorf("expected error connecting to stopped server") } } func TestExportProblemBeforeStart(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -1260,14 +1155,7 @@ func TestPrometheusExporter_StartBindFailure(t *testing.T) { // TestNewPrometheusExporter_DualStackDefault verifies an empty BindAddress // defaults to "::" (dual-stack) in the constructor. func TestNewPrometheusExporter_DualStackDefault(t *testing.T) { - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: freePort(t), - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -1280,17 +1168,8 @@ func TestNewPrometheusExporter_DualStackDefault(t *testing.T) { // the default "::" (dual-stack) BindAddress and serves /metrics. The bind has an // automatic IPv4 fallback, so this passes whether or not IPv6 is available. func TestPrometheusExporter_DualStackServesRequest(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - // BindAddress intentionally left empty -> defaults to "::". - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + // BindAddress intentionally left empty -> defaults to "::"; ephemeral port. + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -1299,11 +1178,12 @@ func TestPrometheusExporter_DualStackServesRequest(t *testing.T) { } defer func() { _ = exporter.Stop() }() + port := exporter.BoundPort() addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { t.Fatalf("failed to connect to metrics server: %v", err) } @@ -1316,17 +1196,18 @@ func TestPrometheusExporter_DualStackServesRequest(t *testing.T) { // TestPrometheusExporter_ExplicitBindAddressHonored verifies an explicit // BindAddress is used as-is and serves a request. func TestPrometheusExporter_ExplicitBindAddressHonored(t *testing.T) { - port := freePort(t) + // Explicit BindAddress with an ephemeral (0) port: the listener binds + // 127.0.0.1:0 and the real port is read back via BoundPort. config := &types.PrometheusExporterConfig{ Enabled: true, BindAddress: "127.0.0.1", - Port: port, + Port: 0, Path: "/metrics", Namespace: "test", } settings := &types.GlobalSettings{NodeName: "test-node"} - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newPrometheusExporter(config, settings, true) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -1343,11 +1224,12 @@ func TestPrometheusExporter_ExplicitBindAddressHonored(t *testing.T) { t.Errorf("bound host = %q, want 127.0.0.1", host) } + port := exporter.BoundPort() addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { t.Fatalf("failed to connect to metrics server: %v", err) } diff --git a/pkg/exporters/prometheus/integration_test.go b/pkg/exporters/prometheus/integration_test.go index 989d0d8..a0fdee4 100644 --- a/pkg/exporters/prometheus/integration_test.go +++ b/pkg/exporters/prometheus/integration_test.go @@ -15,16 +15,7 @@ import ( ) func TestMetricsEndpoint(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -36,13 +27,14 @@ func TestMetricsEndpoint(t *testing.T) { } defer exporter.Stop() - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } // Test metrics endpoint - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { t.Fatalf("failed to get metrics: %v", err) } @@ -82,16 +74,7 @@ func TestMetricsEndpoint(t *testing.T) { } func TestHealthEndpoint(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -103,13 +86,14 @@ func TestHealthEndpoint(t *testing.T) { } defer exporter.Stop() - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } // Test health endpoint - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d/health", config.Port)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d/health", port)) if err != nil { t.Fatalf("failed to get health: %v", err) } @@ -143,16 +127,7 @@ func TestHealthEndpoint(t *testing.T) { } func TestPrometheusFormat(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -164,7 +139,8 @@ func TestPrometheusFormat(t *testing.T) { } defer exporter.Stop() - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } @@ -203,7 +179,7 @@ func TestPrometheusFormat(t *testing.T) { exporter.ExportProblem(ctx, problem) // Get metrics - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { t.Fatalf("failed to get metrics: %v", err) } @@ -266,16 +242,7 @@ func TestPrometheusFormat(t *testing.T) { } func TestConditionStatusGauge(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -287,7 +254,8 @@ func TestConditionStatusGauge(t *testing.T) { } defer exporter.Stop() - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } @@ -309,7 +277,7 @@ func TestConditionStatusGauge(t *testing.T) { exporter.ExportStatus(ctx, status) // Scrape and verify gauge == 1 - body := scrapeMetrics(t, config.Port, config.Path) + body := scrapeMetrics(t, port, "/metrics") if !strings.Contains(body, `test_condition_status{condition_type="NetworkPartitioned"`) { t.Error("condition_status metric not found for NetworkPartitioned") } @@ -334,7 +302,7 @@ func TestConditionStatusGauge(t *testing.T) { exporter.ExportStatus(ctx, status2) // Scrape and verify gauge == 0 - body = scrapeMetrics(t, config.Port, config.Path) + body = scrapeMetrics(t, port, "/metrics") if !containsMetricWithValue(body, "test_condition_status", "NetworkPartitioned", "0") { t.Error("expected condition_status=0 for NetworkPartitioned=False, but gauge did not update") } @@ -371,16 +339,7 @@ func containsMetricWithValue(body, metricName, conditionType, value string) bool } func TestConcurrentScrapes(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -392,7 +351,8 @@ func TestConcurrentScrapes(t *testing.T) { } defer exporter.Stop() - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } @@ -410,7 +370,7 @@ func TestConcurrentScrapes(t *testing.T) { defer wg.Done() for j := 0; j < numScrapes; j++ { - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { errCh <- fmt.Errorf("scrape %d-%d failed: %w", id, j, err) return @@ -444,16 +404,7 @@ func TestConcurrentScrapes(t *testing.T) { } func TestServerShutdown(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -464,13 +415,14 @@ func TestServerShutdown(t *testing.T) { t.Fatalf("failed to start exporter: %v", err) } - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } // Verify server is running - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { t.Fatalf("failed to connect before shutdown: %v", err) } @@ -488,23 +440,14 @@ func TestServerShutdown(t *testing.T) { } // Verify server is no longer running - _, err = newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path)) + _, err = newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err == nil { t.Error("expected connection to fail after shutdown") } } func TestMetricValues(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -516,7 +459,8 @@ func TestMetricValues(t *testing.T) { } defer exporter.Stop() - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } @@ -543,7 +487,7 @@ func TestMetricValues(t *testing.T) { } // Get metrics - resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", config.Port, config.Path)) + resp, err := newTestHTTPClient().Get(fmt.Sprintf("http://localhost:%d%s", port, "/metrics")) if err != nil { t.Fatalf("failed to get metrics: %v", err) } @@ -610,16 +554,7 @@ func TestMetricValues(t *testing.T) { } func TestNameserverHealthScoreStaleMetricCleanup(t *testing.T) { - port := freePort(t) - config := &types.PrometheusExporterConfig{ - Enabled: true, - Port: port, - Path: "/metrics", - Namespace: "test", - } - settings := &types.GlobalSettings{NodeName: "test-node"} - - exporter, err := NewPrometheusExporter(config, settings) + exporter, err := newEphemeralExporter(&types.GlobalSettings{NodeName: "test-node"}) if err != nil { t.Fatalf("failed to create exporter: %v", err) } @@ -630,7 +565,8 @@ func TestNameserverHealthScoreStaleMetricCleanup(t *testing.T) { } defer exporter.Stop() - addr := fmt.Sprintf("localhost:%d", config.Port) + port := exporter.BoundPort() + addr := fmt.Sprintf("localhost:%d", port) if err := waitForServerReady(addr, 5*time.Second); err != nil { t.Fatalf("server never became ready: %v", err) } @@ -647,7 +583,7 @@ func TestNameserverHealthScoreStaleMetricCleanup(t *testing.T) { t.Fatalf("failed to export status round 1: %v", err) } - body := scrapeMetrics(t, config.Port, config.Path) + body := scrapeMetrics(t, port, "/metrics") if !strings.Contains(body, `nameserver="8.8.8.8"`) { t.Error("round 1: expected 8.8.8.8 metric to be present") } @@ -666,7 +602,7 @@ func TestNameserverHealthScoreStaleMetricCleanup(t *testing.T) { t.Fatalf("failed to export status round 2: %v", err) } - body = scrapeMetrics(t, config.Port, config.Path) + body = scrapeMetrics(t, port, "/metrics") if !strings.Contains(body, `nameserver="8.8.8.8"`) { t.Error("round 2: expected 8.8.8.8 metric to still be present") } diff --git a/pkg/exporters/prometheus/testhelpers_test.go b/pkg/exporters/prometheus/testhelpers_test.go index 0da739c..b1b72e9 100644 --- a/pkg/exporters/prometheus/testhelpers_test.go +++ b/pkg/exporters/prometheus/testhelpers_test.go @@ -4,24 +4,9 @@ import ( "context" "net" "net/http" - "testing" "time" ) -// freePort allocates a free TCP port on the loopback interface and immediately -// releases it. There is a small TOCTOU window, but this is acceptable for tests -// and far preferable to hardcoded port numbers that can collide across test runs. -func freePort(t *testing.T) int { - t.Helper() - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("freePort: %v", err) - } - port := ln.Addr().(*net.TCPAddr).Port - ln.Close() - return port -} - // waitForServerReady polls until the server at addr is accepting TCP connections, // or until timeout elapses. Use after Start() to avoid race-prone time.Sleep calls. func waitForServerReady(addr string, timeout time.Duration) error {