Skip to content

Commit

Permalink
Merge pull request #5524 from hashicorp/ddebko-backport-db-fix
Browse files Browse the repository at this point in the history
backport fixes for the database transactions issue
  • Loading branch information
ddebko authored Feb 7, 2025
2 parents f225d09 + 7cf3e98 commit a46631a
Show file tree
Hide file tree
Showing 28 changed files with 147 additions and 35 deletions.
2 changes: 1 addition & 1 deletion internal/auth/ldap/repository_auth_method_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r *Repository) CreateAuthMethod(ctx context.Context, am *AuthMethod, opt .
return nil, errors.Wrap(ctx, err, op)
}

dbWrapper, err := r.kms.GetWrapper(context.Background(), am.ScopeId, kms.KeyPurposeDatabase)
dbWrapper, err := r.kms.GetWrapper(ctx, am.ScopeId, kms.KeyPurposeDatabase)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/oidc/repository_auth_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (r *Repository) upsertAccount(ctx context.Context, am *AuthMethod, IdTokenC
var rowCnt int
for rows.Next() {
rowCnt += 1
err = r.reader.ScanRows(ctx, rows, &result)
err = reader.ScanRows(ctx, rows, &result)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to scan rows for account"))
}
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/oidc/repository_auth_method_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (r *Repository) CreateAuthMethod(ctx context.Context, am *AuthMethod, opt .
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper"))
}

databaseWrapper, err := r.kms.GetWrapper(context.Background(), am.ScopeId, kms.KeyPurposeDatabase)
databaseWrapper, err := r.kms.GetWrapper(ctx, am.ScopeId, kms.KeyPurposeDatabase)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
Expand Down
5 changes: 3 additions & 2 deletions internal/auth/oidc/repository_managed_group_members.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/util"
)

