Skip to content

Commit

Permalink
Pass our configured random reader into nodee functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai committed Feb 21, 2024
1 parent 2497078 commit 0b09505
Show file tree
Hide file tree
Showing 24 changed files with 176 additions and 44 deletions.
38 changes: 38 additions & 0 deletions internal/daemon/cluster/handlers/option_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package handlers

import (
"context"
"crypto/rand"
"io"
"testing"

"github.com/hashicorp/nodeenrollment/storage/inmem"
"github.com/hashicorp/nodeenrollment/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// Test_GetOpts provides unit tests for GetOpts and all the options
func Test_GetOpts(t *testing.T) {
t.Parallel()

t.Run("withKeyProducer", func(t *testing.T) {
storage, err := inmem.New(context.Background())
require.NoError(t, err)
creds, err := types.NewNodeCredentials(context.Background(), storage)
require.NoError(t, err)
opts := getOpts(WithKeyProducer(creds))
testOpts := getDefaultOptions()
assert.Equal(t, nil, testOpts.withKeyProducer)
assert.Equal(t, creds, opts.withKeyProducer)
})
t.Run("WithRandomReader", func(t *testing.T) {
opts := getOpts(WithRandomReader(io.LimitReader(nil, 0)))
testOpts := getDefaultOptions()
assert.Equal(t, rand.Reader, testOpts.withRandomReader)
assert.NotEqual(t, rand.Reader, opts.withRandomReader)
})
}
17 changes: 15 additions & 2 deletions internal/daemon/cluster/handlers/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package handlers

import (
"crypto/rand"
"io"

"github.com/hashicorp/nodeenrollment"
)

Expand All @@ -21,11 +24,14 @@ type Option func(*options)

// options = how options are represented
type options struct {
withKeyProducer nodeenrollment.X25519KeyProducer
withKeyProducer nodeenrollment.X25519KeyProducer
withRandomReader io.Reader
}

func getDefaultOptions() options {
return options{}
return options{
withRandomReader: rand.Reader,
}
}

// WithKeyProducer provides an option types.NodeInformation
Expand All @@ -34,3 +40,10 @@ func WithKeyProducer(nodeInfo nodeenrollment.X25519KeyProducer) Option {
o.withKeyProducer = nodeInfo
}
}

// WithRandomReader provides an option to specify a specific random source
func WithRandomReader(with io.Reader) Option {
return func(o *options) {
o.withRandomReader = with
}
}
3 changes: 2 additions & 1 deletion internal/daemon/cluster/handlers/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package handlers

