Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions pkg/lumera/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
)

const (
defaultTimeout = 30 * time.Second
defaultLumeraPort = "9090"

keepaliveTime = 30 * time.Second
keepaliveTimeout = 10 * time.Second
retryDelay = 2 * time.Second
maxRetryDelay = 30 * time.Second
maxRetries = 5
backoffFactor = 2
)

// Connection defines the interface for a client connection.
Expand Down Expand Up @@ -45,15 +52,7 @@ func newGRPCConnection(ctx context.Context, rawAddr string) (Connection, error)
creds = insecure.NewCredentials()
}

dialCtx, cancel := context.WithTimeout(ctx, defaultTimeout)
defer cancel()

conn, err := grpc.DialContext(
dialCtx,
hostPort,
grpc.WithTransportCredentials(creds),
grpc.WithBlock(),
)
conn, err := createGRPCConnection(ctx, hostPort, creds)
if err != nil {
return nil, fmt.Errorf("failed to connect to gRPC server: %w", err)
}
Expand Down Expand Up @@ -109,6 +108,53 @@ func normaliseAddr(raw string) (hostPort string, useTLS bool, serverName string,
return net.JoinHostPort(host, port), false, host, nil
}

// createGRPCConnection creates a gRPC connection with keepalive and retry interceptor
func createGRPCConnection(ctx context.Context, hostPort string, creds credentials.TransportCredentials) (*grpc.ClientConn, error) {
_ = ctx // Keeping this for api compatibility
Comment thread
mateeullahmalik marked this conversation as resolved.
opts := []grpc.DialOption{
grpc.WithTransportCredentials(creds),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: keepaliveTime,
Timeout: keepaliveTimeout,
PermitWithoutStream: true,
}),
grpc.WithUnaryInterceptor(retryInterceptor),
}

return grpc.NewClient(hostPort, opts...)
}

// retryInterceptor retries failed calls with exponential backoff
func retryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
delay := retryDelay
var lastErr error

for attempt := 0; attempt < maxRetries; attempt++ {
err := invoker(ctx, method, req, reply, cc, opts...)
if err == nil {
return nil
}

lastErr = err

// Don't wait after the last attempt
if attempt < maxRetries-1 {
select {
case <-time.After(delay):
// Exponential backoff: 2s → 4s → 8s → 16s → 30s (capped)
delay *= backoffFactor
if delay > maxRetryDelay {
delay = maxRetryDelay
}
case <-ctx.Done():
return ctx.Err()
}
}
}

return lastErr // Return the last error after all retries exhausted
}

// Close closes the gRPC connection.
func (c *grpcConnection) Close() error {
if c.conn != nil {
Expand Down
195 changes: 195 additions & 0 deletions pkg/lumera/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package lumera

import (
"context"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func TestNormaliseAddr(t *testing.T) {
tests := []struct {
name string
input string
expectedHost string
expectedTLS bool
expectedServer string
expectError bool
}{
{
name: "https scheme",
input: "https://grpc.testnet.lumera.io",
expectedHost: "grpc.testnet.lumera.io:443",
expectedTLS: true,
expectedServer: "grpc.testnet.lumera.io",
expectError: false,
},
{
name: "grpcs scheme with port",
input: "grpcs://grpc.node9x.com:7443",
expectedHost: "grpc.node9x.com:7443",
expectedTLS: true,
expectedServer: "grpc.node9x.com",
expectError: false,
},
{
name: "host with port 443",
input: "grpc.node9x.com:443",
expectedHost: "grpc.node9x.com:443",
expectedTLS: true,
expectedServer: "grpc.node9x.com",
expectError: false,
},
{
name: "host with custom port",
input: "grpc.node9x.com:9090",
expectedHost: "grpc.node9x.com:9090",
expectedTLS: false,
expectedServer: "grpc.node9x.com",
expectError: false,
},
{
name: "host without port",
input: "grpc.testnet.lumera.io",
expectedHost: "grpc.testnet.lumera.io:9090",
expectedTLS: false,
expectedServer: "grpc.testnet.lumera.io",
expectError: false,
},
{
name: "invalid scheme",
input: "ftp://invalid.com",
expectedHost: "",
expectedTLS: false,
expectedServer: "",
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hostPort, useTLS, serverName, err := normaliseAddr(tt.input)

if tt.expectError && err == nil {
t.Errorf("normaliseAddr(%s) expected error, got nil", tt.input)
return
}

if !tt.expectError && err != nil {
t.Errorf("normaliseAddr(%s) unexpected error: %v", tt.input, err)
return
}

if !tt.expectError {
if hostPort != tt.expectedHost {
t.Errorf("normaliseAddr(%s) hostPort = %s, want %s", tt.input, hostPort, tt.expectedHost)
}
if useTLS != tt.expectedTLS {
t.Errorf("normaliseAddr(%s) useTLS = %v, want %v", tt.input, useTLS, tt.expectedTLS)
}
if serverName != tt.expectedServer {
t.Errorf("normaliseAddr(%s) serverName = %s, want %s", tt.input, serverName, tt.expectedServer)
}
}
})
}
}

func TestGrpcConnectionMethods(t *testing.T) {
// Test with nil connection
conn := &grpcConnection{conn: nil}

// Close should not panic with nil connection
err := conn.Close()
if err != nil {
t.Errorf("Close() with nil connection should return nil, got %v", err)
}

// GetConn should return nil
grpcConn := conn.GetConn()
if grpcConn != nil {
t.Errorf("GetConn() with nil connection should return nil, got %v", grpcConn)
}
}

func TestConnectionConstants(t *testing.T) {
// Test that our constants are reasonable
if keepaliveTime < 10*time.Second {
t.Errorf("keepaliveTime too short: %v", keepaliveTime)
}

if keepaliveTimeout >= keepaliveTime {
t.Errorf("keepaliveTimeout should be less than keepaliveTime: %v >= %v", keepaliveTimeout, keepaliveTime)
}

if retryDelay < 100*time.Millisecond {
t.Errorf("retryDelay too short: %v", retryDelay)
}

if maxRetryDelay <= retryDelay {
t.Errorf("maxRetryDelay should be greater than retryDelay: %v <= %v", maxRetryDelay, retryDelay)
}

if maxRetries < 1 {
t.Errorf("maxRetries should be at least 1: %v", maxRetries)
}

if backoffFactor < 1 {
t.Errorf("backoffFactor should be at least 1: %v", backoffFactor)
}
}

func TestRetryInterceptorSuccess(t *testing.T) {
attempts := 0

// Mock invoker that fails twice, then succeeds
mockInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
attempts++
if attempts < 3 {
return status.Error(codes.DeadlineExceeded, "simulated timeout")
}
return nil // Success on 3rd attempt
}

// Call retry interceptor
err := retryInterceptor(context.Background(), "/test", nil, nil, nil, mockInvoker)

// Should succeed
if err != nil {
t.Errorf("Expected success after retries, got error: %v", err)
}

if attempts != 3 {
t.Errorf("Expected 3 attempts, got %d", attempts)
}
}

func TestRetryInterceptorContextCancellation(t *testing.T) {
attempts := 0

// Mock invoker that always fails
mockInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
attempts++
return status.Error(codes.Unavailable, "simulated failure")
}

// Context that cancels quickly
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

// Call retry interceptor
err := retryInterceptor(ctx, "/test", nil, nil, nil, mockInvoker)

// Should return context error
if err != context.DeadlineExceeded {
t.Errorf("Expected context deadline exceeded, got: %v", err)
}

// Should have made at least one attempt
if attempts < 1 {
t.Errorf("Expected at least 1 attempt, got %d", attempts)
}
}