// SetManagedGroupMemberships will set the managed groups for the given account
Expand Down Expand Up @@ -207,7 +208,7 @@ func (r *Repository) ListManagedGroupMembershipsByMember(ctx context.Context, wi
limit = opts.withLimit
}
reader := r.reader
if opts.withReader != nil {
if !util.IsNil(opts.withReader) {
reader = opts.withReader
}
var mgs []*ManagedGroupMemberAccount
Expand All @@ -232,7 +233,7 @@ func (r *Repository) ListManagedGroupMembershipsByGroup(ctx context.Context, wit
limit = opts.withLimit
}
reader := r.reader
if opts.withReader != nil {
if !util.IsNil(opts.withReader) {
reader = opts.withReader
}
var mgs []*ManagedGroupMemberAccount
Expand Down
2 changes: 1 addition & 1 deletion internal/auth/repository_auth_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (amr *AuthMethodRepository) ListDeletedIds(ctx context.Context, since time.
var deletedAuthMethodIDs []string
var transactionTimestamp time.Time
if _, err := amr.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
rows, err := amr.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
if err != nil {
return errors.Wrap(ctx, err, op)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/credential/repository_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (s *StoreRepository) ListDeletedIds(ctx context.Context, since time.Time) (
var deletedStoreIDs []string
var transactionTimestamp time.Time
if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
if err != nil {
return errors.Wrap(ctx, err, op)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/credential/vault/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func nextRenewal(ctx context.Context, j scheduler.Job) (time.Duration, error) {
return 0, errors.New(ctx, errors.Unknown, op, "unknown job")
}

rows, err := r.Query(context.Background(), query, nil)
rows, err := r.Query(ctx, query, nil)
if err != nil {
return 0, errors.Wrap(ctx, err, op)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/credential/vault/vault_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func newToken(ctx context.Context, storeId string, token TokenSecret, accessor [
accessorCopy := make([]byte, len(accessor))
copy(accessorCopy, accessor)

hmac, err := crypto.HmacSha256WithPrk(context.Background(), tokenCopy, accessorCopy)
hmac, err := crypto.HmacSha256WithPrk(ctx, tokenCopy, accessorCopy)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt))
}
Expand Down
2 changes: 1 addition & 1 deletion internal/daemon/controller/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ func wrapHandlerWithCallbackInterceptor(h http.Handler, c *Controller) http.Hand

if strings.HasSuffix(req.URL.Path, "oidc:authenticate") {
if s, ok := values["state"].(string); ok {
stateWrapper, err := oidc.UnwrapMessage(context.Background(), s)
stateWrapper, err := oidc.UnwrapMessage(ctx, s)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error marshaling state"))
w.WriteHeader(http.StatusInternalServerError)
Expand Down
20 changes: 20 additions & 0 deletions internal/host/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package host
import (
"errors"

"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/hashicorp/boundary/internal/util"
)

// GetOpts - iterate the inbound Options and return a struct
Expand All @@ -26,6 +28,8 @@ type Option func(*options) error
// options = how options are represented
type options struct {
WithLimit int
WithReader db.Reader
WithWriter db.Writer
WithOrderByCreateTime bool
Ascending bool
WithStartPageAfterItem pagination.Item
Expand Down Expand Up @@ -66,3 +70,19 @@ func WithStartPageAfterItem(item pagination.Item) Option {
return nil
}
}

// WithReaderWriter is used to share the same database reader
// and writer when executing sql within a transaction.
func WithReaderWriter(r db.Reader, w db.Writer) Option {
return func(o *options) error {
if util.IsNil(r) {
return errors.New("reader cannot be nil")
}
if util.IsNil(w) {
return errors.New("writer cannot be nil")
}
o.WithReader = r
o.WithWriter = w
return nil
}
}
19 changes: 19 additions & 0 deletions internal/host/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,23 @@ func Test_GetOpts(t *testing.T) {
assert.Equal(opts.WithStartPageAfterItem.GetPublicId(), "s_1")
assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime))
})
t.Run("WithReaderWriter", func(t *testing.T) {
t.Parallel()
t.Run("nil writer", func(t *testing.T) {
t.Parallel()
_, err := GetOpts(WithReaderWriter(&db.Db{}, nil))
require.Error(t, err)
})
t.Run("nil reader", func(t *testing.T) {
t.Parallel()
_, err := GetOpts(WithReaderWriter(nil, &db.Db{}))
require.Error(t, err)
})
reader := &db.Db{}
writer := &db.Db{}
opts, err := GetOpts(WithReaderWriter(reader, writer))
require.NoError(t, err)
assert.Equal(t, reader, opts.WithReader)
assert.Equal(t, writer, opts.WithWriter)
})
}
2 changes: 1 addition & 1 deletion internal/host/plugin/job_set_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func nextSync(ctx context.Context, j scheduler.Job) (time.Duration, error) {
return 0, errors.New(ctx, errors.Unknown, op, "unknown job")
}

rows, err := r.Query(context.Background(), query, []any{setSyncJobRunInterval})
rows, err := r.Query(ctx, query, []any{setSyncJobRunInterval})
if err != nil {
return 0, errors.Wrap(ctx, err, op)
}
Expand Down
12 changes: 12 additions & 0 deletions internal/host/plugin/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package plugin

import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/pagination"
"google.golang.org/protobuf/types/known/structpb"
)
Expand Down Expand Up @@ -38,6 +39,8 @@ type options struct {
withSecretsHmac []byte
withStartPageAfterItem pagination.Item
withWorkerFilter string
WithReader db.Reader
withWriter db.Writer
}

func getDefaultOptions() options {
Expand Down Expand Up @@ -162,3 +165,12 @@ func WithWorkerFilter(wf string) Option {
o.withWorkerFilter = wf
}
}

// WithReaderWriter is used to share the same database reader
// and writer when executing sql within a transaction.
func WithReaderWriter(r db.Reader, w db.Writer) Option {
return func(o *options) {
o.WithReader = r
o.withWriter = w
}
}
8 changes: 8 additions & 0 deletions internal/host/plugin/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -113,4 +114,11 @@ func Test_GetOpts(t *testing.T) {
testOpts.withWorkerFilter = `"test" in "/tags/type"`
assert.Equal(t, opts, testOpts)
})
t.Run("WithReaderWriter", func(t *testing.T) {
reader := &db.Db{}
writer := &db.Db{}
opts := getOpts(WithReaderWriter(reader, writer))
assert.Equal(t, reader, opts.WithReader)
assert.Equal(t, writer, opts.withWriter)
})
}
14 changes: 10 additions & 4 deletions internal/host/plugin/repository_host_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/host"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/libs/patchstruct"
"github.com/hashicorp/boundary/internal/oplog"
Expand Down Expand Up @@ -404,7 +405,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
func(read db.Reader, w db.Writer) error {
msgs := make([]*oplog.Message, 0, 3)
ticket, err := w.GetTicket(ctx, newCatalog)
if err != nil {
Expand Down Expand Up @@ -528,7 +529,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
if needSetSync {
// We also need to mark all host sets in this catalog to be
// synced as well.
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId)
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId, host.WithReaderWriter(read, w))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get sets for host catalog"))
}
Expand Down Expand Up @@ -713,14 +714,19 @@ func (r *Repository) getCatalog(ctx context.Context, id string) (*HostCatalog, *
return c, p, nil
}

func (r *Repository) getPlugin(ctx context.Context, plgId string) (*plg.Plugin, error) {
func (r *Repository) getPlugin(ctx context.Context, plgId string, opts ...Option) (*plg.Plugin, error) {
const op = "plugin.(Repository).getPlugin"
if plgId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "no plugin id")
}
opt := getOpts(opts...)
reader := r.reader
if !util.IsNil(opt.WithReader) {
reader = opt.WithReader
}
plg := plg.NewPlugin()
plg.PublicId = plgId
if err := r.reader.LookupByPublicId(ctx, plg); err != nil {
if err := reader.LookupByPublicId(ctx, plg); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to get host plugin with id %q", plgId)))
}
return plg, nil
Expand Down
13 changes: 11 additions & 2 deletions internal/host/plugin/repository_host_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
limit = opts.WithLimit
}

