diff --git a/cmd/workspace/apps/logs.go b/cmd/workspace/apps/logs.go new file mode 100644 index 00000000000..0072494ce1a --- /dev/null +++ b/cmd/workspace/apps/logs.go @@ -0,0 +1,281 @@ +package apps + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "slices" + "strings" + "time" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/apps/logstream" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdgroup" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/databricks-sdk-go/service/apps" + "github.com/gorilla/websocket" + "github.com/spf13/cobra" +) + +const ( + defaultTailLines = 200 + defaultPrefetchWindow = 2 * time.Second + defaultHandshakeTimeout = 30 * time.Second +) + +var allowedSources = []string{"APP", "SYSTEM"} + +func newLogsCommand() *cobra.Command { + var ( + tailLines int + follow bool + outputPath string + streamTimeout time.Duration + searchTerm string + sourceFilters []string + ) + + cmd := &cobra.Command{ + Use: "logs NAME", + Short: "Show Databricks app logs", + Long: `Stream stdout/stderr logs for a Databricks app via its log stream. + +By default the command fetches the most recent logs (up to --tail-lines, default 200) and exits. +Use --follow to continue streaming logs until cancelled, optionally bounding the duration with --timeout. +Server-side filtering is available through --search (same semantics as the Databricks UI) and client-side filtering +via --source APP|SYSTEM. Use --output-file to mirror the stream to a local file (created with 0600 permissions).`, + Example: ` # Fetch the last 50 log lines + databricks apps logs my-app --tail-lines 50 + + # Follow logs until interrupted, searching for "ERROR" messages from app sources only + databricks apps logs my-app --follow --search ERROR --source APP + + # Mirror streamed logs to a local file while following for up to 5 minutes + databricks apps logs my-app --follow --timeout 5m --output-file /tmp/my-app.log`, + Args: root.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + if tailLines < 0 { + return errors.New("--tail-lines cannot be negative") + } + + if follow && streamTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, streamTimeout) + defer cancel() + } + + name := args[0] + w := cmdctx.WorkspaceClient(ctx) + app, err := w.Apps.Get(ctx, apps.GetAppRequest{Name: name}) + if err != nil { + return err + } + if app.Url == "" { + return fmt.Errorf("app %s does not have a public URL; deploy and start it before streaming logs", name) + } + + wsURL, err := buildLogsURL(app.Url) + if err != nil { + return err + } + + cfg := cmdctx.ConfigUsed(ctx) + if cfg == nil { + return errors.New("missing workspace configuration") + } + + tokenSource := cfg.GetTokenSource() + if tokenSource == nil { + return errors.New("configuration does not support OAuth tokens") + } + + initialToken, err := tokenSource.Token(ctx) + if err != nil { + return err + } + + tokenProvider := func(ctx context.Context) (string, error) { + tok, err := tokenSource.Token(ctx) + if err != nil { + return "", err + } + return tok.AccessToken, nil + } + + appStatusChecker := func(ctx context.Context) error { + app, err := w.Apps.Get(ctx, apps.GetAppRequest{Name: name}) + if err != nil { + return err + } + if app.ComputeStatus == nil { + return errors.New("app status unavailable") + } + // Check if app is in a terminal/stopped state + switch app.ComputeStatus.State { + case apps.ComputeStateStopped, apps.ComputeStateDeleting, apps.ComputeStateError: + return fmt.Errorf("app is %s", app.ComputeStatus.State) + default: + // App is running or transitioning - continue streaming + return nil + } + } + + writer := cmd.OutOrStdout() + var file *os.File + if outputPath != "" { + file, err = os.OpenFile(outputPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o600) + if err != nil { + return err + } + defer file.Close() + writer = io.MultiWriter(writer, file) + } + colorizeLogs := outputPath == "" && cmdio.IsTTY(cmd.OutOrStdout()) + + sourceMap, err := buildSourceFilter(sourceFilters) + if err != nil { + return err + } + + log.Infof(ctx, "Streaming logs for %s (%s)", name, wsURL) + return logstream.Run(ctx, logstream.Config{ + Dialer: newLogStreamDialer(cfg), + URL: wsURL, + Origin: normalizeOrigin(app.Url), + Token: initialToken.AccessToken, + TokenProvider: tokenProvider, + AppStatusChecker: appStatusChecker, + Search: searchTerm, + Sources: sourceMap, + Tail: tailLines, + Follow: follow, + Prefetch: defaultPrefetchWindow, + Writer: writer, + UserAgent: "databricks-cli apps logs", + Colorize: colorizeLogs, + }) + }, + } + + streamGroup := cmdgroup.NewFlagGroup("Streaming") + streamGroup.FlagSet().IntVar(&tailLines, "tail-lines", defaultTailLines, "Number of recent log lines to show before streaming. Set to 0 to show everything.") + streamGroup.FlagSet().BoolVarP(&follow, "follow", "f", false, "Continue streaming logs until interrupted.") + streamGroup.FlagSet().DurationVar(&streamTimeout, "timeout", 0, "Maximum time to stream when --follow is set. 0 disables the timeout.") + + filterGroup := cmdgroup.NewFlagGroup("Filtering") + filterGroup.FlagSet().StringVar(&searchTerm, "search", "", "Send a search term to the log service before streaming.") + filterGroup.FlagSet().StringSliceVar(&sourceFilters, "source", nil, "Restrict logs to APP and/or SYSTEM sources.") + + wrappedCmd := cmdgroup.NewCommandWithGroupFlag(cmd) + wrappedCmd.AddFlagGroup(streamGroup) + wrappedCmd.AddFlagGroup(filterGroup) + + cmd.Flags().StringVar(&outputPath, "output-file", "", "Optional file path to write logs in addition to stdout.") + + return cmd +} + +func buildLogsURL(appURL string) (string, error) { + parsed, err := url.Parse(appURL) + if err != nil { + return "", err + } + + switch strings.ToLower(parsed.Scheme) { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + case "wss", "ws": + default: + return "", fmt.Errorf("unsupported app URL scheme: %s", parsed.Scheme) + } + + parsed.Path = path.Join(parsed.Path, "logz/stream") + if !strings.HasPrefix(parsed.Path, "/") { + parsed.Path = "/" + parsed.Path + } + + return parsed.String(), nil +} + +func normalizeOrigin(appURL string) string { + parsed, err := url.Parse(appURL) + if err != nil { + return "" + } + switch strings.ToLower(parsed.Scheme) { + case "http", "https": + return parsed.Scheme + "://" + parsed.Host + case "ws": + parsed.Scheme = "http" + case "wss": + parsed.Scheme = "https" + default: + return "" + } + parsed.Path = "" + parsed.RawQuery = "" + parsed.Fragment = "" + return parsed.String() +} + +func buildSourceFilter(values []string) (map[string]struct{}, error) { + if len(values) == 0 { + return nil, nil + } + filter := make(map[string]struct{}) + for _, v := range values { + trimmed := strings.ToUpper(strings.TrimSpace(v)) + if trimmed == "" { + continue + } + if !slices.Contains(allowedSources, trimmed) { + return nil, fmt.Errorf("invalid --source value %q (valid: %s)", v, strings.Join(allowedSources, ", ")) + } + filter[trimmed] = struct{}{} + } + if len(filter) == 0 { + return nil, nil + } + return filter, nil +} + +func newLogStreamDialer(cfg *config.Config) *websocket.Dialer { + dialer := &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: defaultHandshakeTimeout, + } + + if cfg == nil { + return dialer + } + + if transport, ok := cfg.HTTPTransport.(*http.Transport); ok && transport != nil { + clone := transport.Clone() + dialer.Proxy = clone.Proxy + dialer.NetDialContext = clone.DialContext + if clone.TLSClientConfig != nil { + dialer.TLSClientConfig = clone.TLSClientConfig.Clone() + } + return dialer + } + + if cfg.InsecureSkipVerify { + dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + return dialer +} diff --git a/cmd/workspace/apps/logs_test.go b/cmd/workspace/apps/logs_test.go new file mode 100644 index 00000000000..ca7ea5e7ef9 --- /dev/null +++ b/cmd/workspace/apps/logs_test.go @@ -0,0 +1,85 @@ +package apps + +import ( + "crypto/tls" + "net/http" + "net/url" + "testing" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewLogStreamDialerConfiguresProxyAndTLS(t *testing.T) { + t.Run("clones HTTP transport when provided", func(t *testing.T) { + proxyURL, err := url.Parse("http://localhost:8080") + require.NoError(t, err) + + transport := &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + } + + cfg := &config.Config{ + HTTPTransport: transport, + } + + dialer := newLogStreamDialer(cfg) + require.NotNil(t, dialer) + + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "example.com"}} + actualProxy, err := dialer.Proxy(req) + require.NoError(t, err) + assert.Equal(t, proxyURL.String(), actualProxy.String()) + + require.NotNil(t, dialer.TLSClientConfig) + assert.NotSame(t, transport.TLSClientConfig, dialer.TLSClientConfig, "TLS config should be cloned") + assert.Equal(t, transport.TLSClientConfig.MinVersion, dialer.TLSClientConfig.MinVersion) + }) + + t.Run("honors insecure skip verify when no transport is supplied", func(t *testing.T) { + cfg := &config.Config{ + InsecureSkipVerify: true, + } + dialer := newLogStreamDialer(cfg) + require.NotNil(t, dialer) + require.NotNil(t, dialer.TLSClientConfig, "expected TLS config when insecure skip verify is set") + assert.True(t, dialer.TLSClientConfig.InsecureSkipVerify) + }) +} + +func TestBuildLogsURLConvertsSchemes(t *testing.T) { + url, err := buildLogsURL("https://example.com/foo") + require.NoError(t, err) + assert.Equal(t, "wss://example.com/foo/logz/stream", url) + + url, err = buildLogsURL("http://example.com/foo") + require.NoError(t, err) + assert.Equal(t, "ws://example.com/foo/logz/stream", url) +} + +func TestBuildLogsURLRejectsUnknownScheme(t *testing.T) { + _, err := buildLogsURL("ftp://example.com/foo") + require.Error(t, err) +} + +func TestNormalizeOrigin(t *testing.T) { + assert.Equal(t, "https://example.com", normalizeOrigin("https://example.com/foo")) + assert.Equal(t, "http://example.com", normalizeOrigin("ws://example.com/foo")) + assert.Equal(t, "https://example.com", normalizeOrigin("wss://example.com/foo")) + assert.Equal(t, "", normalizeOrigin("://invalid")) +} + +func TestBuildSourceFilter(t *testing.T) { + filters, err := buildSourceFilter([]string{"app", "system", ""}) + require.NoError(t, err) + assert.Equal(t, map[string]struct{}{"APP": {}, "SYSTEM": {}}, filters) + + filters, err = buildSourceFilter(nil) + require.NoError(t, err) + assert.Nil(t, filters) + + _, err = buildSourceFilter([]string{"foo"}) + require.Error(t, err) +} diff --git a/cmd/workspace/apps/overrides.go b/cmd/workspace/apps/overrides.go index e1406871771..3bc3871aa97 100644 --- a/cmd/workspace/apps/overrides.go +++ b/cmd/workspace/apps/overrides.go @@ -23,6 +23,9 @@ func listDeploymentsOverride(listDeploymentsCmd *cobra.Command, listDeploymentsR } func init() { + cmdOverrides = append(cmdOverrides, func(cmd *cobra.Command) { + cmd.AddCommand(newLogsCommand()) + }) listOverrides = append(listOverrides, listOverride) listDeploymentsOverrides = append(listDeploymentsOverrides, listDeploymentsOverride) } diff --git a/libs/apps/logstream/streamer.go b/libs/apps/logstream/streamer.go new file mode 100644 index 00000000000..e9bd4d15aa1 --- /dev/null +++ b/libs/apps/logstream/streamer.go @@ -0,0 +1,576 @@ +package logstream + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "slices" + "strings" + "sync" + "time" + + "github.com/fatih/color" + "github.com/gorilla/websocket" +) + +const ( + handshakeErrorBodyLimit = 4 * 1024 + defaultUserAgent = "databricks-cli logstream" + initialReconnectBackoff = 200 * time.Millisecond + maxReconnectBackoff = 5 * time.Second + closeCodeUnauthorized = 4401 + closeCodeForbidden = 4403 +) + +// Dialer defines the subset of websocket.Dialer used by the streamer. +type Dialer interface { + DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error) +} + +// TokenProvider refreshes tokens when the streamer needs a new bearer token. +type TokenProvider func(context.Context) (string, error) + +// AppStatusChecker checks if the app is still running. +// Returns nil if app is running, or an error if the app is stopped/unavailable. +type AppStatusChecker func(context.Context) error + +// Config holds the options for running a log stream. +type Config struct { + Dialer Dialer + URL string + Origin string + Token string + TokenProvider TokenProvider + AppStatusChecker AppStatusChecker + Search string + Sources map[string]struct{} + Tail int + Follow bool + Prefetch time.Duration + Writer io.Writer + UserAgent string + Colorize bool +} + +// Run connects to the log stream described by cfg and copies frames to the writer. +func Run(ctx context.Context, cfg Config) error { + if cfg.Writer == nil { + return errors.New("logstream: writer is required") + } + + streamer := &logStreamer{ + dialer: cfg.Dialer, + url: cfg.URL, + origin: cfg.Origin, + token: cfg.Token, + tokenProvider: cfg.TokenProvider, + appStatusChecker: cfg.AppStatusChecker, + search: cfg.Search, + sources: cfg.Sources, + tail: cfg.Tail, + follow: cfg.Follow, + prefetch: cfg.Prefetch, + writer: cfg.Writer, + userAgent: cfg.UserAgent, + colorize: cfg.Colorize, + } + if streamer.userAgent == "" { + streamer.userAgent = defaultUserAgent + } + return streamer.Run(ctx) +} + +type logStreamer struct { + dialer Dialer + url string + origin string + token string + tokenProvider TokenProvider + appStatusChecker AppStatusChecker + search string + sources map[string]struct{} + tail int + follow bool + prefetch time.Duration + writer io.Writer + tailFlushed bool + userAgent string + colorize bool +} + +// Run establishes the websocket connection and manages reconnections. +// It is not safe to call Run concurrently on the same logStreamer instance. +func (s *logStreamer) Run(ctx context.Context) error { + if s.dialer == nil { + s.dialer = &websocket.Dialer{} + } + + backoff := initialReconnectBackoff + // Backoff timer starts as a zero-value timer; stopTimer handles the first initialization safely. + timer := time.NewTimer(0) + stopTimer(timer, 0) + + for { + shouldContinue, err := func() (bool, error) { + resp, err := s.connectAndConsume(ctx) + if err != nil { + if ctx.Err() != nil { + return false, ctx.Err() + } + + if s.follow && (s.shouldRefreshForStatus(resp) || s.shouldRefreshForError(err)) { + if err := s.refreshToken(ctx); err != nil { + return false, err + } + backoff = initialReconnectBackoff + return true, nil + } + + if !s.follow { + return false, err + } + + // Before retrying, check if the app is still running (if checker is provided). + if s.appStatusChecker != nil { + if statusErr := s.appStatusChecker(ctx); statusErr != nil { + return false, fmt.Errorf("app is no longer available: %w", statusErr) + } + } + + return true, nil + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + backoff = initialReconnectBackoff + if !s.follow { + return false, nil + } + // Connection closed normally while following - check if app is still running. + if s.appStatusChecker != nil { + if statusErr := s.appStatusChecker(ctx); statusErr != nil { + return false, fmt.Errorf("app is no longer available: %w", statusErr) + } + } + return true, nil + }() + if err != nil { + return err + } + + if shouldContinue { + if err := waitForBackoff(ctx, timer, backoff); err != nil { + return err + } + backoff = min(backoff*2, maxReconnectBackoff) + continue + } + + return nil + } +} + +func (s *logStreamer) connectAndConsume(ctx context.Context) (*http.Response, error) { + if err := s.ensureToken(ctx); err != nil { + return nil, err + } + + headers := http.Header{} + headers.Set("Authorization", "Bearer "+s.token) + headers.Set("User-Agent", s.userAgent) + if s.origin != "" { + headers.Set("Origin", s.origin) + } + + conn, resp, err := s.dialer.DialContext(ctx, s.url, headers) + if err != nil { + err = decorateDialError(err, resp) + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + return resp, err + } + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + + stopWatch := watchContext(ctx, conn) + defer stopWatch() + + err = s.consume(ctx, conn) + return nil, err +} + +func (s *logStreamer) consume(ctx context.Context, conn *websocket.Conn) (retErr error) { + initial := []byte(s.search) + if len(initial) == 0 { + initial = []byte("") + } + + if err := conn.WriteMessage(websocket.TextMessage, initial); err != nil { + return err + } + + buffer := &tailBuffer{size: s.tail} + flushed := s.tail == 0 || s.tailFlushed + var flushDeadline time.Time + if s.tail > 0 && s.prefetch > 0 && !s.tailFlushed { + flushDeadline = time.Now().Add(s.prefetch) + } + + defer func() { + if s.tail > 0 && !flushed { + if err := buffer.Flush(s.writer); err != nil { + if retErr == nil { + retErr = err + } + return + } + flushed = true + s.tailFlushed = true + } + }() + + for { + if ctx.Err() != nil { + return ctx.Err() + } + + deadline, hasDeadline := ctx.Deadline() + if !flushDeadline.IsZero() { + if !hasDeadline || flushDeadline.Before(deadline) { + deadline = flushDeadline + } + hasDeadline = true + } + if hasDeadline { + _ = conn.SetReadDeadline(deadline) + } else { + _ = conn.SetReadDeadline(time.Time{}) + } + + _, message, err := conn.ReadMessage() + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + if !flushDeadline.IsZero() { + flushDeadline = time.Time{} + if s.tail > 0 && !flushed { + if err := buffer.Flush(s.writer); err != nil { + return err + } + flushed = true + s.tailFlushed = true + if !s.follow { + return nil + } + } + continue + } + } + if handled, closeErr := handleCloseError(err); handled { + return closeErr + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return ctx.Err() + } + return err + } + + if len(message) == 1 && message[0] == 0 { + continue + } + + entry, err := parseLogEntry(message) + if err != nil { + line := formatPlainMessage(message) + if line == "" { + continue + } + stop, err := s.flushOrBufferLine(line, buffer, &flushed, &flushDeadline) + if err != nil { + return err + } + if stop { + return nil + } + continue + } + source := strings.ToUpper(entry.Source) + if len(s.sources) > 0 { + if _, ok := s.sources[source]; !ok { + continue + } + } + line := s.formatLogEntry(entry) + stop, err := s.flushOrBufferLine(line, buffer, &flushed, &flushDeadline) + if err != nil { + return err + } + if stop { + return nil + } + } +} + +func (s *logStreamer) flushOrBufferLine(line string, buffer *tailBuffer, flushed *bool, flushDeadline *time.Time) (bool, error) { + if s.tail > 0 && !*flushed { + buffer.Add(line) + ready := buffer.Len() >= s.tail + if !flushDeadline.IsZero() { + ready = false + } + if ready { + if !s.follow { + return false, nil + } + if err := buffer.Flush(s.writer); err != nil { + return false, err + } + *flushed = true + s.tailFlushed = true + *flushDeadline = time.Time{} + } + return false, nil + } + if _, err := fmt.Fprintln(s.writer, line); err != nil { + return false, err + } + return false, nil +} + +type wsEntry struct { + Source string `json:"source"` + Timestamp float64 `json:"timestamp"` + Message string `json:"message"` +} + +func parseLogEntry(raw []byte) (*wsEntry, error) { + var entry wsEntry + if err := json.Unmarshal(raw, &entry); err != nil { + return nil, err + } + return &entry, nil +} + +func (s *logStreamer) formatLogEntry(entry *wsEntry) string { + timestamp := formatTimestamp(entry.Timestamp) + source := strings.ToUpper(entry.Source) + message := strings.TrimRight(entry.Message, "\r\n") + + if s.colorize { + timestamp = color.HiBlackString(timestamp) + source = color.HiBlueString(source) + } + + return fmt.Sprintf("%s [%s] %s", timestamp, source, message) +} + +func formatPlainMessage(raw []byte) string { + line := strings.TrimRight(string(raw), "\r\n") + return line +} + +type tailBuffer struct { + size int + lines []string +} + +func (b *tailBuffer) Add(line string) { + if b.size <= 0 { + return + } + b.lines = append(b.lines, line) + if len(b.lines) > b.size { + b.lines = slices.Delete(b.lines, 0, len(b.lines)-b.size) + } +} + +func (b *tailBuffer) Len() int { + return len(b.lines) +} + +func (b *tailBuffer) Flush(w io.Writer) error { + if b.size == 0 { + return nil + } + for _, line := range b.lines { + if _, err := fmt.Fprintln(w, line); err != nil { + return err + } + } + b.lines = slices.Clip(b.lines[:0]) + return nil +} + +func formatTimestamp(ts float64) string { + if ts <= 0 { + return "----------" + } + sec := int64(ts) + nsec := int64((ts - float64(sec)) * 1e9) + t := time.Unix(sec, nsec).UTC() + return t.Format(time.RFC3339) +} + +func (s *logStreamer) ensureToken(ctx context.Context) error { + if s.token != "" || s.tokenProvider == nil { + return nil + } + token, err := s.tokenProvider(ctx) + if err != nil { + return err + } + s.token = token + return nil +} + +func (s *logStreamer) refreshToken(ctx context.Context) error { + if s.tokenProvider == nil { + return errors.New("token refresh unavailable") + } + s.token = "" + return s.ensureToken(ctx) +} + +func decorateDialError(err error, resp *http.Response) error { + if resp == nil { + return err + } + + var bodySnippet string + if resp.Body != nil { + data, readErr := io.ReadAll(io.LimitReader(resp.Body, handshakeErrorBodyLimit)) + _ = resp.Body.Close() + if readErr == nil { + bodySnippet = strings.TrimSpace(string(data)) + } + } + + status := strings.TrimSpace(resp.Status) + if status == "" && resp.StatusCode != 0 { + status = fmt.Sprintf("%d %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + if status == "" { + status = "unknown status" + } + + detail := "HTTP " + status + if bodySnippet != "" { + return fmt.Errorf("%w (%s: %s)", err, detail, bodySnippet) + } + return fmt.Errorf("%w (%s)", err, detail) +} + +func (s *logStreamer) shouldRefreshForStatus(resp *http.Response) bool { + if resp == nil { + return false + } + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + return true + default: + return false + } +} + +func (s *logStreamer) shouldRefreshForError(err error) bool { + var closeErr *websocket.CloseError + if errors.As(err, &closeErr) { + switch closeErr.Code { + case closeCodeUnauthorized, closeCodeForbidden: + return true + } + } + return false +} + +func handleCloseError(err error) (bool, error) { + var closeErr *websocket.CloseError + if !errors.As(err, &closeErr) { + return false, err + } + if closeErr.Code == websocket.CloseNormalClosure || closeErr.Code == websocket.CloseGoingAway { + return true, nil + } + return true, fmt.Errorf("log stream closed with code %d (%s): %w", closeErr.Code, strings.TrimSpace(closeErr.Text), err) +} + +func waitForBackoff(ctx context.Context, timer *time.Timer, d time.Duration) error { + if d <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + stopTimer(timer, d) + select { + case <-ctx.Done(): + stopTimer(timer, 0) + return ctx.Err() + case <-timer.C: + return nil + } +} + +func stopTimer(timer *time.Timer, d time.Duration) { + if timer == nil { + return + } + if d <= 0 { + // For a zero duration we only need to stop and drain. + if timer.Stop() { + return + } + drainTimer(timer) + return + } + // For a positive duration, either stop an already-started timer or + // just initialize it when it is still in the zero state. + if !timer.Stop() { + drainTimer(timer) + } + timer.Reset(d) +} + +func drainTimer(timer *time.Timer) { + select { + case <-timer.C: + default: + } +} + +func watchContext(ctx context.Context, conn *websocket.Conn) func() { + var once sync.Once + closeCh := make(chan struct{}) + + closeConn := func() { + once.Do(func() { + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "context canceled"), time.Now().Add(time.Second)) + _ = conn.Close() + }) + } + + go func() { + select { + case <-ctx.Done(): + closeConn() + case <-closeCh: + } + }() + + return func() { + close(closeCh) + closeConn() + } +} diff --git a/libs/apps/logstream/streamer_test.go b/libs/apps/logstream/streamer_test.go new file mode 100644 index 00000000000..58b984e233f --- /dev/null +++ b/libs/apps/logstream/streamer_test.go @@ -0,0 +1,716 @@ +package logstream + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/fatih/color" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLogStreamerTailBufferFlushes(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() // search token + + for i := 1; i <= 3; i++ { + require.NoError(t, sendEntry(conn, float64(i), fmt.Sprintf("msg%d", i))) + } + time.Sleep(50 * time.Millisecond) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "test", + tail: 2, + follow: false, + prefetch: 25 * time.Millisecond, + writer: buf, + } + + require.NoError(t, streamer.Run(context.Background())) + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.Len(t, lines, 2, "expected only last two log lines") + assert.Contains(t, lines[0], "msg2") + assert.Contains(t, lines[1], "msg3") +} + +func TestLogStreamerTailFlushErrorPropagates(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + + require.NoError(t, sendEntry(conn, 1, "msg1")) + require.NoError(t, sendEntry(conn, 2, "msg2")) + + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + writerErr := errors.New("simulated write failure") + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "test", + tail: 2, + follow: false, + prefetch: 0, + writer: &failWriter{err: writerErr}, + } + + err := streamer.Run(context.Background()) + require.Error(t, err) + assert.Equal(t, writerErr, err) +} + +func TestLogStreamerTrimsCRLFInStructuredEntries(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + require.NoError(t, sendEntry(conn, 123, "line with crlf\r\n")) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + writer: buf, + } + + require.NoError(t, streamer.Run(context.Background())) + output := buf.String() + assert.Contains(t, output, "line with crlf") + assert.NotContains(t, output, "\r") +} + +func TestLogStreamerDialErrorIncludesResponseBody(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"FORBIDDEN","message":"token invalid"}`)) + })) + defer server.Close() + + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "test", + writer: &bytes.Buffer{}, + } + + err := streamer.Run(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "HTTP 403 Forbidden") + assert.Contains(t, err.Error(), "token invalid") +} + +func TestLogStreamerRetriesOnDialFailure(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + require.NoError(t, sendEntry(conn, float64(id), fmt.Sprintf("msg%d", id))) + }) + defer server.Close() + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &flakyDialer{failures: 1, inner: &websocket.Dialer{}}, + url: toWebSocketURL(server.URL), + tail: 0, + follow: true, + prefetch: 0, + writer: buf, + } + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + require.ErrorIs(t, streamer.Run(ctx), context.DeadlineExceeded) + assert.Contains(t, buf.String(), "msg1") +} + +func TestLogStreamerSendsSearchTerm(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, msg, err := conn.ReadMessage() + require.NoError(t, err) + assert.Equal(t, "ERROR", string(msg)) + require.NoError(t, sendEntry(conn, 1, "boom")) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "test", + search: "ERROR", + writer: buf, + } + + require.NoError(t, streamer.Run(context.Background())) + assert.Contains(t, buf.String(), "boom") +} + +func TestLogStreamerFiltersSources(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + require.NoError(t, sendEntry(conn, 1, "app")) + require.NoError(t, conn.WriteMessage(websocket.TextMessage, mustJSON(wsEntry{Source: "SYSTEM", Timestamp: 2, Message: "sys"}))) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + sources := map[string]struct{}{"APP": {}} + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "test", + sources: sources, + writer: buf, + } + + require.NoError(t, streamer.Run(context.Background())) + output := strings.TrimSpace(buf.String()) + assert.Contains(t, output, "app") + assert.NotContains(t, output, "sys") +} + +func TestFormatLogEntryColorizesWhenEnabled(t *testing.T) { + original := color.NoColor + color.NoColor = false + defer func() { color.NoColor = original }() + + entry := &wsEntry{Source: "app", Timestamp: 1, Message: "hello\n"} + streamer := &logStreamer{colorize: true} + colored := streamer.formatLogEntry(entry) + assert.Contains(t, colored, "\x1b[") + assert.Contains(t, colored, fmt.Sprintf("[%s]", color.HiBlueString("APP"))) + + streamer.colorize = false + plain := streamer.formatLogEntry(entry) + assert.NotContains(t, plain, "\x1b[") + assert.Contains(t, plain, "[APP]") +} + +func mustJSON(entry wsEntry) []byte { + raw, err := json.Marshal(entry) + if err != nil { + panic(err) + } + return raw +} + +func TestTailWithoutPrefetchRespectsTailSize(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + for i := 1; i <= 4; i++ { + require.NoError(t, sendEntry(conn, float64(i), fmt.Sprintf("line%d", i))) + } + time.Sleep(20 * time.Millisecond) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + tail: 2, + prefetch: 0, + writer: buf, + } + + require.NoError(t, streamer.Run(context.Background())) + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.Len(t, lines, 2) + assert.Contains(t, lines[0], "line3") + assert.Contains(t, lines[1], "line4") +} + +func TestCloseErrorPropagatesWhenAbnormal(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(4403, "auth failed"), time.Now().Add(time.Second)) + }) + defer server.Close() + + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + } + + err := streamer.Run(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "log stream closed with code 4403") + assert.Contains(t, err.Error(), "auth failed") +} + +type failWriter struct { + err error +} + +func (f *failWriter) Write([]byte) (int, error) { + return 0, f.err +} + +type flakyDialer struct { + failures int32 + inner Dialer +} + +func (f *flakyDialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*websocket.Conn, *http.Response, error) { + if atomic.LoadInt32(&f.failures) > 0 { + atomic.AddInt32(&f.failures, -1) + return nil, nil, errors.New("transient dial failure") + } + return f.inner.DialContext(ctx, urlStr, requestHeader) +} + +func newTestLogServer(t *testing.T, handler func(int, *websocket.Conn)) *httptest.Server { + upgrader := websocket.Upgrader{} + var connCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := int(connCount.Add(1)) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("failed to upgrade connection: %v", err) + return + } + go handler(id, conn) + })) + + t.Cleanup(func() { + server.CloseClientConnections() + server.Close() + }) + return server +} + +func toWebSocketURL(raw string) string { + return strings.Replace(raw, "http", "ws", 1) +} + +func sendEntry(conn *websocket.Conn, ts float64, message string) error { + payload, err := json.Marshal(wsEntry{ + Source: "APP", + Timestamp: ts, + Message: message, + }) + if err != nil { + return err + } + return conn.WriteMessage(websocket.TextMessage, payload) +} + +func TestLogStreamerTailFlushesWithoutFollow(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + for i := 1; i <= 4; i++ { + require.NoError(t, sendEntry(conn, float64(i), fmt.Sprintf("line%d", i))) + } + time.Sleep(250 * time.Millisecond) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + writer := newNotifyBuffer() + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + tail: 2, + follow: false, + prefetch: 50 * time.Millisecond, + writer: writer, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- streamer.Run(ctx) + }() + + require.Eventually(t, writer.hasWrite, 150*time.Millisecond, 10*time.Millisecond, "expected tail logs to flush before the server closed the socket") + require.NoError(t, <-done) + lines := strings.Split(strings.TrimSpace(writer.String()), "\n") + require.Len(t, lines, 2) + assert.Contains(t, lines[0], "line3") + assert.Contains(t, lines[1], "line4") +} + +func TestLogStreamerFollowTailWithoutPrefetchEmitsRequestedLines(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + for i := 1; i <= 4; i++ { + require.NoError(t, sendEntry(conn, float64(i), fmt.Sprintf("line%d", i))) + } + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + writer := newNotifyBuffer() + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + tail: 2, + follow: true, + prefetch: 0, + writer: writer, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- streamer.Run(ctx) + }() + + require.Eventually(t, func() bool { + return strings.Contains(writer.String(), "line4") + }, time.Second, 10*time.Millisecond, "expected to see full tail even when prefetching is disabled") + snapshot := writer.String() + cancel() + + err := <-done + require.ErrorIs(t, err, context.Canceled) + lines := strings.Split(strings.TrimSpace(snapshot), "\n") + require.GreaterOrEqual(t, len(lines), 2) + tail := lines[len(lines)-2:] + assert.Contains(t, tail[0], "line3") + assert.Contains(t, tail[1], "line4") +} + +func TestLogStreamerFollowTailDoesNotReplayAfterReconnect(t *testing.T) { + t.Parallel() + + stopCtx, stop := context.WithCancel(context.Background()) + defer stop() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + if id == 1 { + for i := 1; i <= 4; i++ { + require.NoError(t, sendEntry(conn, float64(i), fmt.Sprintf("line%d", i))) + } + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + return + } + + require.NoError(t, sendEntry(conn, 5, "line5")) + require.NoError(t, sendEntry(conn, 6, "line6")) + <-stopCtx.Done() + }) + defer server.Close() + + writer := newNotifyBuffer() + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + tail: 2, + follow: true, + prefetch: 0, + writer: writer, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- streamer.Run(ctx) + }() + + require.Eventually(t, func() bool { + return strings.Contains(writer.String(), "line6") + }, time.Second, 10*time.Millisecond, "expected logs from the second connection") + + cancel() + stop() + + err := <-done + require.ErrorIs(t, err, context.Canceled) + + output := writer.String() + assert.Equal(t, 1, strings.Count(output, "line3"), "line3 emitted more than once") + assert.Equal(t, 1, strings.Count(output, "line4"), "line4 emitted more than once") + assert.Contains(t, output, "line5") + assert.Contains(t, output, "line6") +} + +func TestLogStreamerRefreshesTokenAfterAuthClose(t *testing.T) { + t.Parallel() + + var connCount atomic.Int32 + upgrader := websocket.Upgrader{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := int(connCount.Add(1)) + auth := r.Header.Get("Authorization") + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("failed to upgrade connection: %v", err) + return + } + + go func() { + defer conn.Close() + _, _, _ = conn.ReadMessage() + if id == 1 { + assert.Equal(t, "Bearer expired", auth) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(4403, "auth failed"), time.Now().Add(time.Second)) + return + } + + assert.Equal(t, "Bearer fresh", auth) + if err := sendEntry(conn, 1, "refreshed"); err != nil { + t.Errorf("failed to send entry: %v", err) + } + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }() + })) + t.Cleanup(func() { + server.CloseClientConnections() + server.Close() + }) + + var refreshes atomic.Int32 + tokenProvider := func(ctx context.Context) (string, error) { + if refreshes.Load() > 0 { + return "", errors.New("token refreshed multiple times") + } + refreshes.Add(1) + return "fresh", nil + } + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "expired", + tokenProvider: tokenProvider, + follow: true, + writer: buf, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- streamer.Run(ctx) + }() + + require.Eventually(t, func() bool { + return strings.Contains(buf.String(), "refreshed") + }, time.Second, 10*time.Millisecond, "expected logs after token refresh") + + cancel() + + err := <-done + require.ErrorIs(t, err, context.Canceled) + assert.Equal(t, int32(1), refreshes.Load(), "expected single token refresh") +} + +func TestLogStreamerEmitsPlainTextFrames(t *testing.T) { + t.Parallel() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + require.NoError(t, conn.WriteMessage(websocket.TextMessage, []byte("plain text line"))) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + buf := &bytes.Buffer{} + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + writer: buf, + } + + require.NoError(t, streamer.Run(context.Background())) + assert.Contains(t, buf.String(), "plain text line") +} + +func TestLogStreamerTimeoutStopsQuietFollowStream(t *testing.T) { + t.Parallel() + + stopCtx, stop := context.WithCancel(context.Background()) + defer stop() + + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + _, _, _ = conn.ReadMessage() + <-stopCtx.Done() + }) + defer server.Close() + + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "token", + follow: true, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- streamer.Run(ctx) + }() + + <-ctx.Done() + + select { + case err := <-done: + require.ErrorIs(t, err, context.DeadlineExceeded, "streamer should exit when context times out") + case <-time.After(200 * time.Millisecond): + t.Fatalf("streamer did not exit within 200ms of context deadline") + } +} + +type notifyBuffer struct { + mu sync.Mutex + buf bytes.Buffer + ch chan struct{} +} + +func newNotifyBuffer() *notifyBuffer { + return ¬ifyBuffer{ch: make(chan struct{}, 1)} +} + +func (n *notifyBuffer) Write(p []byte) (int, error) { + n.mu.Lock() + written, err := n.buf.Write(p) + n.mu.Unlock() + if err == nil { + select { + case n.ch <- struct{}{}: + default: + } + } + return written, err +} + +func (n *notifyBuffer) String() string { + n.mu.Lock() + defer n.mu.Unlock() + return n.buf.String() +} + +func (n *notifyBuffer) hasWrite() bool { + select { + case <-n.ch: + return true + default: + return false + } +} + +func TestAppStatusCheckerStopsFollowing(t *testing.T) { + t.Parallel() + + var connectionCount atomic.Int32 + server := newTestLogServer(t, func(id int, conn *websocket.Conn) { + defer conn.Close() + connectionCount.Add(1) + _, _, _ = conn.ReadMessage() // search token + + // Send a couple messages then close + require.NoError(t, sendEntry(conn, 1.0, "message before stop")) + time.Sleep(10 * time.Millisecond) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second)) + }) + defer server.Close() + + buf := &bytes.Buffer{} + checkCount := atomic.Int32{} + appStatusChecker := func(ctx context.Context) error { + count := checkCount.Add(1) + // Simulate app being stopped after first reconnect attempt + if count > 1 { + return errors.New("app stopped") + } + return nil + } + + streamer := &logStreamer{ + dialer: &websocket.Dialer{}, + url: toWebSocketURL(server.URL), + token: "test", + follow: true, + writer: buf, + appStatusChecker: appStatusChecker, + } + + err := streamer.Run(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "app is no longer available") + assert.Contains(t, err.Error(), "app stopped") + + // Should have connected twice: initial connection + one reconnect attempt + assert.Equal(t, int32(2), connectionCount.Load()) + // Should have checked app status once (before second reconnect) + assert.GreaterOrEqual(t, checkCount.Load(), int32(1)) +}