From c3bd1f013c4a733fa8206e77a0bfc4219c9b6c36 Mon Sep 17 00:00:00 2001 From: Matee Ullah Malik Date: Mon, 25 Aug 2025 17:57:11 +0500 Subject: [PATCH] feat: implement gRPC connection with keepalive and retry interceptor --- pkg/lumera/connection.go | 66 ++++++++++-- pkg/lumera/connection_test.go | 195 ++++++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 10 deletions(-) create mode 100644 pkg/lumera/connection_test.go diff --git a/pkg/lumera/connection.go b/pkg/lumera/connection.go index 5947c132..b75fec6a 100644 --- a/pkg/lumera/connection.go +++ b/pkg/lumera/connection.go @@ -11,11 +11,18 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" ) const ( - defaultTimeout = 30 * time.Second defaultLumeraPort = "9090" + + keepaliveTime = 30 * time.Second + keepaliveTimeout = 10 * time.Second + retryDelay = 2 * time.Second + maxRetryDelay = 30 * time.Second + maxRetries = 5 + backoffFactor = 2 ) // Connection defines the interface for a client connection. @@ -45,15 +52,7 @@ func newGRPCConnection(ctx context.Context, rawAddr string) (Connection, error) creds = insecure.NewCredentials() } - dialCtx, cancel := context.WithTimeout(ctx, defaultTimeout) - defer cancel() - - conn, err := grpc.DialContext( - dialCtx, - hostPort, - grpc.WithTransportCredentials(creds), - grpc.WithBlock(), - ) + conn, err := createGRPCConnection(ctx, hostPort, creds) if err != nil { return nil, fmt.Errorf("failed to connect to gRPC server: %w", err) } @@ -109,6 +108,53 @@ func normaliseAddr(raw string) (hostPort string, useTLS bool, serverName string, return net.JoinHostPort(host, port), false, host, nil } +// createGRPCConnection creates a gRPC connection with keepalive and retry interceptor +func createGRPCConnection(ctx context.Context, hostPort string, creds credentials.TransportCredentials) (*grpc.ClientConn, error) { + _ = ctx // Keeping this for api compatibility + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(creds), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: keepaliveTime, + Timeout: keepaliveTimeout, + PermitWithoutStream: true, + }), + grpc.WithUnaryInterceptor(retryInterceptor), + } + + return grpc.NewClient(hostPort, opts...) +} + +// retryInterceptor retries failed calls with exponential backoff +func retryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + delay := retryDelay + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + err := invoker(ctx, method, req, reply, cc, opts...) + if err == nil { + return nil + } + + lastErr = err + + // Don't wait after the last attempt + if attempt < maxRetries-1 { + select { + case <-time.After(delay): + // Exponential backoff: 2s → 4s → 8s → 16s → 30s (capped) + delay *= backoffFactor + if delay > maxRetryDelay { + delay = maxRetryDelay + } + case <-ctx.Done(): + return ctx.Err() + } + } + } + + return lastErr // Return the last error after all retries exhausted +} + // Close closes the gRPC connection. func (c *grpcConnection) Close() error { if c.conn != nil { diff --git a/pkg/lumera/connection_test.go b/pkg/lumera/connection_test.go new file mode 100644 index 00000000..ffacda3f --- /dev/null +++ b/pkg/lumera/connection_test.go @@ -0,0 +1,195 @@ +package lumera + +import ( + "context" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestNormaliseAddr(t *testing.T) { + tests := []struct { + name string + input string + expectedHost string + expectedTLS bool + expectedServer string + expectError bool + }{ + { + name: "https scheme", + input: "https://grpc.testnet.lumera.io", + expectedHost: "grpc.testnet.lumera.io:443", + expectedTLS: true, + expectedServer: "grpc.testnet.lumera.io", + expectError: false, + }, + { + name: "grpcs scheme with port", + input: "grpcs://grpc.node9x.com:7443", + expectedHost: "grpc.node9x.com:7443", + expectedTLS: true, + expectedServer: "grpc.node9x.com", + expectError: false, + }, + { + name: "host with port 443", + input: "grpc.node9x.com:443", + expectedHost: "grpc.node9x.com:443", + expectedTLS: true, + expectedServer: "grpc.node9x.com", + expectError: false, + }, + { + name: "host with custom port", + input: "grpc.node9x.com:9090", + expectedHost: "grpc.node9x.com:9090", + expectedTLS: false, + expectedServer: "grpc.node9x.com", + expectError: false, + }, + { + name: "host without port", + input: "grpc.testnet.lumera.io", + expectedHost: "grpc.testnet.lumera.io:9090", + expectedTLS: false, + expectedServer: "grpc.testnet.lumera.io", + expectError: false, + }, + { + name: "invalid scheme", + input: "ftp://invalid.com", + expectedHost: "", + expectedTLS: false, + expectedServer: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hostPort, useTLS, serverName, err := normaliseAddr(tt.input) + + if tt.expectError && err == nil { + t.Errorf("normaliseAddr(%s) expected error, got nil", tt.input) + return + } + + if !tt.expectError && err != nil { + t.Errorf("normaliseAddr(%s) unexpected error: %v", tt.input, err) + return + } + + if !tt.expectError { + if hostPort != tt.expectedHost { + t.Errorf("normaliseAddr(%s) hostPort = %s, want %s", tt.input, hostPort, tt.expectedHost) + } + if useTLS != tt.expectedTLS { + t.Errorf("normaliseAddr(%s) useTLS = %v, want %v", tt.input, useTLS, tt.expectedTLS) + } + if serverName != tt.expectedServer { + t.Errorf("normaliseAddr(%s) serverName = %s, want %s", tt.input, serverName, tt.expectedServer) + } + } + }) + } +} + +func TestGrpcConnectionMethods(t *testing.T) { + // Test with nil connection + conn := &grpcConnection{conn: nil} + + // Close should not panic with nil connection + err := conn.Close() + if err != nil { + t.Errorf("Close() with nil connection should return nil, got %v", err) + } + + // GetConn should return nil + grpcConn := conn.GetConn() + if grpcConn != nil { + t.Errorf("GetConn() with nil connection should return nil, got %v", grpcConn) + } +} + +func TestConnectionConstants(t *testing.T) { + // Test that our constants are reasonable + if keepaliveTime < 10*time.Second { + t.Errorf("keepaliveTime too short: %v", keepaliveTime) + } + + if keepaliveTimeout >= keepaliveTime { + t.Errorf("keepaliveTimeout should be less than keepaliveTime: %v >= %v", keepaliveTimeout, keepaliveTime) + } + + if retryDelay < 100*time.Millisecond { + t.Errorf("retryDelay too short: %v", retryDelay) + } + + if maxRetryDelay <= retryDelay { + t.Errorf("maxRetryDelay should be greater than retryDelay: %v <= %v", maxRetryDelay, retryDelay) + } + + if maxRetries < 1 { + t.Errorf("maxRetries should be at least 1: %v", maxRetries) + } + + if backoffFactor < 1 { + t.Errorf("backoffFactor should be at least 1: %v", backoffFactor) + } +} + +func TestRetryInterceptorSuccess(t *testing.T) { + attempts := 0 + + // Mock invoker that fails twice, then succeeds + mockInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + attempts++ + if attempts < 3 { + return status.Error(codes.DeadlineExceeded, "simulated timeout") + } + return nil // Success on 3rd attempt + } + + // Call retry interceptor + err := retryInterceptor(context.Background(), "/test", nil, nil, nil, mockInvoker) + + // Should succeed + if err != nil { + t.Errorf("Expected success after retries, got error: %v", err) + } + + if attempts != 3 { + t.Errorf("Expected 3 attempts, got %d", attempts) + } +} + +func TestRetryInterceptorContextCancellation(t *testing.T) { + attempts := 0 + + // Mock invoker that always fails + mockInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + attempts++ + return status.Error(codes.Unavailable, "simulated failure") + } + + // Context that cancels quickly + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Call retry interceptor + err := retryInterceptor(ctx, "/test", nil, nil, nil, mockInvoker) + + // Should return context error + if err != context.DeadlineExceeded { + t.Errorf("Expected context deadline exceeded, got: %v", err) + } + + // Should have made at least one attempt + if attempts < 1 { + t.Errorf("Expected at least 1 attempt, got %d", attempts) + } +}