From 0b095056c88e14c95e8e3fdb114fe8d64accf57e Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 6 Feb 2024 12:38:10 -0500 Subject: [PATCH] Pass our configured random reader into nodee functions --- .../daemon/cluster/handlers/option_test.go | 38 +++++++++++++++++++ internal/daemon/cluster/handlers/options.go | 17 ++++++++- internal/daemon/cluster/handlers/testing.go | 3 +- .../upstream_message_service_controller.go | 10 +++-- .../handlers/upstream_message_service_test.go | 2 +- .../upstream_message_service_worker.go | 7 ++-- internal/daemon/controller/controller.go | 3 +- internal/daemon/controller/handler.go | 10 ++++- .../handlers/workers/worker_service.go | 17 +++++++-- .../handlers/workers/worker_service_test.go | 35 ++++++++--------- internal/daemon/controller/listeners.go | 1 + .../daemon/controller/rpc_registration.go | 3 +- internal/daemon/worker/auth_rotation.go | 1 + .../daemon/worker/controller_connection.go | 1 + internal/daemon/worker/listeners.go | 3 +- internal/daemon/worker/rpc_registration.go | 1 + internal/daemon/worker/worker.go | 2 +- internal/server/job/options.go | 11 ++++++ internal/server/job/rotate_roots_job.go | 6 ++- internal/server/options.go | 11 ++++++ internal/server/options_test.go | 8 ++++ internal/server/repository.go | 11 +++++- internal/server/repository_worker.go | 17 +++++++-- internal/server/repository_workerauth.go | 2 +- 24 files changed, 176 insertions(+), 44 deletions(-) create mode 100644 internal/daemon/cluster/handlers/option_test.go diff --git a/internal/daemon/cluster/handlers/option_test.go b/internal/daemon/cluster/handlers/option_test.go new file mode 100644 index 0000000000..6c0b99c35c --- /dev/null +++ b/internal/daemon/cluster/handlers/option_test.go @@ -0,0 +1,38 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package handlers + +import ( + "context" + "crypto/rand" + "io" + "testing" + + "github.com/hashicorp/nodeenrollment/storage/inmem" + "github.com/hashicorp/nodeenrollment/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_GetOpts provides unit tests for GetOpts and all the options +func Test_GetOpts(t *testing.T) { + t.Parallel() + + t.Run("withKeyProducer", func(t *testing.T) { + storage, err := inmem.New(context.Background()) + require.NoError(t, err) + creds, err := types.NewNodeCredentials(context.Background(), storage) + require.NoError(t, err) + opts := getOpts(WithKeyProducer(creds)) + testOpts := getDefaultOptions() + assert.Equal(t, nil, testOpts.withKeyProducer) + assert.Equal(t, creds, opts.withKeyProducer) + }) + t.Run("WithRandomReader", func(t *testing.T) { + opts := getOpts(WithRandomReader(io.LimitReader(nil, 0))) + testOpts := getDefaultOptions() + assert.Equal(t, rand.Reader, testOpts.withRandomReader) + assert.NotEqual(t, rand.Reader, opts.withRandomReader) + }) +} diff --git a/internal/daemon/cluster/handlers/options.go b/internal/daemon/cluster/handlers/options.go index ece877774c..81c9f88779 100644 --- a/internal/daemon/cluster/handlers/options.go +++ b/internal/daemon/cluster/handlers/options.go @@ -4,6 +4,9 @@ package handlers import ( + "crypto/rand" + "io" + "github.com/hashicorp/nodeenrollment" ) @@ -21,11 +24,14 @@ type Option func(*options) // options = how options are represented type options struct { - withKeyProducer nodeenrollment.X25519KeyProducer + withKeyProducer nodeenrollment.X25519KeyProducer + withRandomReader io.Reader } func getDefaultOptions() options { - return options{} + return options{ + withRandomReader: rand.Reader, + } } // WithKeyProducer provides an option types.NodeInformation @@ -34,3 +40,10 @@ func WithKeyProducer(nodeInfo nodeenrollment.X25519KeyProducer) Option { o.withKeyProducer = nodeInfo } } + +// WithRandomReader provides an option to specify a specific random source +func WithRandomReader(with io.Reader) Option { + return func(o *options) { + o.withRandomReader = with + } +} diff --git a/internal/daemon/cluster/handlers/testing.go b/internal/daemon/cluster/handlers/testing.go index ded05ae5da..b74563c5cd 100644 --- a/internal/daemon/cluster/handlers/testing.go +++ b/internal/daemon/cluster/handlers/testing.go @@ -5,6 +5,7 @@ package handlers import ( "context" + "crypto/rand" "net" "sync" "testing" @@ -58,7 +59,7 @@ func TestUpstreamService(t *testing.T) (UpstreamMessageServiceClientProducer, *t require.NoError(t, err) // start an upstream controller - testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage) + testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage, rand.Reader) require.NoError(t, err) require.NotNil(t, testController) diff --git a/internal/daemon/cluster/handlers/upstream_message_service_controller.go b/internal/daemon/cluster/handlers/upstream_message_service_controller.go index 9626ff0793..152c3b7942 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_controller.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_controller.go @@ -6,6 +6,7 @@ package handlers import ( "context" "fmt" + "io" "sync" "github.com/hashicorp/boundary/internal/errors" @@ -79,7 +80,8 @@ func getUpstreamMessageHandler(ctx context.Context, msgType pbs.MsgType) (Upstre // UpstreamMessageServiceServer for OSS controllers type controllerUpstreamMessageServiceServer struct { pbs.UnimplementedUpstreamMessageServiceServer - storage nodeenrollment.Storage + storage nodeenrollment.Storage + randReader io.Reader } var _ pbs.UpstreamMessageServiceServer = (*controllerUpstreamMessageServiceServer)(nil) @@ -90,6 +92,7 @@ var _ pbs.UpstreamMessageServiceServer = (*controllerUpstreamMessageServiceServe func NewControllerUpstreamMessageServiceServer( ctx context.Context, storage nodeenrollment.Storage, + randReader io.Reader, ) (pbs.UpstreamMessageServiceServer, error) { const op = "handlers.NewControllerUpstreamMessageServiceServer" switch { @@ -98,7 +101,8 @@ func NewControllerUpstreamMessageServiceServer( } return &controllerUpstreamMessageServiceServer{ - storage: storage, + storage: storage, + randReader: randReader, }, nil } @@ -186,7 +190,7 @@ func (s *controllerUpstreamMessageServiceServer) UpstreamMessage(ctx context.Con }, }, nil default: - ct, err := nodeenrollment.EncryptMessage(ctx, respMsg, nodeInfo) + ct, err := nodeenrollment.EncryptMessage(ctx, respMsg, nodeInfo, nodeenrollment.WithRandomReader(s.randReader)) if err != nil { return nil, status.Errorf(codes.Internal, "%s: error encrypting response: %v", op, err) } diff --git a/internal/daemon/cluster/handlers/upstream_message_service_test.go b/internal/daemon/cluster/handlers/upstream_message_service_test.go index 34bb1382ec..40a167a4d9 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_test.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_test.go @@ -123,7 +123,7 @@ func Test_controllerUpstreamMessageServiceServer_UpstreamMessage(t *testing.T) { nodeInfo, err := types.LoadNodeInformation(testCtx, initStorage, initKeyId) require.NoError(t, err) // define a test controller - testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage) + testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage, nil) require.NoError(t, err) require.NotNil(t, testController) diff --git a/internal/daemon/cluster/handlers/upstream_message_service_worker.go b/internal/daemon/cluster/handlers/upstream_message_service_worker.go index 809b90356c..bc6468789c 100644 --- a/internal/daemon/cluster/handlers/upstream_message_service_worker.go +++ b/internal/daemon/cluster/handlers/upstream_message_service_worker.go @@ -6,6 +6,7 @@ package handlers import ( "context" "fmt" + "io" "sync" "github.com/hashicorp/boundary/internal/errors" @@ -153,7 +154,7 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ if opts.withKeyProducer == nil { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing node information required for encrypting unwrap keys message") } - req, err = ctMsg(ctx, opts.withKeyProducer, msgType, msg) + req, err = ctMsg(ctx, opts.withKeyProducer, msgType, msg, opts.withRandomReader) if err != nil { return nil, errors.Wrap(ctx, err, op) } @@ -202,9 +203,9 @@ func ptMsg(ctx context.Context, msgType pbs.MsgType, msg proto.Message) (*pbs.Up }, nil } -func ctMsg(ctx context.Context, keySource nodeenrollment.X25519KeyProducer, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) { +func ctMsg(ctx context.Context, keySource nodeenrollment.X25519KeyProducer, msgType pbs.MsgType, msg proto.Message, randomReader io.Reader) (*pbs.UpstreamMessageRequest, error) { const op = "handlers.encryptMsg" - ct, err := nodeenrollment.EncryptMessage(ctx, msg, keySource) + ct, err := nodeenrollment.EncryptMessage(ctx, msg, keySource, nodeenrollment.WithRandomReader(randomReader)) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error encrypting upstream message")) } diff --git a/internal/daemon/controller/controller.go b/internal/daemon/controller/controller.go index cd096c36af..f8f987fc83 100644 --- a/internal/daemon/controller/controller.go +++ b/internal/daemon/controller/controller.go @@ -419,7 +419,7 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { return host.NewCatalogRepository(ctx, dbase, dbase) } c.ServersRepoFn = func() (*server.Repository, error) { - return server.NewRepository(ctx, dbase, dbase, c.kms) + return server.NewRepository(ctx, dbase, dbase, c.kms, server.WithRandomReader(c.conf.SecureRandomReader)) } c.OidcRepoFn = func() (*oidc.Repository, error) { return oidc.NewRepository(ctx, dbase, dbase, c.kms) @@ -596,6 +596,7 @@ func (c *Controller) registerJobs() error { serverJobOpts = append(serverJobOpts, serversjob.WithCertificateLifetime(c.conf.TestOverrideWorkerAuthCaCertificateLifetime), serversjob.WithRotationFrequency(c.conf.TestOverrideWorkerAuthCaCertificateLifetime/2), + serversjob.WithRandomReader(c.conf.SecureRandomReader), ) } if err := serversjob.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, serverJobOpts...); err != nil { diff --git a/internal/daemon/controller/handler.go b/internal/daemon/controller/handler.go index 299198a02a..e813a045bb 100644 --- a/internal/daemon/controller/handler.go +++ b/internal/daemon/controller/handler.go @@ -325,8 +325,14 @@ func (c *Controller) registerGrpcServices(s *grpc.Server) error { services.RegisterCredentialLibraryServiceServer(s, cl) } if _, ok := currentServices[services.WorkerService_ServiceDesc.ServiceName]; !ok { - ws, err := workers.NewService(c.baseContext, c.ServersRepoFn, c.IamRepoFn, c.WorkerAuthRepoStorageFn, - c.downstreamWorkers) + ws, err := workers.NewService( + c.baseContext, + c.ServersRepoFn, + c.IamRepoFn, + c.WorkerAuthRepoStorageFn, + c.downstreamWorkers, + c.conf.SecureRandomReader, + ) if err != nil { return fmt.Errorf("failed to create worker handler service: %w", err) } diff --git a/internal/daemon/controller/handlers/workers/worker_service.go b/internal/daemon/controller/handlers/workers/worker_service.go index 16a1d154c5..2a2eb1da41 100644 --- a/internal/daemon/controller/handlers/workers/worker_service.go +++ b/internal/daemon/controller/handlers/workers/worker_service.go @@ -9,6 +9,7 @@ import ( "encoding/hex" stderrors "errors" "fmt" + "io" "strings" "github.com/hashicorp/boundary/globals" @@ -29,6 +30,7 @@ import ( "github.com/hashicorp/boundary/internal/util" pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/workers" "github.com/hashicorp/go-secure-stdlib/strutil" + "github.com/hashicorp/nodeenrollment" "github.com/hashicorp/nodeenrollment/types" "github.com/mr-tron/base58" "google.golang.org/grpc/codes" @@ -97,13 +99,20 @@ type Service struct { workerAuthFn common.WorkerAuthRepoStorageFactory iamRepoFn common.IamRepoFactory downstreams common.Downstreamers + + randReader io.Reader } var _ pbs.WorkerServiceServer = (*Service)(nil) // NewService returns a worker service which handles worker related requests to boundary. -func NewService(ctx context.Context, repo common.ServersRepoFactory, iamRepoFn common.IamRepoFactory, - workerAuthFn common.WorkerAuthRepoStorageFactory, ds common.Downstreamers, +func NewService( + ctx context.Context, + repo common.ServersRepoFactory, + iamRepoFn common.IamRepoFactory, + workerAuthFn common.WorkerAuthRepoStorageFactory, + ds common.Downstreamers, + randReader io.Reader, ) (Service, error) { const op = "workers.NewService" if repo == nil { @@ -115,7 +124,7 @@ func NewService(ctx context.Context, repo common.ServersRepoFactory, iamRepoFn c if workerAuthFn == nil { return Service{}, errors.New(ctx, errors.InvalidParameter, op, "missing worker auth repository") } - return Service{repoFn: repo, iamRepoFn: iamRepoFn, workerAuthFn: workerAuthFn, downstreams: ds}, nil + return Service{repoFn: repo, iamRepoFn: iamRepoFn, workerAuthFn: workerAuthFn, downstreams: ds, randReader: randReader}, nil } // ListWorkers implements the interface pbs.WorkerServiceServer. @@ -565,7 +574,7 @@ func (s Service) ReinitializeCertificateAuthority(ctx context.Context, req *pbs. return nil, err } - rootCerts, err := server.ReinitializeRoots(ctx, repo) + rootCerts, err := server.ReinitializeRoots(ctx, repo, nodeenrollment.WithRandomReader(s.randReader)) if err != nil { return nil, err } diff --git a/internal/daemon/controller/handlers/workers/worker_service_test.go b/internal/daemon/controller/handlers/workers/worker_service_test.go index 8ff9e092dc..90dd7d5d8b 100644 --- a/internal/daemon/controller/handlers/workers/worker_service_test.go +++ b/internal/daemon/controller/handlers/workers/worker_service_test.go @@ -5,6 +5,7 @@ package workers import ( "context" + "crypto/rand" "fmt" "sort" "strings" @@ -284,7 +285,7 @@ func TestGet(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil) + s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err, "Couldn't create new worker service.") got, err := s.GetWorker(auth.DisabledAuthTestContext(iamRepoFn, tc.scopeId), tc.req) @@ -424,7 +425,7 @@ func TestList(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil) + s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(err, "Couldn't create new worker service.") // Test with a non-anon user @@ -484,7 +485,7 @@ func TestDelete(t *testing.T) { return workerAuthRepo, nil } - s, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + s, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err, "Error when getting new worker service.") wUnmanaged := server.TestKmsWorker(t, conn, wrap, server.WithWorkerTags(&server.Tag{ @@ -614,7 +615,7 @@ func TestUpdate(t *testing.T) { Id: wkr.GetPublicId(), } } - workerService, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + workerService, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err) expectedScope := &scopes.ScopeInfo{Id: scope.Global.String(), Type: scope.Global.String(), Name: scope.Global.String(), Description: "Global Scope"} @@ -1103,7 +1104,7 @@ func TestUpdate_DeprecatedKMS(t *testing.T) { toMerge := &pbs.UpdateWorkerRequest{ Id: wkr.GetPublicId(), } - workerService, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + workerService, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err) cases := []struct { @@ -1188,7 +1189,7 @@ func TestUpdate_BadVersion(t *testing.T) { return repo, nil } - workerService, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + workerService, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err, "Failed to create a new host set service.") wkr := server.TestPkiWorker(t, conn, wrapper) @@ -1226,7 +1227,7 @@ func TestCreateWorkerLed(t *testing.T) { return workerAuthRepo, nil } - testSrv, err := NewService(testCtx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + testSrv, err := NewService(testCtx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err, "Error when getting new worker service.") // Get an initial set of authorized node credentials @@ -1478,7 +1479,7 @@ func TestCreateWorkerLed(t *testing.T) { repoFn := func() (*server.Repository, error) { return server.NewRepository(testCtx, rw, &db.Db{}, testKms) } - testSrv, err := NewService(testCtx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + testSrv, err := NewService(testCtx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err, "Error when getting new worker service.") return testSrv }(), @@ -1511,7 +1512,7 @@ func TestCreateWorkerLed(t *testing.T) { return server.NewRepository(testCtx, rw, rw, testKms) } } - testSrv, err := NewService(testCtx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + testSrv, err := NewService(testCtx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(t, err, "Error when getting new worker service.") return testSrv }(), @@ -1611,7 +1612,7 @@ func TestCreateControllerLed(t *testing.T) { return rootStorage, nil } - testSrv, err := NewService(testCtx, repoFn, iamRepoFn, authRepoFn, nil) + testSrv, err := NewService(testCtx, repoFn, iamRepoFn, authRepoFn, nil, rand.Reader) require.NoError(t, err, "Error when getting new worker service.") // Get an initial set of authorized node credentials @@ -1825,7 +1826,7 @@ func TestCreateControllerLed(t *testing.T) { repoFn := func() (*server.Repository, error) { return server.NewRepository(testCtx, rw, &db.Db{}, testKms) } - testSrv, err := NewService(testCtx, repoFn, iamRepoFn, authRepoFn, nil) + testSrv, err := NewService(testCtx, repoFn, iamRepoFn, authRepoFn, nil, rand.Reader) require.NoError(t, err, "Error when getting new worker service.") return testSrv }(), @@ -1857,7 +1858,7 @@ func TestCreateControllerLed(t *testing.T) { return server.NewRepository(testCtx, rw, rw, testKms) } } - testSrv, err := NewService(testCtx, repoFn, iamRepoFn, authRepoFn, nil) + testSrv, err := NewService(testCtx, repoFn, iamRepoFn, authRepoFn, nil, rand.Reader) require.NoError(t, err, "Error when getting new worker service.") return testSrv }(), @@ -1959,7 +1960,7 @@ func TestService_AddWorkerTags(t *testing.T) { workerAuthRepoFn := func() (*server.WorkerAuthRepositoryStorage, error) { return workerAuthRepo, nil } - s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil) + s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(err) worker := server.TestKmsWorker(t, conn, wrapper) @@ -2119,7 +2120,7 @@ func TestService_SetWorkerTags(t *testing.T) { workerAuthRepoFn := func() (*server.WorkerAuthRepositoryStorage, error) { return workerAuthRepo, nil } - s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil) + s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(err) worker := server.TestKmsWorker(t, conn, wrapper) @@ -2282,7 +2283,7 @@ func TestService_RemoveWorkerTags(t *testing.T) { workerAuthRepoFn := func() (*server.WorkerAuthRepositoryStorage, error) { return workerAuthRepo, nil } - s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil) + s, err := NewService(context.Background(), repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(err) worker := server.TestKmsWorker(t, conn, wrapper) @@ -2481,7 +2482,7 @@ func TestReadCertificateAuthority(t *testing.T) { _, err = rotation.RotateRootCertificates(ctx, workerAuthRepo) require.NoError(err) - testSrv, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + testSrv, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(err, "Error when getting new worker service.") tests := []struct { @@ -2557,7 +2558,7 @@ func TestReinitializeCertificateAuthority(t *testing.T) { _, err = rotation.RotateRootCertificates(ctx, workerAuthRepo) require.NoError(err) - testSrv, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil) + testSrv, err := NewService(ctx, repoFn, iamRepoFn, workerAuthRepoFn, nil, rand.Reader) require.NoError(err, "Error when getting new worker service.") tests := []struct { diff --git a/internal/daemon/controller/listeners.go b/internal/daemon/controller/listeners.go index a008e8feb7..6043833792 100644 --- a/internal/daemon/controller/listeners.go +++ b/internal/daemon/controller/listeners.go @@ -169,6 +169,7 @@ func (c *Controller) configureForCluster(ln *base.ServerListener) (func(), error Options: []nodee.Option{ nodee.WithLogger(eventLogger), nodee.WithRegistrationWrapper(wrapperToUse), + nodee.WithRandomReader(c.conf.SecureRandomReader), }, }) if err != nil { diff --git a/internal/daemon/controller/rpc_registration.go b/internal/daemon/controller/rpc_registration.go index 8e9da95eaa..4af11f4474 100644 --- a/internal/daemon/controller/rpc_registration.go +++ b/internal/daemon/controller/rpc_registration.go @@ -104,6 +104,7 @@ func registerControllerMultihopService(ctx context.Context, c *Controller, serve workerAuthStorage, true, nil, + nodeenrollment.WithRandomReader(c.conf.SecureRandomReader), ) if err != nil { return fmt.Errorf("%s: error creating multihop service handler: %w", op, err) @@ -132,7 +133,7 @@ func registerControllerUpstreamMessageService(ctx context.Context, c *Controller return fmt.Errorf("%s: worker auth repository storage func is unset", op) } - upstreamMsgService, err := handlers.NewControllerUpstreamMessageServiceServer(ctx, workerAuthStorage) + upstreamMsgService, err := handlers.NewControllerUpstreamMessageServiceServer(ctx, workerAuthStorage, c.conf.SecureRandomReader) if err != nil { return fmt.Errorf("%s: error creating upstream message service handler: %w", op, err) } diff --git a/internal/daemon/worker/auth_rotation.go b/internal/daemon/worker/auth_rotation.go index b1e9758424..ab25219e8d 100644 --- a/internal/daemon/worker/auth_rotation.go +++ b/internal/daemon/worker/auth_rotation.go @@ -279,6 +279,7 @@ func rotateWorkerAuth(ctx context.Context, w *Worker, currentNodeCreds *types.No ctx, w.WorkerAuthStorage, fetchResp, + randReaderOpt, nodeenrollment.WithStorageWrapper(w.conf.WorkerAuthStorageKms), ) if err != nil { diff --git a/internal/daemon/worker/controller_connection.go b/internal/daemon/worker/controller_connection.go index 9321b6f410..3fe1103fdd 100644 --- a/internal/daemon/worker/controller_connection.go +++ b/internal/daemon/worker/controller_connection.go @@ -117,6 +117,7 @@ func (w *Worker) upstreamDialerFunc(extraAlpnProtos ...string) func(context.Cont nodeenrollment.WithRegistrationWrapper(w.conf.WorkerAuthKms), nodeenrollment.WithWrappingRegistrationFlowApplicationSpecificParams(st), nodeenrollment.WithExtraAlpnProtos(extraAlpnProtos), + nodeenrollment.WithRandomReader(w.conf.SecureRandomReader), // If the activation token hasn't been populated, this won't do // anything, and it won't do anything if it's already been used nodeenrollment.WithActivationToken(w.conf.RawConfig.Worker.ControllerGeneratedActivationToken), diff --git a/internal/daemon/worker/listeners.go b/internal/daemon/worker/listeners.go index 9b46520206..a54150a78d 100644 --- a/internal/daemon/worker/listeners.go +++ b/internal/daemon/worker/listeners.go @@ -148,7 +148,7 @@ func (w *Worker) configureForWorker(ln *base.ServerListener, logger *log.Logger, if err != nil { return nil, temperror.New(fmt.Errorf("error deriving node credentials key id in multi-hop fetch function: %w", err)) } - req.RewrappedWrappingRegistrationFlowInfo, err = nodee.EncryptMessage(ctx, regInfo, nodeCreds) + req.RewrappedWrappingRegistrationFlowInfo, err = nodee.EncryptMessage(ctx, regInfo, nodeCreds, opt...) if err != nil { return nil, temperror.New(fmt.Errorf("error rewrapping registration information in multi-hop fetch function: %w", err)) } @@ -211,6 +211,7 @@ func (w *Worker) configureForWorker(ln *base.ServerListener, logger *log.Logger, nodee.WithStorageWrapper(w.conf.WorkerAuthStorageKms), nodee.WithRegistrationWrapper(wrapperToUse), nodee.WithLogger(eventLogger), + nodee.WithRandomReader(w.conf.SecureRandomReader), }, }) if err != nil { diff --git a/internal/daemon/worker/rpc_registration.go b/internal/daemon/worker/rpc_registration.go index cdad537d11..f5b39169f2 100644 --- a/internal/daemon/worker/rpc_registration.go +++ b/internal/daemon/worker/rpc_registration.go @@ -59,6 +59,7 @@ func registerWorkerMultihopService(ctx context.Context, w *Worker, server *grpc. w.WorkerAuthStorage, false, w.controllerMultihopConn, + nodeenrollment.WithRandomReader(w.conf.SecureRandomReader), ) if err != nil { return fmt.Errorf("%s: error creating multihop service handler: %w", op, err) diff --git a/internal/daemon/worker/worker.go b/internal/daemon/worker/worker.go index f32aca4907..7098c8139c 100644 --- a/internal/daemon/worker/worker.go +++ b/internal/daemon/worker/worker.go @@ -803,5 +803,5 @@ func (w *Worker) SendUpstreamMessage(ctx context.Context, m proto.Message) (prot return nil, errors.Wrap(ctx, err, op) } clientProducer := w.controllerUpstreamMsgConn.Load() - return handlers.SendUpstreamMessage(ctx, *clientProducer, initKeyId, m, handlers.WithKeyProducer(nodeCreds)) + return handlers.SendUpstreamMessage(ctx, *clientProducer, initKeyId, m, handlers.WithKeyProducer(nodeCreds), handlers.WithRandomReader(w.conf.SecureRandomReader)) } diff --git a/internal/server/job/options.go b/internal/server/job/options.go index efdcaab6ea..a842f6bedd 100644 --- a/internal/server/job/options.go +++ b/internal/server/job/options.go @@ -4,6 +4,8 @@ package servers import ( + "crypto/rand" + "io" "time" ) @@ -25,11 +27,13 @@ type Option func(*options) type options struct { withRotationFrequency time.Duration withCertificateLifetime time.Duration + withRandomReader io.Reader } func getDefaultOptions() options { return options{ withRotationFrequency: defaultRotationFrequency, + withRandomReader: rand.Reader, } } @@ -49,3 +53,10 @@ func WithCertificateLifetime(with time.Duration) Option { o.withCertificateLifetime = with } } + +// WithRandomReader provides a way to specify a random reader +func WithRandomReader(with io.Reader) Option { + return func(o *options) { + o.withRandomReader = with + } +} diff --git a/internal/server/job/rotate_roots_job.go b/internal/server/job/rotate_roots_job.go index d0d4d23774..dda635d708 100644 --- a/internal/server/job/rotate_roots_job.go +++ b/internal/server/job/rotate_roots_job.go @@ -5,6 +5,7 @@ package servers import ( "context" + "io" "time" "github.com/hashicorp/boundary/internal/db" @@ -26,6 +27,8 @@ type rotateRootsJob struct { rotateFrequency time.Duration certificateLifetime time.Duration + + randReader io.Reader } // newRotateRootsJob instantiates the rotate roots job. @@ -52,6 +55,7 @@ func newRotateRootsJob(ctx context.Context, r db.Reader, w db.Writer, kms *kms.K totalRotates: 0, rotateFrequency: opts.withRotationFrequency, certificateLifetime: opts.withCertificateLifetime, + randReader: opts.withRandomReader, }, nil } @@ -81,7 +85,7 @@ func (r *rotateRootsJob) Status() scheduler.JobStatus { func (r *rotateRootsJob) Run(ctx context.Context) error { const op = "server.(rotateRootsJob).Run" - _, err := server.RotateRoots(ctx, r.workerAuthRepo, nodeenrollment.WithCertificateLifetime(r.certificateLifetime)) + _, err := server.RotateRoots(ctx, r.workerAuthRepo, nodeenrollment.WithCertificateLifetime(r.certificateLifetime), nodeenrollment.WithRandomReader(r.randReader)) if err != nil { return errors.Wrap(ctx, err, op) } diff --git a/internal/server/options.go b/internal/server/options.go index 85022f8430..35b91d2ecc 100644 --- a/internal/server/options.go +++ b/internal/server/options.go @@ -5,6 +5,8 @@ package server import ( "context" + "crypto/rand" + "io" "time" "github.com/hashicorp/boundary/version" @@ -53,12 +55,14 @@ type options struct { withFeature version.Feature withDirectlyConnected bool withWorkerPool []string + withRandomReader io.Reader } func getDefaultOptions() options { return options{ withNewIdFunc: newWorkerId, withOperationalState: ActiveOperationalState.String(), + withRandomReader: rand.Reader, } } @@ -267,3 +271,10 @@ func WithWorkerPool(workerIds []string) Option { o.withWorkerPool = workerIds } } + +// WithRandomReader provides a random reader. +func WithRandomReader(with io.Reader) Option { + return func(o *options) { + o.withRandomReader = with + } +} diff --git a/internal/server/options_test.go b/internal/server/options_test.go index c1ba421a57..16adc5f23e 100644 --- a/internal/server/options_test.go +++ b/internal/server/options_test.go @@ -5,6 +5,8 @@ package server import ( "context" + "crypto/rand" + "io" "reflect" "runtime" "testing" @@ -243,4 +245,10 @@ func Test_GetOpts(t *testing.T) { testOpts.withNewIdFunc = nil assert.Equal(t, opts, testOpts) }) + t.Run("WithRandomReader", func(t *testing.T) { + opts := GetOpts(WithRandomReader(io.LimitReader(nil, 0))) + testOpts := getDefaultOptions() + assert.Equal(t, rand.Reader, testOpts.withRandomReader) + assert.NotEqual(t, rand.Reader, opts.withRandomReader) + }) } diff --git a/internal/server/repository.go b/internal/server/repository.go index c740783014..3be02cac66 100644 --- a/internal/server/repository.go +++ b/internal/server/repository.go @@ -5,6 +5,7 @@ package server import ( "context" + "io" "reflect" "time" @@ -26,10 +27,15 @@ type Repository struct { kms *kms.Kms // defaultLimit provides a default for limiting the number of results returned from the repo defaultLimit int + randomReader io.Reader } -// NewRepository creates a new server Repository. Supports the options: WithLimit -// which sets a default limit on results returned by repo operations. +// NewRepository creates a new server Repository. +// +// Supported options: +// +// * WithLimit: default limit on results returned by repo operations +// * RandomReader: Specify a specific random source func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repository, error) { const op = "server.NewRepository" if r == nil { @@ -52,6 +58,7 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, writer: w, kms: kms, defaultLimit: opts.withLimit, + randomReader: opts.withRandomReader, }, nil } diff --git a/internal/server/repository_worker.go b/internal/server/repository_worker.go index 9612b18b16..04e9cdcad8 100644 --- a/internal/server/repository_worker.go +++ b/internal/server/repository_worker.go @@ -637,7 +637,13 @@ func (r *Repository) CreateWorker(ctx context.Context, worker *Worker, opt ...Op if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to create worker auth repository")) } - nodeInfo, err := registration.AuthorizeNode(ctx, workerAuthRepo, opts.WithFetchNodeCredentialsRequest, nodeenrollment.WithSkipStorage(true)) + nodeInfo, err := registration.AuthorizeNode( + ctx, + workerAuthRepo, + opts.WithFetchNodeCredentialsRequest, + nodeenrollment.WithSkipStorage(true), + nodeenrollment.WithRandomReader(r.randomReader), + ) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to authorize node")) } @@ -647,9 +653,14 @@ func (r *Repository) CreateWorker(ctx context.Context, worker *Worker, opt ...Op } case opts.WithCreateControllerLedActivationToken: - tokenId, activationToken, err := registration.CreateServerLedActivationToken(ctx, nil, &types.ServerLedRegistrationRequest{}, + tokenId, activationToken, err := registration.CreateServerLedActivationToken( + ctx, + nil, + &types.ServerLedRegistrationRequest{}, nodeenrollment.WithSkipStorage(true), - nodeenrollment.WithState(state)) + nodeenrollment.WithState(state), + nodeenrollment.WithRandomReader(r.randomReader), + ) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to create controller-led activation token")) } diff --git a/internal/server/repository_workerauth.go b/internal/server/repository_workerauth.go index 05e12ec782..8ecdaa89ad 100644 --- a/internal/server/repository_workerauth.go +++ b/internal/server/repository_workerauth.go @@ -80,7 +80,6 @@ func (r *WorkerAuthRepositoryStorage) Store(ctx context.Context, msg nodeenrollm switch t := msg.(type) { case *types.NodeInformation: // Encrypt the private key - if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(read db.Reader, w db.Writer) error { return StoreNodeInformationTx(ctx, read, w, r.kms, scope.Global.String(), t) }); err != nil { @@ -92,6 +91,7 @@ func (r *WorkerAuthRepositoryStorage) Store(ctx context.Context, msg nodeenrollm if err != nil { return errors.Wrap(ctx, err, op) } + default: return errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("message type %T not supported for Store", msg)) }