import (
"context"
"crypto/rand"
"net"
"sync"
"testing"
Expand Down Expand Up @@ -58,7 +59,7 @@ func TestUpstreamService(t *testing.T) (UpstreamMessageServiceClientProducer, *t
require.NoError(t, err)

// start an upstream controller
testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage)
testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage, rand.Reader)
require.NoError(t, err)
require.NotNil(t, testController)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package handlers
import (
"context"
"fmt"
"io"
"sync"

"github.com/hashicorp/boundary/internal/errors"
Expand Down Expand Up @@ -79,7 +80,8 @@ func getUpstreamMessageHandler(ctx context.Context, msgType pbs.MsgType) (Upstre
// UpstreamMessageServiceServer for OSS controllers
type controllerUpstreamMessageServiceServer struct {
pbs.UnimplementedUpstreamMessageServiceServer
storage nodeenrollment.Storage
storage nodeenrollment.Storage
randReader io.Reader
}

var _ pbs.UpstreamMessageServiceServer = (*controllerUpstreamMessageServiceServer)(nil)
Expand All @@ -90,6 +92,7 @@ var _ pbs.UpstreamMessageServiceServer = (*controllerUpstreamMessageServiceServe
func NewControllerUpstreamMessageServiceServer(
ctx context.Context,
storage nodeenrollment.Storage,
randReader io.Reader,
) (pbs.UpstreamMessageServiceServer, error) {
const op = "handlers.NewControllerUpstreamMessageServiceServer"
switch {
Expand All @@ -98,7 +101,8 @@ func NewControllerUpstreamMessageServiceServer(
}

return &controllerUpstreamMessageServiceServer{
storage: storage,
storage: storage,
randReader: randReader,
}, nil
}

Expand Down Expand Up @@ -186,7 +190,7 @@ func (s *controllerUpstreamMessageServiceServer) UpstreamMessage(ctx context.Con
},
}, nil
default:
ct, err := nodeenrollment.EncryptMessage(ctx, respMsg, nodeInfo)
ct, err := nodeenrollment.EncryptMessage(ctx, respMsg, nodeInfo, nodeenrollment.WithRandomReader(s.randReader))
if err != nil {
return nil, status.Errorf(codes.Internal, "%s: error encrypting response: %v", op, err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func Test_controllerUpstreamMessageServiceServer_UpstreamMessage(t *testing.T) {
nodeInfo, err := types.LoadNodeInformation(testCtx, initStorage, initKeyId)
require.NoError(t, err)
// define a test controller
testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage)
testController, err := NewControllerUpstreamMessageServiceServer(testCtx, initStorage, nil)
require.NoError(t, err)
require.NotNil(t, testController)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package handlers
import (
"context"
"fmt"
"io"
"sync"

"github.com/hashicorp/boundary/internal/errors"
Expand Down Expand Up @@ -153,7 +154,7 @@ func SendUpstreamMessage(ctx context.Context, clientProducer UpstreamMessageServ
if opts.withKeyProducer == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing node information required for encrypting unwrap keys message")
}
req, err = ctMsg(ctx, opts.withKeyProducer, msgType, msg)
req, err = ctMsg(ctx, opts.withKeyProducer, msgType, msg, opts.withRandomReader)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
Expand Down Expand Up @@ -202,9 +203,9 @@ func ptMsg(ctx context.Context, msgType pbs.MsgType, msg proto.Message) (*pbs.Up
}, nil
}

func ctMsg(ctx context.Context, keySource nodeenrollment.X25519KeyProducer, msgType pbs.MsgType, msg proto.Message) (*pbs.UpstreamMessageRequest, error) {
func ctMsg(ctx context.Context, keySource nodeenrollment.X25519KeyProducer, msgType pbs.MsgType, msg proto.Message, randomReader io.Reader) (*pbs.UpstreamMessageRequest, error) {
const op = "handlers.encryptMsg"
ct, err := nodeenrollment.EncryptMessage(ctx, msg, keySource)
ct, err := nodeenrollment.EncryptMessage(ctx, msg, keySource, nodeenrollment.WithRandomReader(randomReader))
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error encrypting upstream message"))
}
Expand Down
3 changes: 2 additions & 1 deletion internal/daemon/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func New(ctx context.Context, conf *Config) (*Controller, error) {
return host.NewCatalogRepository(ctx, dbase, dbase)
}
c.ServersRepoFn = func() (*server.Repository, error) {
return server.NewRepository(ctx, dbase, dbase, c.kms)
return server.NewRepository(ctx, dbase, dbase, c.kms, server.WithRandomReader(c.conf.SecureRandomReader))
}
c.OidcRepoFn = func() (*oidc.Repository, error) {
return oidc.NewRepository(ctx, dbase, dbase, c.kms)
Expand Down Expand Up @@ -596,6 +596,7 @@ func (c *Controller) registerJobs() error {
serverJobOpts = append(serverJobOpts,
serversjob.WithCertificateLifetime(c.conf.TestOverrideWorkerAuthCaCertificateLifetime),
serversjob.WithRotationFrequency(c.conf.TestOverrideWorkerAuthCaCertificateLifetime/2),
serversjob.WithRandomReader(c.conf.SecureRandomReader),
)
}
if err := serversjob.RegisterJobs(c.baseContext, c.scheduler, rw, rw, c.kms, serverJobOpts...); err != nil {
Expand Down
10 changes: 8 additions & 2 deletions internal/daemon/controller/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,14 @@ func (c *Controller) registerGrpcServices(s *grpc.Server) error {
services.RegisterCredentialLibraryServiceServer(s, cl)
}
if _, ok := currentServices[services.WorkerService_ServiceDesc.ServiceName]; !ok {
ws, err := workers.NewService(c.baseContext, c.ServersRepoFn, c.IamRepoFn, c.WorkerAuthRepoStorageFn,
c.downstreamWorkers)
ws, err := workers.NewService(
c.baseContext,
c.ServersRepoFn,
c.IamRepoFn,
c.WorkerAuthRepoStorageFn,
c.downstreamWorkers,
c.conf.SecureRandomReader,
)
if err != nil {
return fmt.Errorf("failed to create worker handler service: %w", err)
}
Expand Down
17 changes: 13 additions & 4 deletions internal/daemon/controller/handlers/workers/worker_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"encoding/hex"
stderrors "errors"
"fmt"
"io"
"strings"

"github.com/hashicorp/boundary/globals"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/hashicorp/boundary/internal/util"
pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/workers"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/nodeenrollment"
"github.com/hashicorp/nodeenrollment/types"
"github.com/mr-tron/base58"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -97,13 +99,20 @@ type Service struct {
workerAuthFn common.WorkerAuthRepoStorageFactory
iamRepoFn common.IamRepoFactory
downstreams common.Downstreamers

randReader io.Reader
}

var _ pbs.WorkerServiceServer = (*Service)(nil)

// NewService returns a worker service which handles worker related requests to boundary.
func NewService(ctx context.Context, repo common.ServersRepoFactory, iamRepoFn common.IamRepoFactory,
workerAuthFn common.WorkerAuthRepoStorageFactory, ds common.Downstreamers,
func NewService(
ctx context.Context,
repo common.ServersRepoFactory,
iamRepoFn common.IamRepoFactory,
workerAuthFn common.WorkerAuthRepoStorageFactory,
ds common.Downstreamers,
randReader io.Reader,
) (Service, error) {
const op = "workers.NewService"
if repo == nil {
Expand All @@ -115,7 +124,7 @@ func NewService(ctx context.Context, repo common.ServersRepoFactory, iamRepoFn c
if workerAuthFn == nil {
return Service{}, errors.New(ctx, errors.InvalidParameter, op, "missing worker auth repository")
}
return Service{repoFn: repo, iamRepoFn: iamRepoFn, workerAuthFn: workerAuthFn, downstreams: ds}, nil
return Service{repoFn: repo, iamRepoFn: iamRepoFn, workerAuthFn: workerAuthFn, downstreams: ds, randReader: randReader}, nil
}

// ListWorkers implements the interface pbs.WorkerServiceServer.
Expand Down Expand Up @@ -565,7 +574,7 @@ func (s Service) ReinitializeCertificateAuthority(ctx context.Context, req *pbs.
return nil, err
}

rootCerts, err := server.ReinitializeRoots(ctx, repo)
rootCerts, err := server.ReinitializeRoots(ctx, repo, nodeenrollment.WithRandomReader(s.randReader))
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 0b09505

Please sign in to comment.