diff --git a/protocol/logger/job_logger.go b/protocol/logger/job_logger.go index 5408cd0..57b0ed8 100644 --- a/protocol/logger/job_logger.go +++ b/protocol/logger/job_logger.go @@ -109,18 +109,39 @@ func (logger *WebsocketLivelogger) SendLog(lines *protocol.TimelineRecordFeedLin return wsjson.Write(ctx, ws, lines) } +func (logger *WebsocketLivelogger) ReConnectAndSendLog(lines *protocol.TimelineRecordFeedLinesWrapper) error { + if err := logger.Connect(); err != nil { + return err + } + if err := logger.SendLog(lines); err != nil { + return err + } + return nil +} + +type currentLogger struct { + logger LiveLogger + wsLogger *WebsocketLivelogger +} + +func (c *currentLogger) isValid() bool { + return c != nil && c.logger != nil +} + type WebsocketLiveloggerWithFallback struct { JobRequest *protocol.AgentJobRequestMessage Connection *protocol.VssConnection - currentLogger atomic.Pointer[LiveLogger] + currentLogger atomic.Pointer[currentLogger] FeedStreamURL string ForceWebsock bool } -func (logger *WebsocketLiveloggerWithFallback) initializeVssLogger() LiveLogger { - l := &VssLiveLogger{ - JobRequest: logger.JobRequest, - Connection: logger.Connection, +func (logger *WebsocketLiveloggerWithFallback) initializeVssLogger() *currentLogger { + l := ¤tLogger{ + logger: &VssLiveLogger{ + JobRequest: logger.JobRequest, + Connection: logger.Connection, + }, } _ = logger.replace(l) // Ignore error for cleanup return l @@ -130,7 +151,7 @@ func (logger *WebsocketLiveloggerWithFallback) InitializeVssLogger() { logger.initializeVssLogger() } -func (logger *WebsocketLiveloggerWithFallback) initialize() LiveLogger { +func (logger *WebsocketLiveloggerWithFallback) initialize() *currentLogger { if logger.FeedStreamURL != "" { wslogger := &WebsocketLivelogger{ JobRequest: logger.JobRequest, @@ -139,8 +160,12 @@ func (logger *WebsocketLiveloggerWithFallback) initialize() LiveLogger { } err := wslogger.Connect() if err == nil { - _ = logger.replace(wslogger) // Ignore error for cleanup - return wslogger + cl := ¤tLogger{ + logger: wslogger, + wsLogger: wslogger, + } + _ = logger.replace(cl) // Ignore error for cleanup + return cl } else if logger.Connection.Trace { fmt.Printf("Failed to connect to websocket %s, fallback to vsslogger\n", err.Error()) } @@ -170,65 +195,61 @@ func (e *errorLogger) SendLog(lines *protocol.TimelineRecordFeedLinesWrapper) er return ErrMissingLoggerConnection } -func makePointer[T any](p T) *T { - return &p -} - -func getPointer[T any](p *T) T { - if p == nil { - var zero T - return zero - } - return *p -} - -func (logger *WebsocketLiveloggerWithFallback) replace(n LiveLogger) error { - if currentLogger := logger.currentLogger.Swap(makePointer(n)); getPointer(currentLogger) != nil { - return (*currentLogger).Close() +func (logger *WebsocketLiveloggerWithFallback) replace(n *currentLogger) error { + if currentLogger := logger.currentLogger.Swap(n); currentLogger.isValid() { + return currentLogger.logger.Close() } return nil } func (logger *WebsocketLiveloggerWithFallback) Close() error { - return logger.replace(&errorLogger{}) + return logger.replace(¤tLogger{ + logger: &errorLogger{}, + }) } func (logger *WebsocketLiveloggerWithFallback) sendLogFallback( - err error, reason string, wrapper *protocol.TimelineRecordFeedLinesWrapper, + err error, wrapper *protocol.TimelineRecordFeedLinesWrapper, ) error { if !logger.ForceWebsock { if logger.Connection.Trace { - fmt.Printf("Failed to %s to websocket %s, fallback to vsslogger\n", reason, err.Error()) + fmt.Printf("Failed to send to websocket %s, fallback to vsslogger\n", err.Error()) } currentLogger := logger.initializeVssLogger() - if currentLogger == nil { - return fmt.Errorf("failed to initialize VSS logger after websocket %s failure: %w", reason, err) + if currentLogger.isValid() { + return currentLogger.logger.SendLog(wrapper) } - return currentLogger.SendLog(wrapper) + return fmt.Errorf("failed to initialize VSS logger after websocket send failure: %w", err) } return err } +func (logger *WebsocketLiveloggerWithFallback) getOrInitializeLogger() (*currentLogger, error) { + currentLogger := logger.currentLogger.Load() + if currentLogger.isValid() { + return currentLogger, nil + } + currentLogger = logger.initialize() + if currentLogger.isValid() { + return currentLogger, nil + } + return nil, fmt.Errorf("failed to initialize live logger: no logger instance available (ForceWebsock=%t)", logger.ForceWebsock) +} + func (logger *WebsocketLiveloggerWithFallback) SendLog(wrapper *protocol.TimelineRecordFeedLinesWrapper) error { - currentLogger := getPointer(logger.currentLogger.Load()) - if currentLogger == nil { - currentLogger = logger.initialize() - if currentLogger == nil { - return fmt.Errorf("failed to initialize live logger: no logger instance available (ForceWebsock=%t)", logger.ForceWebsock) - } + currentLogger, err := logger.getOrInitializeLogger() + if err != nil { + return err } - err := currentLogger.SendLog(wrapper) + err = currentLogger.logger.SendLog(wrapper) if err != nil { if logger.Connection.Trace { fmt.Printf("Failed to send webconsole log %s\n", err.Error()) } - if wslogger, ok := currentLogger.(*WebsocketLivelogger); ok { - if err = wslogger.Connect(); err != nil { - return logger.sendLogFallback(err, "reconnect", wrapper) - } - err = currentLogger.SendLog(wrapper) - if err != nil { - return logger.sendLogFallback(err, "send", wrapper) + wsLogger := currentLogger.wsLogger + if wsLogger != nil { + if err = currentLogger.wsLogger.ReConnectAndSendLog(wrapper); err != nil { + return logger.sendLogFallback(err, wrapper) } return nil }