reader := r.reader
writer := r.writer
if !util.IsNil(opts.WithReader) {
reader = opts.WithReader
}
if !util.IsNil(opts.WithWriter) {
writer = opts.WithWriter
}

args := make([]any, 0, 1)
var where string

Expand All @@ -825,7 +834,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
}

var aggHostSets []*hostSetAgg
if err := r.reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
if err := reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", publicId)))
}

Expand All @@ -844,7 +853,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
}
var plg *plugin.Plugin
if plgId != "" {
plg, err = r.getPlugin(ctx, plgId)
plg, err = r.getPlugin(ctx, plgId, WithReaderWriter(reader, writer))
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/host/repository_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (s *CatalogRepository) ListDeletedIds(ctx context.Context, since time.Time)
var deletedCatalogIDs []string
var transactionTimestamp time.Time
if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
if err != nil {
return errors.Wrap(ctx, err, op)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/host/static/repository_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (r *Repository) UpdateHost(ctx context.Context, projectId string, h *Host,
var rowsUpdated int
var returnedHost *Host
_, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
func(r db.Reader, w db.Writer) error {
returnedHost = h.clone()
var err error
rowsUpdated, err = w.Update(ctx, returnedHost, dbMask, nullFields,
Expand All @@ -183,7 +183,7 @@ func (r *Repository) UpdateHost(ctx context.Context, projectId string, h *Host,
ha := &hostAgg{
PublicId: h.PublicId,
}
if err := r.reader.LookupByPublicId(ctx, ha); err != nil {
if err := r.LookupByPublicId(ctx, ha); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to lookup host after update"))
}
returnedHost.SetIds = ha.getSetIds()
Expand Down
5 changes: 3 additions & 2 deletions internal/iam/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/internal/util"
)

var ErrMetadataScopeNotFound = errors.New(context.Background(), errors.RecordNotFound, "iam", "scope not found for metadata", errors.WithoutEvent())
Expand Down Expand Up @@ -65,7 +66,7 @@ func (r *Repository) list(ctx context.Context, resources any, where string, args
limit = opts.withLimit
}
reader := r.reader
if opts.withReader != nil {
if !util.IsNil(opts.withReader) {
reader = opts.withReader
}
return reader.SearchWhere(ctx, resources, where, args, db.WithLimit(limit))
Expand Down Expand Up @@ -150,7 +151,7 @@ func (r *Repository) update(ctx context.Context, resource Resource, version uint
reader := r.reader
writer := r.writer
needFreshReaderWriter := true
if opts.withReader != nil && opts.withWriter != nil {
if !util.IsNil(opts.withReader) && !util.IsNil(opts.withWriter) {
reader = opts.withReader
writer = opts.withWriter
if !writer.IsTx(ctx) {
Expand Down
3 changes: 2 additions & 1 deletion internal/iam/repository_grant_scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/util"
)

// AddRoleGrantScopes will add role grant scopes associated with the role ID in
Expand Down Expand Up @@ -235,7 +236,7 @@ func (r *Repository) SetRoleGrantScopes(ctx context.Context, roleId string, role
writer := r.writer
needFreshReaderWriter := true
opts := getOpts(opt...)
if opts.withReader != nil && opts.withWriter != nil {
if !util.IsNil(opts.withReader) && !util.IsNil(opts.withWriter) {
reader = opts.withReader
writer = opts.withWriter
needFreshReaderWriter = false
Expand Down
5 changes: 3 additions & 2 deletions internal/iam/repository_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/go-dbw"
)

Expand Down Expand Up @@ -193,7 +194,7 @@ func (r *Repository) LookupRole(ctx context.Context, withPublicId string, opt ..
}

var err error
if opts.withReader != nil && opts.withWriter != nil {
if !util.IsNil(opts.withReader) && !util.IsNil(opts.withWriter) {
if !opts.withWriter.IsTx(ctx) {
return nil, nil, nil, nil, errors.New(ctx, errors.Internal, op, "writer is not in transaction")
}
Expand Down Expand Up @@ -325,7 +326,7 @@ func (r *Repository) queryRoles(ctx context.Context, whereClause string, args []
for _, retRole := range retRoles {
roleIds = append(roleIds, retRole.PublicId)
}
retRoleGrantScopes, err = r.ListRoleGrantScopes(ctx, roleIds)
retRoleGrantScopes, err = r.ListRoleGrantScopes(ctx, roleIds, WithReaderWriter(rd, w))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to query role grant scopes"))
}
Expand Down
Loading

0 comments on commit a46631a

Please sign in to comment.