Skip to content

Commit

Permalink
Gateway reactive config implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kralicky committed Nov 21, 2023
1 parent b8e3360 commit 6f364f9
Show file tree
Hide file tree
Showing 43 changed files with 1,252 additions and 1,142 deletions.
6 changes: 5 additions & 1 deletion pkg/bootstrap/client_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/x509"
"net"
"runtime"
"sync/atomic"
"time"

. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -65,7 +66,10 @@ var _ = Describe("Client V2", Ordered, Label("unit"), func() {
Certificates: []tls.Certificate{*cert},
})))

server := bootstrap.NewServerV2(store, cert.PrivateKey.(crypto.Signer))
signer := cert.PrivateKey.(crypto.Signer)
ptr := &atomic.Pointer[crypto.Signer]{}
ptr.Store(&signer)
server := bootstrap.NewServerV2(store, ptr)
bootstrapv2.RegisterBootstrapServer(srv, server)

listener, err := net.Listen("tcp4", "127.0.0.1:0")
Expand Down
7 changes: 4 additions & 3 deletions pkg/bootstrap/server_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto"
"fmt"
"strings"
"sync/atomic"

"maps"

Expand All @@ -27,12 +28,12 @@ import (

type ServerV2 struct {
bootstrapv2.UnsafeBootstrapServer
privateKey crypto.Signer
privateKey *atomic.Pointer[crypto.Signer]
storage Storage
clusterIdLocks storage.LockManager
}

func NewServerV2(storage Storage, privateKey crypto.Signer) *ServerV2 {
func NewServerV2(storage Storage, privateKey *atomic.Pointer[crypto.Signer]) *ServerV2 {
return &ServerV2{
privateKey: privateKey,
storage: storage,
Expand Down Expand Up @@ -84,7 +85,7 @@ func (h *ServerV2) Auth(ctx context.Context, authReq *bootstrapv2.BootstrapAuthR
// Remove "Bearer " from the header
bearerToken := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer"))
// Verify the token
payload, err := jws.Verify([]byte(bearerToken), jwa.EdDSA, h.privateKey.Public())
payload, err := jws.Verify([]byte(bearerToken), jwa.EdDSA, (*h.privateKey.Load()).Public())
if err != nil {
return nil, util.StatusError(codes.PermissionDenied)
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/bootstrap/server_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"errors"
"net"
"strings"
"sync/atomic"
"time"

"github.com/lestrrat-go/jwx/jwa"
Expand Down Expand Up @@ -65,13 +66,16 @@ var _ = Describe("Server V2", Ordered, Label("unit"), func() {
Expect(err).NotTo(HaveOccurred())
cert = &crt

ptr := &atomic.Pointer[crypto.Signer]{}
signer := cert.PrivateKey.(crypto.Signer)
ptr.Store(&signer)
srv := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
server := bootstrap.NewServerV2(bootstrap.StorageConfig{
TokenStore: mockTokenStore,
ClusterStore: mockClusterStore,
KeyringStoreBroker: mockKeyringStoreBroker,
LockManagerBroker: inmemory.NewLockManagerBroker(),
}, cert.PrivateKey.(crypto.Signer))
}, ptr)
bootstrapv2.RegisterBootstrapServer(srv, server)

listener := bufconn.Listen(1024 * 1024)
Expand Down
197 changes: 144 additions & 53 deletions pkg/gateway/connections.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,37 @@ package gateway
import (
"context"
"encoding/base64"
"fmt"
"io"
"log/slog"
"os"
"strings"
sync "sync"
"time"

corev1 "github.com/rancher/opni/pkg/apis/core/v1"
"github.com/rancher/opni/pkg/auth/cluster"
"github.com/rancher/opni/pkg/config/reactive"
configv1 "github.com/rancher/opni/pkg/config/v1"
"github.com/rancher/opni/pkg/logger"
"github.com/rancher/opni/pkg/storage"
"github.com/rancher/opni/pkg/storage/lock"
"github.com/rancher/opni/pkg/util"
"github.com/rancher/opni/pkg/util/streams"
"github.com/rancher/opni/pkg/versions"
"github.com/samber/lo"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protopath"
)

type TrackedConnectionListener interface {
// Called when a new agent connection to any gateway instance is tracked.
// The provided context will be canceled when the tracked connection is deleted.
// If a tracked connection is updated, this method will be called again with
// the same context, agentId, and leaseId, but updated instanceInfo.
// Implementations of this method MUST NOT block.
HandleTrackedConnection(ctx context.Context, agentId string, leaseId string, instanceInfo *corev1.InstanceInfo)
}
Expand Down Expand Up @@ -62,11 +71,11 @@ type activeTrackedInstance struct {
// All gateway instances track the same state and update in real time at the
// speed of the underlying storage backend.
type ConnectionTracker struct {
localInstanceInfo *corev1.InstanceInfo
rootContext context.Context
kv storage.KeyValueStore
lm storage.LockManager
logger *slog.Logger
rootContext context.Context
kv storage.KeyValueStore
lm storage.LockManager
mgr *configv1.GatewayConfigManager
logger *slog.Logger

listenersMu sync.Mutex
connListeners []TrackedConnectionListener
Expand All @@ -75,24 +84,39 @@ type ConnectionTracker struct {
mu sync.RWMutex
activeConnections map[string]*activeTrackedConnection
activeInstances map[string]*activeTrackedInstance

localInstanceInfoMu sync.Mutex
localInstanceInfo *corev1.InstanceInfo
// relayConf reactive.Reactive[*configv1.RelayServerSpec]
// managementConf reactive.Reactive[*configv1.ManagementServerSpec]
// dashboardConf reactive.Reactive[*configv1.DashboardServerSpec]
}

func NewConnectionTracker(
rootContext context.Context,
localInstanceInfo *corev1.InstanceInfo,
mgr *configv1.GatewayConfigManager,
kv storage.KeyValueStore,
lm storage.LockManager,
lg *slog.Logger,
) *ConnectionTracker {
return &ConnectionTracker{
localInstanceInfo: localInstanceInfo,
hostname, _ := os.Hostname()
ct := &ConnectionTracker{
localInstanceInfo: &corev1.InstanceInfo{
Annotations: map[string]string{
"hostname": hostname,
"pid": fmt.Sprint(os.Getpid()),
"version": versions.Version,
},
},
rootContext: rootContext,
mgr: mgr,
kv: kv,
lm: lm,
logger: lg,
activeConnections: make(map[string]*activeTrackedConnection),
activeInstances: make(map[string]*activeTrackedInstance),
}
return ct
}

func NewReadOnlyConnectionTracker(
Expand Down Expand Up @@ -153,14 +177,16 @@ func (ct *ConnectionTracker) AddTrackedInstanceListener(listener TrackedInstance
}

func (ct *ConnectionTracker) LocalInstanceInfo() *corev1.InstanceInfo {
ct.localInstanceInfoMu.Lock()
defer ct.localInstanceInfoMu.Unlock()
return util.ProtoClone(ct.localInstanceInfo)
}

// Starts the connection tracker. This will block until the context is canceled
// and the underlying kv store watcher is closed.
func (ct *ConnectionTracker) Run(ctx context.Context) error {
var wg sync.WaitGroup
if ct.localInstanceInfo != nil {
if ct.LocalInstanceInfo() != nil {
wg.Add(1)
go func() {
defer wg.Done()
Expand Down Expand Up @@ -199,16 +225,76 @@ func (ct *ConnectionTracker) Run(ctx context.Context) error {
// "$instances". This key can be used to track the set of all instances
// that are currently running.
func (ct *ConnectionTracker) lockInstance(ctx context.Context) {
instanceInfo := &corev1.InstanceInfo{
RelayAddress: ct.localInstanceInfo.GetRelayAddress(),
ManagementAddress: ct.localInstanceInfo.GetManagementAddress(),
GatewayAddress: ct.localInstanceInfo.GetGatewayAddress(),
WebAddress: ct.localInstanceInfo.GetWebAddress(),
}
ctx, ca := context.WithCancel(ctx)
defer ca()

locker := ct.lm.Locker(instancesKey, lock.WithAcquireContext(ctx),
lock.WithInitialValue(base64.StdEncoding.EncodeToString(util.Must(proto.Marshal(ct.LocalInstanceInfo())))))

updateInstanceInfo, cancel := lo.NewDebounce(100*time.Millisecond, func() {
ct.kv.Put(ctx, locker.Key(), util.Must(proto.Marshal(ct.LocalInstanceInfo())))
})
defer cancel()

listenOnce := sync.OnceFunc(func() {
relay := reactive.Message[*configv1.RelayServerSpec](ct.mgr.Reactive(protopath.Path(configv1.ProtoPath().Relay())))
mgmt := reactive.Message[*configv1.ManagementServerSpec](ct.mgr.Reactive(protopath.Path(configv1.ProtoPath().Management())))
server := reactive.Message[*configv1.ServerSpec](ct.mgr.Reactive(protopath.Path(configv1.ProtoPath().Server())))
dashboard := reactive.Message[*configv1.DashboardServerSpec](ct.mgr.Reactive(protopath.Path(configv1.ProtoPath().Dashboard())))

relay.WatchFunc(ctx, func(msg *configv1.RelayServerSpec) {
ct.localInstanceInfoMu.Lock()
defer ct.localInstanceInfoMu.Unlock()
if addr := msg.GetAdvertiseAddress(); addr != "" {
ct.localInstanceInfo.RelayAddress = addr
} else {
ct.logger.Warn("relay advertise address not set; will advertise the listen address")
ct.localInstanceInfo.RelayAddress = msg.GetGrpcListenAddress()
}
updateInstanceInfo()
})

mgmt.WatchFunc(ctx, func(msg *configv1.ManagementServerSpec) {
ct.localInstanceInfoMu.Lock()
defer ct.localInstanceInfoMu.Unlock()
if addr := msg.GetAdvertiseAddress(); addr != "" {
ct.localInstanceInfo.ManagementAddress = addr
} else {
ct.logger.Warn("management advertise address not set; will advertise the listen address")
ct.localInstanceInfo.ManagementAddress = msg.GetGrpcListenAddress()
}
updateInstanceInfo()
})

server.WatchFunc(ctx, func(msg *configv1.ServerSpec) {
ct.localInstanceInfoMu.Lock()
defer ct.localInstanceInfoMu.Unlock()
if addr := msg.GetAdvertiseAddress(); addr != "" {
ct.localInstanceInfo.GatewayAddress = addr
} else {
ct.logger.Warn("gateway advertise address not set; will advertise the listen address")
ct.localInstanceInfo.GatewayAddress = msg.GetGrpcListenAddress()
}
updateInstanceInfo()
})

dashboard.WatchFunc(ctx, func(msg *configv1.DashboardServerSpec) {
ct.localInstanceInfoMu.Lock()
defer ct.localInstanceInfoMu.Unlock()
if addr := msg.GetAdvertiseAddress(); addr != "" {
ct.localInstanceInfo.WebAddress = addr
} else {
ct.logger.Warn("web advertise address not set; will advertise the listen address")
ct.localInstanceInfo.WebAddress = msg.GetHttpListenAddress()
}
updateInstanceInfo()
})
})
for ctx.Err() == nil {
locker := ct.lm.Locker(instancesKey, lock.WithAcquireContext(ctx),
lock.WithInitialValue(base64.StdEncoding.EncodeToString(util.Must(proto.Marshal(instanceInfo)))))
lock.WithInitialValue(base64.StdEncoding.EncodeToString(util.Must(proto.Marshal(ct.LocalInstanceInfo())))))
locker.Lock()
listenOnce()
}
}

Expand All @@ -219,12 +305,7 @@ var instanceInfoKey = instanceInfoKeyType{}
func (ct *ConnectionTracker) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
agentId := cluster.StreamAuthorizedID(ss.Context())
instanceInfo := &corev1.InstanceInfo{
RelayAddress: ct.localInstanceInfo.GetRelayAddress(),
ManagementAddress: ct.localInstanceInfo.GetManagementAddress(),
GatewayAddress: ct.localInstanceInfo.GetGatewayAddress(),
WebAddress: ct.localInstanceInfo.GetWebAddress(),
}
instanceInfo := ct.LocalInstanceInfo()
locker := ct.lm.Locker(agentId,
lock.WithAcquireContext(ss.Context()),
lock.WithInitialValue(base64.StdEncoding.EncodeToString(util.Must(proto.Marshal(instanceInfo)))),
Expand Down Expand Up @@ -295,54 +376,62 @@ func (ct *ConnectionTracker) handleConnectionEventLocked(event storage.WatchEven
"agentId", agentId,
"key", event.Current.Key(),
)
info, err := instanceInfo()
if err != nil {
lg.With(logger.Err(err)).Error("failed to unmarshal instance info")
return
}
isUpdate := false
if conn, ok := ct.activeConnections[agentId]; ok {
if conn.leaseId == leaseId {
// key was updated, not a new connection
lg.Debug("tracked connection updated")
return
}
info, err := instanceInfo()
if err != nil {
lg.With(logger.Err(err)).Error("failed to unmarshal instance info")
return
}
if !info.GetAcquired() {
// a different instance is only attempting to acquire the lock,
// ignore the event
ct.logger.With("agent", agentId, "instance", info.GetRelayAddress()).Debug("observed lock attempt from another instance")
return
}
// a different instance has acquired the lock, invalidate
// the current tracked connection
ct.logger.With("agentId", agentId).Debug("tracked connection invalidated")
conn.cancelTrackingContext()
delete(ct.activeConnections, agentId)
}
info, err := instanceInfo()
if err != nil {
lg.With(logger.Err(err)).Error("failed to unmarshal instance info")
return
if conn.leaseId == leaseId {
// key was updated, not a new connection
isUpdate = true
} else {
// a different instance has acquired the lock, invalidate
// the current tracked connection
ct.logger.With("agentId", agentId).Debug("tracked connection invalidated")
conn.cancelTrackingContext()
delete(ct.activeConnections, agentId)
}
}
if !info.GetAcquired() {
return // ignore unacquired connections
}
if ct.IsLocalInstance(info) {
ct.logger.With("agentId", agentId).Debug("tracking new connection (local)")
} else if isUpdate {
ct.logger.With("agentId", agentId).Debug("tracked connection updated")
} else {
ct.logger.With("agentId", agentId).Debug("tracking new connection")
}
ctx, cancel := context.WithCancel(ct.rootContext)
conn := &activeTrackedConnection{
agentId: agentId,
leaseId: leaseId,
revision: event.Current.Revision(),
instanceInfo: info,
trackingContext: ctx,
cancelTrackingContext: cancel,

var trackingContext context.Context
if !isUpdate {
ctx, cancel := context.WithCancel(ct.rootContext)
trackingContext = ctx
conn := &activeTrackedConnection{
agentId: agentId,
leaseId: leaseId,
revision: event.Current.Revision(),
instanceInfo: info,
trackingContext: ctx,
cancelTrackingContext: cancel,
}
ct.activeConnections[agentId] = conn
} else {
conn := ct.activeConnections[agentId]
conn.revision = event.Current.Revision()
conn.instanceInfo = info
trackingContext = conn.trackingContext
}
ct.activeConnections[agentId] = conn
for _, listener := range ct.connListeners {
listener.HandleTrackedConnection(ctx, agentId, leaseId, info)
listener.HandleTrackedConnection(trackingContext, agentId, leaseId, info)
}
case storage.WatchEventDelete:
agentId, leaseId, ok := decodeKey(event.Previous.Key())
Expand Down Expand Up @@ -389,7 +478,9 @@ func (ct *ConnectionTracker) handleInstanceEventLocked(event storage.WatchEvent[
ct.logger.With(logger.Err(err)).Error("failed to unmarshal instance info")
return
}
if _, ok := ct.activeInstances[leaseId]; ok {
if existing, ok := ct.activeInstances[leaseId]; ok {
ct.logger.With("leaseId", leaseId).Debug("tracked instance updated")
existing.instanceInfo = instanceInfo
return
}
ct.logger.With("leaseId", leaseId).Debug("tracking new instance")
Expand Down
Loading

0 comments on commit 6f364f9

Please sign in to comment.