diff --git a/src/pkg/cli/client/byoc/aws/byoc.go b/src/pkg/cli/client/byoc/aws/byoc.go index bb7a5b4c2..e6ee4bc08 100644 --- a/src/pkg/cli/client/byoc/aws/byoc.go +++ b/src/pkg/cli/client/byoc/aws/byoc.go @@ -713,7 +713,7 @@ func (b *ByocAws) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (ite if err != nil { return nil, AnnotateAwsError(err) } - logSeq = cw.Flatten(cdSeq) + logSeq = pkg.Flatten(cdSeq) // No need to filter events by etag because we only show logs from the specified task ID } else { logSeq, err = b.queryOrTailLogs(ctx, cwClient, req) @@ -785,7 +785,7 @@ func (b *ByocAws) queryOrTailLogs(ctx context.Context, cwClient cw.LogsClient, r if err != nil { return nil, err } - return cw.Flatten(logSeq), nil + return pkg.Flatten(logSeq), nil } else { logSeq, err := cw.QueryLogGroups( ctx, diff --git a/src/pkg/cli/client/byoc/aws/byoc_test.go b/src/pkg/cli/client/byoc/aws/byoc_test.go index be50d6618..8ce8f8946 100644 --- a/src/pkg/cli/client/byoc/aws/byoc_test.go +++ b/src/pkg/cli/client/byoc/aws/byoc_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/DefangLabs/defang/src/pkg" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc" "github.com/DefangLabs/defang/src/pkg/clouds/aws" "github.com/DefangLabs/defang/src/pkg/clouds/aws/cw" @@ -693,7 +694,7 @@ func TestQueryCdLogs(t *testing.T) { require.NoError(t, err) // Flatten and collect - logSeq := cw.Flatten(batchSeq) + logSeq := pkg.Flatten(batchSeq) events := collectEvents(t, logSeq) assert.Len(t, events, tt.wantCount) }) diff --git a/src/pkg/cli/client/byoc/gcp/byoc.go b/src/pkg/cli/client/byoc/gcp/byoc.go index 044c1035b..0571cacc0 100644 --- a/src/pkg/cli/client/byoc/gcp/byoc.go +++ b/src/pkg/cli/client/byoc/gcp/byoc.go @@ -586,17 +586,17 @@ func (b *ByocGcp) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) now := time.Now() subscribeStream.query.AddSince(now) // Do no query historical events - return subscribeStream.Follow(now) + return subscribeStream.Follow(ctx, now) } func (b *ByocGcp) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { logStream := b.getLogStream(ctx, b.driver, req) if req.Follow { - return logStream.Follow(req.Since.AsTime()) + return logStream.Follow(ctx, req.Since.AsTime()) } else if req.Since.IsValid() { - return logStream.Head(req.Limit), nil + return logStream.Head(ctx, req.Limit), nil } - return logStream.Tail(req.Limit), nil + return logStream.Tail(ctx, req.Limit), nil } func (b *ByocGcp) getLogStream(ctx context.Context, gcpLogsClient GcpLogsClient, req *defangv1.TailRequest) *LogStream { diff --git a/src/pkg/cli/client/byoc/gcp/byoc_test.go b/src/pkg/cli/client/byoc/gcp/byoc_test.go index 87f54eec7..67b35fa7d 100644 --- a/src/pkg/cli/client/byoc/gcp/byoc_test.go +++ b/src/pkg/cli/client/byoc/gcp/byoc_test.go @@ -4,7 +4,7 @@ import ( "context" "encoding/base64" "errors" - "io" + "iter" "os" "testing" "time" @@ -51,16 +51,24 @@ func TestSetUpCD(t *testing.T) { } type MockGcpLogsClient struct { - lister gcp.Lister - tailer gcp.Tailer + listEntries []*loggingpb.LogEntry + tailEntries []*loggingpb.LogEntry } -func (m MockGcpLogsClient) ListLogEntries(ctx context.Context, query string, order gcp.Order) (gcp.Lister, error) { - return m.lister, nil +func mockEntryIter(entries []*loggingpb.LogEntry) iter.Seq2[[]*loggingpb.LogEntry, error] { + return func(yield func([]*loggingpb.LogEntry, error) bool) { + if !yield(entries, nil) { + return + } + } } -func (m MockGcpLogsClient) NewTailer(ctx context.Context) (gcp.Tailer, error) { - return m.tailer, nil +func (m MockGcpLogsClient) ListLogEntries(ctx context.Context, query string, order gcp.Order) (iter.Seq2[[]*loggingpb.LogEntry, error], error) { + return mockEntryIter(m.listEntries), nil +} + +func (m MockGcpLogsClient) TailLogEntries(ctx context.Context, query string) (iter.Seq2[[]*loggingpb.LogEntry, error], error) { + return mockEntryIter(m.tailEntries), nil } func (m MockGcpLogsClient) GetExecutionEnv(ctx context.Context, executionName string) (map[string]string, error) { return nil, nil @@ -77,35 +85,6 @@ func (m MockGcpLogsClient) GetBuildInfo(ctx context.Context, buildId string) (*g }, nil } -type MockGcpLoggingLister struct { - logEntries []*loggingpb.LogEntry -} - -func (m *MockGcpLoggingLister) Next() (*loggingpb.LogEntry, error) { - if len(m.logEntries) > 0 { - entry := m.logEntries[0] - m.logEntries = m.logEntries[1:] - return entry, nil - } - return nil, io.EOF -} - -type MockGcpLoggingTailer struct { - MockGcpLoggingLister -} - -func (m *MockGcpLoggingTailer) Close() error { - return nil -} - -func (m *MockGcpLoggingTailer) Start(ctx context.Context, query string) error { - return nil -} - -func (m *MockGcpLoggingTailer) Next(ctx context.Context) (*loggingpb.LogEntry, error) { - return m.MockGcpLoggingLister.Next() -} - func TestGetLogStream(t *testing.T) { tests := []struct { name string @@ -175,10 +154,7 @@ func TestGetLogStream(t *testing.T) { b := NewByocProvider(ctx, "testTenantID", "") b.cdExecution = tt.cdExecution - driver := &MockGcpLogsClient{ - lister: &MockGcpLoggingLister{}, - tailer: &MockGcpLoggingTailer{}, - } + driver := &MockGcpLogsClient{} logStream := b.getLogStream(ctx, driver, tt.req) diff --git a/src/pkg/cli/client/byoc/gcp/stream.go b/src/pkg/cli/client/byoc/gcp/stream.go index e9b1e1abf..681304692 100644 --- a/src/pkg/cli/client/byoc/gcp/stream.go +++ b/src/pkg/cli/client/byoc/gcp/stream.go @@ -28,24 +28,22 @@ type LogParser[T any] func(*loggingpb.LogEntry) ([]*T, error) type LogFilter[T any] func(entry T) T type GcpLogsClient interface { - ListLogEntries(ctx context.Context, query string, order gcp.Order) (gcp.Lister, error) - NewTailer(ctx context.Context) (gcp.Tailer, error) + ListLogEntries(ctx context.Context, query string, order gcp.Order) (iter.Seq2[[]*loggingpb.LogEntry, error], error) + TailLogEntries(ctx context.Context, query string) (iter.Seq2[[]*loggingpb.LogEntry, error], error) GetExecutionEnv(ctx context.Context, executionName string) (map[string]string, error) GetProjectID() gcp.ProjectId GetBuildInfo(ctx context.Context, buildId string) (*gcp.BuildTag, error) } type ServerStream[T any] struct { - ctx context.Context gcpLogsClient GcpLogsClient parse LogParser[T] filters []LogFilter[*T] query *Query } -func NewServerStream[T any](ctx context.Context, gcpLogsClient GcpLogsClient, parse LogParser[T], filters ...LogFilter[*T]) *ServerStream[T] { +func NewServerStream[T any](gcpLogsClient GcpLogsClient, parse LogParser[T], filters ...LogFilter[*T]) *ServerStream[T] { return &ServerStream[T]{ - ctx: ctx, gcpLogsClient: gcpLogsClient, parse: parse, filters: filters, @@ -63,41 +61,36 @@ func isContextCanceledError(err error) bool { } // Follow returns an iterator that queries historical logs then tails live logs. -func (s *ServerStream[T]) Follow(start time.Time) (iter.Seq2[*T, error], error) { - tailer, err := s.gcpLogsClient.NewTailer(s.ctx) +func (s *ServerStream[T]) Follow(ctx context.Context, start time.Time) (iter.Seq2[*T, error], error) { + query := s.query.GetQuery() + shouldList := !start.IsZero() && start.Unix() > 0 && time.Since(start) > 10*time.Millisecond + // Establish tail connection eagerly so the server starts buffering entries while we list historical logs + tailIter, err := s.gcpLogsClient.TailLogEntries(ctx, query) if err != nil { return nil, err } - query := s.query.GetQuery() - shouldList := !start.IsZero() && start.Unix() > 0 && time.Since(start) > 10*time.Millisecond term.Debugf("Query and tail logs since %v with query: \n%v", start, query) return func(yield func(*T, error) bool) { - defer tailer.Close() // Only query older logs if start time is more than 10ms ago if shouldList { - lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderAscending) + listIter, err := s.gcpLogsClient.ListLogEntries(ctx, query, gcp.OrderAscending) if err != nil { yield(nil, err) return } - if !s.yieldList(yield, lister, 0) { + if !s.yieldList(yield, listIter, 0) { return } } - // Start tailing logs after all older logs are processed - if err := tailer.Start(s.ctx, query); err != nil { - yield(nil, err) - return - } - for { - entry, err := tailer.Next(s.ctx) + // Tail live logs after all older logs are processed + for entries, err := range tailIter { if err != nil { - if context.Cause(s.ctx) == io.EOF || errors.Is(err, io.EOF) { + if context.Cause(ctx) == io.EOF || errors.Is(err, io.EOF) { return } if isContextCanceledError(err) { - if cause := context.Cause(s.ctx); cause != nil { + if cause := context.Cause(ctx); cause != nil { yield(nil, cause) } return @@ -105,48 +98,50 @@ func (s *ServerStream[T]) Follow(start time.Time) (iter.Seq2[*T, error], error) yield(nil, err) return } - resps, err := s.parseAndFilter(entry) - if err != nil { - yield(nil, err) - return - } - for _, resp := range resps { - if !yield(resp, nil) { + for _, entry := range entries { + resps, err := s.parseAndFilter(entry) + if err != nil { + yield(nil, err) return } + for _, resp := range resps { + if !yield(resp, nil) { + return + } + } } } }, nil } // Head returns an iterator that queries logs in ascending order. -func (s *ServerStream[T]) Head(limit int32) iter.Seq2[*T, error] { +func (s *ServerStream[T]) Head(ctx context.Context, limit int32) iter.Seq2[*T, error] { query := s.query.GetQuery() term.Debugf("Query logs with query: \n%v", query) return func(yield func(*T, error) bool) { - lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderAscending) + listIter, err := s.gcpLogsClient.ListLogEntries(ctx, query, gcp.OrderAscending) if err != nil { yield(nil, err) return } - s.yieldList(yield, lister, limit) + s.yieldList(yield, listIter, limit) } } // Tail returns an iterator that queries logs in descending order, reversing if a limit is set. -func (s *ServerStream[T]) Tail(limit int32) iter.Seq2[*T, error] { +func (s *ServerStream[T]) Tail(ctx context.Context, limit int32) iter.Seq2[*T, error] { query := s.query.GetQuery() term.Debugf("Query logs with query: \n%v", query) return func(yield func(*T, error) bool) { - lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderDescending) + listIter, err := s.gcpLogsClient.ListLogEntries(ctx, query, gcp.OrderDescending) if err != nil { yield(nil, err) return } if limit == 0 { - s.yieldList(yield, lister, 0) + s.yieldList(yield, listIter, 0) } else { - buffer, err := s.listToBuffer(lister, limit) + buffer, err := s.listToBuffer(listIter, limit) if err != nil { yield(nil, err) return @@ -161,51 +156,52 @@ func (s *ServerStream[T]) Tail(limit int32) iter.Seq2[*T, error] { } } -// yieldList yields items from lister to yield. Returns true if iteration completed -// (EOF or limit reached), false if the consumer stopped or an error was yielded. -func (s *ServerStream[T]) yieldList(yield func(*T, error) bool, lister gcp.Lister, limit int32) bool { +// yieldList yields items from entries to yield. Returns true if iteration completed +// (end of entries or limit reached), false if the consumer stopped or an error was yielded. +func (s *ServerStream[T]) yieldList(yield func(*T, error) bool, seq iter.Seq2[[]*loggingpb.LogEntry, error], limit int32) bool { count := int32(0) - for { - if limit > 0 && count >= limit { - return true - } - entry, err := lister.Next() + for entries, err := range seq { if err != nil { - if errors.Is(err, io.EOF) { - return true - } yield(nil, err) return false } - resps, err := s.parseAndFilter(entry) - if err != nil { - yield(nil, err) - return false - } - for _, resp := range resps { - count++ - if !yield(resp, nil) { + for _, entry := range entries { + resps, err := s.parseAndFilter(entry) + if err != nil { + yield(nil, err) return false } + for _, resp := range resps { + count++ + if !yield(resp, nil) { + return false + } + if limit > 0 && count >= limit { + return true + } + } } } + return true } -func (s *ServerStream[T]) listToBuffer(lister gcp.Lister, limit int32) ([]*T, error) { +func (s *ServerStream[T]) listToBuffer(seq iter.Seq2[[]*loggingpb.LogEntry, error], limit int32) ([]*T, error) { buffer := make([]*T, 0, limit) - for range limit { - entry, err := lister.Next() + for entries, err := range seq { if err != nil { - if errors.Is(err, io.EOF) { - return buffer, nil - } return nil, err } - resps, err := s.parseAndFilter(entry) - if err != nil { - return nil, err + for _, entry := range entries { + resps, err := s.parseAndFilter(entry) + if err != nil { + return nil, err + } + buffer = append(buffer, resps...) + if len(buffer) >= int(limit) { + buffer = buffer[:limit] + return buffer, nil + } } - buffer = append(buffer, resps...) } return buffer, nil } @@ -251,7 +247,7 @@ func NewLogStream(ctx context.Context, gcpLogsClient GcpLogsClient, services []s return entry }) - ss := NewServerStream(ctx, gcpLogsClient, getLogEntryParser(ctx, gcpLogsClient), restoreServiceName) + ss := NewServerStream(gcpLogsClient, getLogEntryParser(ctx, gcpLogsClient), restoreServiceName) ss.query = NewLogQuery(gcpLogsClient.GetProjectID()) return &LogStream{ServerStream: ss} } @@ -312,7 +308,7 @@ func NewSubscribeStream(ctx context.Context, driver GcpLogsClient, waitForCD boo }), ) - ss := NewServerStream(ctx, driver, getActivityParser(ctx, driver, waitForCD, etag), filters...) + ss := NewServerStream(driver, getActivityParser(ctx, driver, waitForCD, etag), filters...) ss.query = NewSubscribeQuery() return &SubscribeStream{ServerStream: ss} } diff --git a/src/pkg/cli/client/byoc/gcp/stream_test.go b/src/pkg/cli/client/byoc/gcp/stream_test.go index 0ffc17009..0b7f549f2 100644 --- a/src/pkg/cli/client/byoc/gcp/stream_test.go +++ b/src/pkg/cli/client/byoc/gcp/stream_test.go @@ -1,9 +1,12 @@ package gcp import ( + "context" + "errors" "iter" "strconv" "testing" + "time" "cloud.google.com/go/logging/apiv2/loggingpb" "github.com/DefangLabs/defang/src/pkg/clouds/gcp" @@ -67,6 +70,45 @@ func makeMockLogEntries(n int) []*loggingpb.LogEntry { return logEntries } +func newTestStream(t *testing.T, listEntries, tailEntries []*loggingpb.LogEntry) *ServerStream[defangv1.TailResponse] { + t.Helper() + ctx := t.Context() + projectId := gcp.ProjectId("test-project-12345") + services := []string{} + restoreServiceName := getServiceNameRestorer(services, gcp.SafeLabelValue, + func(entry *defangv1.TailResponse) string { return entry.Service }, + func(entry *defangv1.TailResponse, name string) *defangv1.TailResponse { + entry.Service = name + return entry + }) + + mockClient := &MockGcpLogsClient{ + listEntries: listEntries, + tailEntries: tailEntries, + } + + stream := NewServerStream( + mockClient, + getLogEntryParser(ctx, mockClient), + restoreServiceName, + ) + stream.query = NewLogQuery(projectId) + return stream +} + +func collectMessages(t *testing.T, logs iter.Seq2[*defangv1.TailResponse, error]) []string { + t.Helper() + var msgs []string + for response, err := range logs { + assert.NoError(t, err) + if err != nil { + break + } + msgs = append(msgs, response.Entries[0].Message) + } + return msgs +} + func TestServerStream_Start(t *testing.T) { type directionType string const ( @@ -125,16 +167,6 @@ func TestServerStream_Start(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := t.Context() - projectId := gcp.ProjectId("test-project-12345") - services := []string{} - restoreServiceName := getServiceNameRestorer(services, gcp.SafeLabelValue, - func(entry *defangv1.TailResponse) string { return entry.Service }, - func(entry *defangv1.TailResponse, name string) *defangv1.TailResponse { - entry.Service = name - return entry - }) - logEntries := makeMockLogEntries(tt.streamSize) // Reverse log entries for tail direction to simulate descending order @@ -144,38 +176,179 @@ func TestServerStream_Start(t *testing.T) { } } - mockGcpLogsClient := &MockGcpLogsClient{ - lister: &MockGcpLoggingLister{ - logEntries: logEntries, - }, - tailer: &MockGcpLoggingTailer{}, - } - - stream := NewServerStream( - ctx, - mockGcpLogsClient, - getLogEntryParser(ctx, mockGcpLogsClient), - restoreServiceName, - ) - stream.query = NewLogQuery(projectId) + stream := newTestStream(t, logEntries, nil) var logs iter.Seq2[*defangv1.TailResponse, error] if tt.direction == head { - logs = stream.Head(tt.limit) + logs = stream.Head(t.Context(), tt.limit) } else { - logs = stream.Tail(tt.limit) + logs = stream.Tail(t.Context(), tt.limit) } - var collectedMessages []string - for response, err := range logs { - assert.NoError(t, err) - if err != nil { - break - } - collectedMessages = append(collectedMessages, response.Entries[0].Message) - } - assert.Equal(t, len(tt.expectedMsgs), len(collectedMessages)) + collectedMessages := collectMessages(t, logs) assert.Equal(t, tt.expectedMsgs, collectedMessages) }) } } + +func TestServerStream_HeadNoLimit(t *testing.T) { + entries := makeMockLogEntries(5) + stream := newTestStream(t, entries, nil) + msgs := collectMessages(t, stream.Head(t.Context(), 0)) + assert.Len(t, msgs, 5) + assert.Equal(t, "Log entry number 0", msgs[0]) + assert.Equal(t, "Log entry number 4", msgs[4]) +} + +func TestServerStream_TailNoLimit(t *testing.T) { + // Descending order entries (4,3,2,1,0) + entries := makeMockLogEntries(5) + for i, j := 0, len(entries)-1; i < j; i, j = i+1, j-1 { + entries[i], entries[j] = entries[j], entries[i] + } + stream := newTestStream(t, entries, nil) + msgs := collectMessages(t, stream.Tail(t.Context(), 0)) + // With limit=0, no reversal — entries come in descending order + assert.Len(t, msgs, 5) + assert.Equal(t, "Log entry number 4", msgs[0]) + assert.Equal(t, "Log entry number 0", msgs[4]) +} + +func TestServerStream_EmptyStream(t *testing.T) { + stream := newTestStream(t, nil, nil) + + t.Run("Head", func(t *testing.T) { + msgs := collectMessages(t, stream.Head(t.Context(), 10)) + assert.Empty(t, msgs) + }) + + t.Run("Tail", func(t *testing.T) { + msgs := collectMessages(t, stream.Tail(t.Context(), 10)) + assert.Empty(t, msgs) + }) +} + +func TestServerStream_Follow(t *testing.T) { + listEntries := makeMockLogEntries(3) + tailEntries := []*loggingpb.LogEntry{ + { + Payload: &loggingpb.LogEntry_TextPayload{TextPayload: "tail entry 0"}, + Timestamp: timestamppb.Now(), + }, + { + Payload: &loggingpb.LogEntry_TextPayload{TextPayload: "tail entry 1"}, + Timestamp: timestamppb.Now(), + }, + } + + stream := newTestStream(t, listEntries, tailEntries) + // Use a start time in the past to trigger historical listing + logs, err := stream.Follow(t.Context(), time.Now().Add(-time.Hour)) + assert.NoError(t, err) + + msgs := collectMessages(t, logs) + assert.Equal(t, []string{ + "Log entry number 0", + "Log entry number 1", + "Log entry number 2", + "tail entry 0", + "tail entry 1", + }, msgs) +} + +func TestServerStream_FollowNoHistory(t *testing.T) { + tailEntries := []*loggingpb.LogEntry{ + { + Payload: &loggingpb.LogEntry_TextPayload{TextPayload: "tail only"}, + Timestamp: timestamppb.Now(), + }, + } + + stream := newTestStream(t, nil, tailEntries) + // Zero start time skips historical listing + logs, err := stream.Follow(t.Context(), time.Time{}) + assert.NoError(t, err) + + msgs := collectMessages(t, logs) + assert.Equal(t, []string{"tail only"}, msgs) +} + +func TestServerStream_ListError(t *testing.T) { + testErr := errors.New("list error") + errorIter := func(yield func([]*loggingpb.LogEntry, error) bool) { + yield(nil, testErr) + } + + ctx := t.Context() + mockClient := &MockGcpLogsClient{} + stream := NewServerStream( + mockClient, + getLogEntryParser(ctx, mockClient), + ) + stream.query = NewLogQuery("test-project") + + // Override the client to return an error iterator + stream.gcpLogsClient = &errorListClient{ + MockGcpLogsClient: *mockClient, + listIter: errorIter, + } + + var gotErr error + for _, err := range stream.Head(ctx, 10) { + if err != nil { + gotErr = err + break + } + } + assert.ErrorIs(t, gotErr, testErr) +} + +type errorListClient struct { + MockGcpLogsClient + listIter iter.Seq2[[]*loggingpb.LogEntry, error] +} + +func (e *errorListClient) ListLogEntries(ctx context.Context, query string, order gcp.Order) (iter.Seq2[[]*loggingpb.LogEntry, error], error) { + return e.listIter, nil +} + +func TestServerStream_TailError(t *testing.T) { + testErr := errors.New("tail error") + errorIter := func(yield func([]*loggingpb.LogEntry, error) bool) { + yield(nil, testErr) + } + + ctx := t.Context() + mockClient := &MockGcpLogsClient{} + stream := NewServerStream( + mockClient, + getLogEntryParser(ctx, mockClient), + ) + stream.query = NewLogQuery("test-project") + + stream.gcpLogsClient = &errorTailClient{ + MockGcpLogsClient: *mockClient, + tailIter: errorIter, + } + + logs, err := stream.Follow(t.Context(), time.Time{}) + assert.NoError(t, err) + + var gotErr error + for _, err := range logs { + if err != nil { + gotErr = err + break + } + } + assert.ErrorIs(t, gotErr, testErr) +} + +type errorTailClient struct { + MockGcpLogsClient + tailIter iter.Seq2[[]*loggingpb.LogEntry, error] +} + +func (e *errorTailClient) TailLogEntries(ctx context.Context, query string) (iter.Seq2[[]*loggingpb.LogEntry, error], error) { + return e.tailIter, nil +} diff --git a/src/pkg/clouds/aws/cw/logs.go b/src/pkg/clouds/aws/cw/logs.go index 5d9673f49..fae9ce35a 100644 --- a/src/pkg/clouds/aws/cw/logs.go +++ b/src/pkg/clouds/aws/cw/logs.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/DefangLabs/defang/src/pkg" "github.com/DefangLabs/defang/src/pkg/clouds/aws" "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types" @@ -97,25 +98,6 @@ func TailLogGroup(ctx context.Context, cwClient StartLiveTailAPI, input LogGroup type FilterLogEventsAPIClient = cloudwatchlogs.FilterLogEventsAPIClient -// Flatten converts an iterator of batches into an iterator of individual items. -func Flatten[T any](seq iter.Seq2[[]T, error]) iter.Seq2[T, error] { - return func(yield func(T, error) bool) { - for items, err := range seq { - for _, item := range items { - if !yield(item, nil) { - return - } - } - if err != nil { - var zero T - if !yield(zero, err) { - return - } - } - } - } -} - func QueryLogGroups(ctx context.Context, cwClient FilterLogEventsAPIClient, start, end time.Time, limit int32, logGroups ...LogGroupInput) (iter.Seq2[LogEvent, error], error) { if len(logGroups) == 0 { return nil, errors.New("at least one LogGroupInput is required") @@ -127,7 +109,7 @@ func QueryLogGroups(ctx context.Context, cwClient FilterLogEventsAPIClient, star // This only happens if there's a missing LogGroupARN, in which case we can't proceed at all return nil, err } - merged = MergeLogEvents(merged, Flatten(logSeq)) // Merge sort the log events based on timestamp + merged = MergeLogEvents(merged, pkg.Flatten(logSeq)) // Merge sort the log events based on timestamp if limit > 0 { // take the first/last n events only from the merged stream if start.IsZero() { diff --git a/src/pkg/clouds/gcp/logging.go b/src/pkg/clouds/gcp/logging.go index cff8ce843..f5971a902 100644 --- a/src/pkg/clouds/gcp/logging.go +++ b/src/pkg/clouds/gcp/logging.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io" + "iter" logging "cloud.google.com/go/logging/apiv2" "cloud.google.com/go/logging/apiv2/loggingpb" @@ -12,83 +12,6 @@ import ( "google.golang.org/api/iterator" ) -func (gcp Gcp) NewTailer(ctx context.Context) (Tailer, error) { - client, err := logging.NewClient(ctx) - if err != nil { - return nil, err - } - tleClient, err := client.TailLogEntries(ctx) - if err != nil { - return nil, err - } - t := &gcpLoggingTailer{ - projectId: gcp.ProjectId, - tleClient: tleClient, - client: client, - } - return t, nil -} - -type Tailer interface { - Start(ctx context.Context, query string) error - Next(ctx context.Context) (*loggingpb.LogEntry, error) - Close() error -} - -type gcpLoggingTailer struct { - projectId string - tleClient loggingpb.LoggingServiceV2_TailLogEntriesClient - client *logging.Client - - cache []*loggingpb.LogEntry -} - -func (t *gcpLoggingTailer) Start(ctx context.Context, query string) error { - req := &loggingpb.TailLogEntriesRequest{ - ResourceNames: []string{"projects/" + t.projectId}, - Filter: query, - } - if err := t.tleClient.Send(req); err != nil { - return fmt.Errorf("failed to send tail log entries request: %w", err) - } - return nil -} - -func (t *gcpLoggingTailer) Next(ctx context.Context) (*loggingpb.LogEntry, error) { - if len(t.cache) == 0 { - resp, err := t.tleClient.Recv() - if err != nil { - return nil, err - } - t.cache = resp.GetEntries() - if len(t.cache) == 0 { - return nil, errors.New("no log entries found") - } - } - - entry := t.cache[0] - t.cache = t.cache[1:] - return entry, nil -} - -func (t *gcpLoggingTailer) Close() error { - // TODO: find out how to properly close the client - term.Debugf("Closing log tailer") - e1 := t.tleClient.CloseSend() - term.Debugf("Closing log tailer client") - e2 := t.client.Close() - return errors.Join(e1, e2) -} - -type Lister interface { - Next() (*loggingpb.LogEntry, error) -} - -type gcpLoggingLister struct { - it *logging.LogEntryIterator - client *logging.Client -} - type Order string const ( @@ -96,7 +19,9 @@ const ( OrderAscending Order = "asc" ) -func (gcp Gcp) ListLogEntries(ctx context.Context, query string, order Order) (Lister, error) { +// ListLogEntries returns an iterator over log entries matching the query. +// The underlying client is closed when iteration completes or is stopped. +func (gcp Gcp) ListLogEntries(ctx context.Context, query string, order Order) (iter.Seq2[[]*loggingpb.LogEntry, error], error) { client, err := logging.NewClient(ctx) if err != nil { return nil, err @@ -108,17 +33,66 @@ func (gcp Gcp) ListLogEntries(ctx context.Context, query string, order Order) (L OrderBy: fmt.Sprintf("timestamp %s", order), } it := client.ListLogEntries(ctx, req) - return &gcpLoggingLister{it: it, client: client}, nil + return func(yield func([]*loggingpb.LogEntry, error) bool) { + defer func() { + term.Debugf("Closing log lister client") + client.Close() + }() + for { + entry, err := it.Next() + if err == iterator.Done { + return + } + if err != nil { + yield(nil, err) + return + } + if !yield([]*loggingpb.LogEntry{entry}, nil) { + return + } + } + }, nil } -func (l *gcpLoggingLister) Next() (*loggingpb.LogEntry, error) { - entry, err := l.it.Next() - if err == iterator.Done { - term.Debugf("Closing log lister client") - if err := l.client.Close(); err != nil { - return nil, err - } - return nil, io.EOF +// TailLogEntries establishes a log tail stream and sends the filter request eagerly. +// The returned iterator yields log entries as they arrive. The underlying stream +// and client are closed when iteration completes or is stopped. +func (gcp Gcp) TailLogEntries(ctx context.Context, query string) (iter.Seq2[[]*loggingpb.LogEntry, error], error) { + client, err := logging.NewClient(ctx) + if err != nil { + return nil, err } - return entry, err + tleClient, err := client.TailLogEntries(ctx) + if err != nil { + client.Close() + return nil, err + } + + req := &loggingpb.TailLogEntriesRequest{ + ResourceNames: []string{"projects/" + gcp.ProjectId}, + Filter: query, + } + if err := tleClient.Send(req); err != nil { + tleClient.CloseSend() + client.Close() + return nil, fmt.Errorf("failed to send tail log entries request: %w", err) + } + + return func(yield func([]*loggingpb.LogEntry, error) bool) { + defer func() { + term.Debugf("Closing log tailer") + e1 := tleClient.CloseSend() + term.Debugf("Closing log tailer client") + e2 := client.Close() + if err := errors.Join(e1, e2); err != nil { + term.Debugf("Error closing log tailer: %v", err) + } + }() + for { + resp, err := tleClient.Recv() + if !yield(resp.GetEntries(), err) { + return + } + } + }, nil } diff --git a/src/pkg/test_utils.go b/src/pkg/test_utils.go new file mode 100644 index 000000000..176cf96cc --- /dev/null +++ b/src/pkg/test_utils.go @@ -0,0 +1,48 @@ +package pkg + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + + "github.com/pmezard/go-difflib/difflib" +) + +func Compare(actual []byte, goldenFile string) error { + // Replace the absolute path in context to make the golden file portable + absPath, _ := filepath.Abs(goldenFile) + actual = bytes.ReplaceAll(actual, []byte(filepath.Dir(absPath)), []byte{'.'}) + + golden, err := os.ReadFile(goldenFile) + if err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("failed to read golden file: %w", err) + } + return os.WriteFile(goldenFile, actual, 0644) + } else { + if err := Diff(string(actual), string(golden)); err != nil { + return fmt.Errorf("%s %w", goldenFile, err) + } + } + return nil +} + +func Diff(actualRaw, goldenRaw string) error { + if actualRaw == goldenRaw { + return nil + } + + // Show the diff (but only the lines that differ to avoid overwhelming output) + diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ + A: difflib.SplitLines(goldenRaw), + B: difflib.SplitLines(actualRaw), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 1, + }) + + return fmt.Errorf("mismatch:\n%s", diff) +} diff --git a/src/pkg/utils.go b/src/pkg/utils.go index fc3eb98a8..10d427553 100644 --- a/src/pkg/utils.go +++ b/src/pkg/utils.go @@ -1,20 +1,17 @@ package pkg import ( - "bytes" "context" - "fmt" "io" + "iter" "math/rand" "os" - "path/filepath" "regexp" "strconv" "strings" "time" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" - "github.com/pmezard/go-difflib/difflib" ) var ( @@ -159,44 +156,6 @@ func IsValidTime(t time.Time) bool { return t.Year() > 1970 } -func Compare(actual []byte, goldenFile string) error { - // Replace the absolute path in context to make the golden file portable - absPath, _ := filepath.Abs(goldenFile) - actual = bytes.ReplaceAll(actual, []byte(filepath.Dir(absPath)), []byte{'.'}) - - golden, err := os.ReadFile(goldenFile) - if err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("failed to read golden file: %w", err) - } - return os.WriteFile(goldenFile, actual, 0644) - } else { - if err := Diff(string(actual), string(golden)); err != nil { - return fmt.Errorf("%s %w", goldenFile, err) - } - } - return nil -} - -func Diff(actualRaw, goldenRaw string) error { - if actualRaw == goldenRaw { - return nil - } - - // Show the diff (but only the lines that differ to avoid overwhelming output) - diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ - A: difflib.SplitLines(goldenRaw), - B: difflib.SplitLines(actualRaw), - FromFile: "Expected", - FromDate: "", - ToFile: "Actual", - ToDate: "", - Context: 1, - }) - - return fmt.Errorf("mismatch:\n%s", diff) -} - var shellSpecialChars = regexp.MustCompile(`[^\w@%+=:,./-]`) // copied from al.essio.dev/pkg/shellescape // ShellQuote returns a shell-quoted string of the given arguments. @@ -226,3 +185,22 @@ func GcpInEnv() string { env, _ := GetFirstEnv(GCPProjectEnvVars...) return env } + +// Flatten converts an iterator of batches into an iterator of individual items. +func Flatten[T any](seq iter.Seq2[[]T, error]) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + for items, err := range seq { + for _, item := range items { + if !yield(item, nil) { + return + } + } + if err != nil { + var zero T + if !yield(zero, err) { + return + } + } + } + } +}