diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8a826cf91..62393b938 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -253,6 +253,8 @@ jobs: unit_type: - unit-race - unit + - unit dbbackend=postgres + - unit dbbackend=sqlite steps: - name: git checkout uses: actions/checkout@v4 diff --git a/accounts/interface.go b/accounts/interface.go index 11d3efe93..ef51b202a 100644 --- a/accounts/interface.go +++ b/accounts/interface.go @@ -1,7 +1,9 @@ package accounts import ( + "bytes" "context" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -55,6 +57,31 @@ func ParseAccountID(idStr string) (*AccountID, error) { return &id, nil } +// ToInt64 converts an AccountID to its int64 representation. +func (a AccountID) ToInt64() (int64, error) { + var value int64 + buf := bytes.NewReader(a[:]) + if err := binary.Read(buf, byteOrder, &value); err != nil { + return 0, err + } + + return value, nil +} + +// AccountIDFromInt64 converts an int64 to an AccountID. +func AccountIDFromInt64(value int64) (AccountID, error) { + var ( + a = AccountID{} + buf = new(bytes.Buffer) + ) + if err := binary.Write(buf, binary.BigEndian, value); err != nil { + return a, err + } + copy(a[:], buf.Bytes()) + + return a, nil +} + // String returns the string representation of the AccountID. func (a AccountID) String() string { return hex.EncodeToString(a[:]) @@ -225,9 +252,14 @@ type Store interface { AddAccountInvoice(ctx context.Context, id AccountID, hash lntypes.Hash) error - // IncreaseAccountBalance increases the balance of the account with the + // CreditAccount increases the balance of the account with the + // given ID by the given amount. + CreditAccount(ctx context.Context, id AccountID, + amount lnwire.MilliSatoshi) error + + // DebitAccount decreases the balance of the account with the // given ID by the given amount. - IncreaseAccountBalance(ctx context.Context, id AccountID, + DebitAccount(ctx context.Context, id AccountID, amount lnwire.MilliSatoshi) error // UpsertAccountPayment updates or inserts a payment entry for the given diff --git a/accounts/service.go b/accounts/service.go index f84f12500..098c4a8fa 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -563,7 +563,7 @@ func (s *InterceptorService) invoiceUpdate(ctx context.Context, // If we get here, the current account has the invoice associated with // it that was just paid. Credit the amount to the account and update it // in the DB. - err := s.store.IncreaseAccountBalance(ctx, acctID, invoice.AmountPaid) + err := s.store.CreditAccount(ctx, acctID, invoice.AmountPaid) if err != nil { return s.disableAndErrorfUnsafe("error increasing account "+ "balance account: %w", err) diff --git a/accounts/store_kvdb.go b/accounts/store_kvdb.go index 47a7c7ac4..a08be5541 100644 --- a/accounts/store_kvdb.go +++ b/accounts/store_kvdb.go @@ -223,11 +223,11 @@ func (s *BoltStore) AddAccountInvoice(_ context.Context, id AccountID, return s.updateAccount(id, update) } -// IncreaseAccountBalance increases the balance of the account with the given ID +// CreditAccount increases the balance of the account with the given ID // by the given amount. // // NOTE: This is part of the Store interface. -func (s *BoltStore) IncreaseAccountBalance(_ context.Context, id AccountID, +func (s *BoltStore) CreditAccount(_ context.Context, id AccountID, amount lnwire.MilliSatoshi) error { update := func(account *OffChainBalanceAccount) error { @@ -244,6 +244,33 @@ func (s *BoltStore) IncreaseAccountBalance(_ context.Context, id AccountID, return s.updateAccount(id, update) } +// DebitAccount decreases the balance of the account with the given ID +// by the given amount. +// +// NOTE: This is part of the Store interface. +func (s *BoltStore) DebitAccount(_ context.Context, id AccountID, + amount lnwire.MilliSatoshi) error { + + update := func(account *OffChainBalanceAccount) error { + if amount > math.MaxInt64 { + return fmt.Errorf("amount %v exceeds the maximum of %v", + amount, int64(math.MaxInt64)) + } + + if account.CurrentBalance-int64(amount) < 0 { + return fmt.Errorf("cannot debit %v from the account "+ + "balance, as the resulting balance would be "+ + "below 0", int64(amount/1000)) + } + + account.CurrentBalance -= int64(amount) + + return nil + } + + return s.updateAccount(id, update) +} + // UpsertAccountPayment updates or inserts a payment entry for the given // account. Various functional options can be passed to modify the behavior of // the method. The returned boolean is true if the payment was already known diff --git a/accounts/store_sql.go b/accounts/store_sql.go new file mode 100644 index 000000000..d57431e16 --- /dev/null +++ b/accounts/store_sql.go @@ -0,0 +1,731 @@ +package accounts + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + // addIndexName is the name of the key under which we store the last + // known invoice add index in the accounts_indices table. + addIndexName = "last_add_index" + + // settleIndexName is the name of the key under which we store the + // last known invoice settle index in the accounts_indices table. + settleIndexName = "last_settle_index" +) + +// SQLQueries is a subset of the sqlc.Queries interface that can be used +// to interact with accounts related tables. +// +//nolint:lll +type SQLQueries interface { + AddAccountInvoice(ctx context.Context, arg sqlc.AddAccountInvoiceParams) error + CreditAccount(ctx context.Context, arg sqlc.CreditAccountParams) (int64, error) + DebitAccount(ctx context.Context, arg sqlc.DebitAccountParams) (int64, error) + DeleteAccount(ctx context.Context, id int64) error + DeleteAccountPayment(ctx context.Context, arg sqlc.DeleteAccountPaymentParams) error + GetAccount(ctx context.Context, id int64) (sqlc.Account, error) + GetAccountByLabel(ctx context.Context, label sql.NullString) (sqlc.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) + GetAccountIndex(ctx context.Context, name string) (int64, error) + GetAccountPayment(ctx context.Context, arg sqlc.GetAccountPaymentParams) (sqlc.AccountPayment, error) + InsertAccount(ctx context.Context, arg sqlc.InsertAccountParams) (int64, error) + ListAccountInvoices(ctx context.Context, id int64) ([]sqlc.AccountInvoice, error) + ListAccountPayments(ctx context.Context, id int64) ([]sqlc.AccountPayment, error) + ListAllAccounts(ctx context.Context) ([]sqlc.Account, error) + SetAccountIndex(ctx context.Context, arg sqlc.SetAccountIndexParams) error + UpdateAccountBalance(ctx context.Context, arg sqlc.UpdateAccountBalanceParams) (int64, error) + UpdateAccountExpiry(ctx context.Context, arg sqlc.UpdateAccountExpiryParams) (int64, error) + UpdateAccountLastUpdate(ctx context.Context, arg sqlc.UpdateAccountLastUpdateParams) (int64, error) + UpsertAccountPayment(ctx context.Context, arg sqlc.UpsertAccountPaymentParams) error + GetAccountInvoice(ctx context.Context, arg sqlc.GetAccountInvoiceParams) (sqlc.AccountInvoice, error) +} + +// BatchedSQLQueries is a version of the SQLActionQueries that's capable +// of batched database operations. +type BatchedSQLQueries interface { + SQLQueries + + db.BatchedTx[SQLQueries] +} + +// SQLStore represents a storage backend. +type SQLStore struct { + // db is all the higher level queries that the SQLStore has access to + // in order to implement all its CRUD logic. + db BatchedSQLQueries + + // DB represents the underlying database connection. + *sql.DB + + clock clock.Clock +} + +// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries +// storage backend. +func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { + executor := db.NewTransactionExecutor( + sqlDB, func(tx *sql.Tx) SQLQueries { + return sqlDB.WithTx(tx) + }, + ) + + return &SQLStore{ + db: executor, + DB: sqlDB.DB, + clock: clock, + } +} + +// NewAccount creates and persists a new OffChainBalanceAccount with the given +// balance and a randomly chosen ID. If the given label is not empty, then it +// must be unique; if it is not, then ErrLabelAlreadyExists is returned. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) NewAccount(ctx context.Context, balance lnwire.MilliSatoshi, + expirationDate time.Time, label string) (*OffChainBalanceAccount, + error) { + + // Ensure that if a label is set, it can't be mistaken for a hex + // encoded account ID to avoid confusion and make it easier for the CLI + // to distinguish between the two. + var labelVal sql.NullString + if len(label) > 0 { + if _, err := hex.DecodeString(label); err == nil && + len(label) == hex.EncodedLen(AccountIDLen) { + + return nil, fmt.Errorf("the label '%s' is not allowed "+ + "as it can be mistaken for an account ID", + label) + } + + labelVal = sql.NullString{ + String: label, + Valid: true, + } + } + + var ( + writeTxOpts db.QueriesTxOptions + account *OffChainBalanceAccount + ) + err := s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + // First, find a unique alias (this is what the ID was in the + // kvdb implementation of the DB). + alias, err := uniqueRandomAccountAlias(ctx, db) + if err != nil { + return err + } + + if labelVal.Valid { + _, err = db.GetAccountByLabel(ctx, labelVal) + if err == nil { + return ErrLabelAlreadyExists + } else if !errors.Is(err, sql.ErrNoRows) { + return err + } + } + + id, err := db.InsertAccount(ctx, sqlc.InsertAccountParams{ + Type: int16(TypeInitialBalance), + InitialBalanceMsat: int64(balance), + CurrentBalanceMsat: int64(balance), + Expiration: expirationDate, + LastUpdated: s.clock.Now().UTC(), + Label: labelVal, + Alias: alias, + }) + if err != nil { + return fmt.Errorf("inserting account: %w", err) + } + + account, err = getAndMarshalAccount(ctx, db, id) + if err != nil { + return fmt.Errorf("fetching account: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + return account, nil +} + +// getAndMarshalAccount retrieves the account with the given ID. If the account +// cannot be found, then ErrAccNotFound is returned. +func getAndMarshalAccount(ctx context.Context, db SQLQueries, id int64) ( + *OffChainBalanceAccount, error) { + + dbAcct, err := db.GetAccount(ctx, id) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrAccNotFound + } else if err != nil { + return nil, err + } + + return marshalDBAccount(ctx, db, dbAcct) +} + +func marshalDBAccount(ctx context.Context, db SQLQueries, + dbAcct sqlc.Account) (*OffChainBalanceAccount, error) { + + alias, err := AccountIDFromInt64(dbAcct.Alias) + if err != nil { + return nil, err + } + + account := &OffChainBalanceAccount{ + ID: alias, + Type: AccountType(dbAcct.Type), + InitialBalance: lnwire.MilliSatoshi(dbAcct.InitialBalanceMsat), + CurrentBalance: dbAcct.CurrentBalanceMsat, + LastUpdate: dbAcct.LastUpdated.UTC(), + ExpirationDate: dbAcct.Expiration.UTC(), + Invoices: make(AccountInvoices), + Payments: make(AccountPayments), + Label: dbAcct.Label.String, + } + + invoices, err := db.ListAccountInvoices(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + for _, invoice := range invoices { + var hash lntypes.Hash + copy(hash[:], invoice.Hash) + account.Invoices[hash] = struct{}{} + } + + payments, err := db.ListAccountPayments(ctx, dbAcct.ID) + if err != nil { + return nil, err + } + + for _, payment := range payments { + var hash lntypes.Hash + copy(hash[:], payment.Hash) + account.Payments[hash] = &PaymentEntry{ + Status: lnrpc.Payment_PaymentStatus(payment.Status), + FullAmount: lnwire.MilliSatoshi(payment.FullAmountMsat), + } + } + + return account, nil +} + +// uniqueRandomAccountAlias generates a random account alias that is not already +// in use. An account "alias" is a unique 8 byte identifier (which corresponds +// to the AccountID type) that is used to identify accounts in the database. The +// reason for using this alias in addition to the SQL auto-incremented ID is to +// remain backwards compatible with the kvdb implementation of the DB which only +// used the alias. +func uniqueRandomAccountAlias(ctx context.Context, db SQLQueries) (int64, + error) { + + var ( + newAlias AccountID + numTries = 10 + ) + for numTries > 0 { + if _, err := rand.Read(newAlias[:]); err != nil { + return 0, err + } + + newAliasID, err := newAlias.ToInt64() + if err != nil { + return 0, err + } + + _, err = db.GetAccountIDByAlias(ctx, newAliasID) + if errors.Is(err, sql.ErrNoRows) { + // No account found with this new ID, we can use it. + return newAliasID, nil + } else if err != nil { + return 0, err + } + + numTries-- + } + + return 0, fmt.Errorf("couldn't create new account ID") +} + +// AddAccountInvoice adds and invoice hash to the account with the given +// AccountID alias. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) AddAccountInvoice(ctx context.Context, alias AccountID, + hash lntypes.Hash) error { + + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + acctID, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + // First check that this invoice does not already exist. + _, err = db.GetAccountInvoice(ctx, sqlc.GetAccountInvoiceParams{ + AccountID: acctID, + Hash: hash[:], + }) + // If it does, there is nothing left to do. + if err == nil { + return nil + } else if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + err = db.AddAccountInvoice(ctx, sqlc.AddAccountInvoiceParams{ + AccountID: acctID, + Hash: hash[:], + }) + if err != nil { + return err + } + + return s.markAccountUpdated(ctx, db, acctID) + }) +} + +func getAccountIDByAlias(ctx context.Context, db SQLQueries, alias AccountID) ( + int64, error) { + + aliasInt, err := alias.ToInt64() + if err != nil { + return 0, fmt.Errorf("error converting account alias into "+ + "int64: %w", err) + } + + acctID, err := db.GetAccountIDByAlias(ctx, aliasInt) + if errors.Is(err, sql.ErrNoRows) { + return 0, ErrAccNotFound + } + + return acctID, err +} + +// markAccountUpdated is a helper that updates the last updated timestamp of +// the account with the given ID. +func (s *SQLStore) markAccountUpdated(ctx context.Context, + db SQLQueries, id int64) error { + + _, err := db.UpdateAccountLastUpdate( + ctx, sqlc.UpdateAccountLastUpdateParams{ + ID: id, + LastUpdated: s.clock.Now().UTC(), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return ErrAccNotFound + } + + return err +} + +// UpdateAccountBalanceAndExpiry updates the balance and/or expiry of an +// account. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) UpdateAccountBalanceAndExpiry(ctx context.Context, + alias AccountID, newBalance fn.Option[int64], + newExpiry fn.Option[time.Time]) error { + + var ( + writeTxOpts db.QueriesTxOptions + ) + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + newBalance.WhenSome(func(i int64) { + _, err = db.UpdateAccountBalance( + ctx, sqlc.UpdateAccountBalanceParams{ + ID: id, + CurrentBalanceMsat: i, + }, + ) + }) + if err != nil { + return err + } + + newExpiry.WhenSome(func(t time.Time) { + _, err = db.UpdateAccountExpiry( + ctx, sqlc.UpdateAccountExpiryParams{ + ID: id, + Expiration: t.UTC(), + }, + ) + }) + if err != nil { + return err + } + + return s.markAccountUpdated(ctx, db, id) + }) +} + +// CreditAccount increases the balance of the account with the given alias by +// the given amount. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) CreditAccount(ctx context.Context, alias AccountID, + amount lnwire.MilliSatoshi) error { + + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + _, err = db.CreditAccount( + ctx, sqlc.CreditAccountParams{ + ID: id, + Amount: int64(amount), + }, + ) + if err != nil { + return err + } + + return s.markAccountUpdated(ctx, db, id) + }) +} + +// DebitAccount decreases the balance of the account with the given alias by +// the given amount. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) DebitAccount(ctx context.Context, alias AccountID, + amount lnwire.MilliSatoshi) error { + + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + id, err = db.DebitAccount( + ctx, sqlc.DebitAccountParams{ + ID: id, + Amount: int64(amount), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("cannot debit %v from the account "+ + "balance, as the resulting balance would be "+ + "below 0", int64(amount/1000)) + } else if err != nil { + return err + } + + return s.markAccountUpdated(ctx, db, id) + }) +} + +// Account retrieves an account from the SQL store and un-marshals it. If the +// account cannot be found, then ErrAccNotFound is returned. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) Account(ctx context.Context, alias AccountID) ( + *OffChainBalanceAccount, error) { + + var ( + readTxOpts = db.NewQueryReadTx() + account *OffChainBalanceAccount + ) + err := s.db.ExecTx(ctx, &readTxOpts, func(db SQLQueries) error { + id, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + account, err = getAndMarshalAccount(ctx, db, id) + return err + }) + + return account, err +} + +// Accounts retrieves all accounts from the SQL store and un-marshals them. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) Accounts(ctx context.Context) ([]*OffChainBalanceAccount, + error) { + + var ( + readTxOpts = db.NewQueryReadTx() + accounts []*OffChainBalanceAccount + ) + err := s.db.ExecTx(ctx, &readTxOpts, func(db SQLQueries) error { + dbAccounts, err := db.ListAllAccounts(ctx) + if err != nil { + return err + } + + accounts = make([]*OffChainBalanceAccount, len(dbAccounts)) + for i, dbAccount := range dbAccounts { + account, err := marshalDBAccount(ctx, db, dbAccount) + if err != nil { + return err + } + + accounts[i] = account + } + + return nil + }) + + return accounts, err +} + +// RemoveAccount finds an account by its ID and removes it from the DB. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) RemoveAccount(ctx context.Context, alias AccountID) error { + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + return db.DeleteAccount(ctx, id) + }) +} + +// UpsertAccountPayment updates or inserts a payment entry for the given +// account. Various functional options can be passed to modify the behavior of +// the method. The returned boolean is true if the payment was already known +// before the update. This is to be treated as a best-effort indication if an +// error is also returned since the method may error before the boolean can be +// set correctly. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) UpsertAccountPayment(ctx context.Context, alias AccountID, + hash lntypes.Hash, fullAmount lnwire.MilliSatoshi, + status lnrpc.Payment_PaymentStatus, + options ...UpsertPaymentOption) (bool, error) { + + opts := newUpsertPaymentOption() + for _, o := range options { + o(opts) + } + + var ( + writeTxOpts db.QueriesTxOptions + known bool + ) + return known, s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + payment, err := db.GetAccountPayment( + ctx, sqlc.GetAccountPaymentParams{ + AccountID: id, + Hash: hash[:], + }, + ) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + known = err == nil + + if known { + currStatus := lnrpc.Payment_PaymentStatus( + payment.Status, + ) + if opts.errIfAlreadySucceeded && + successState(currStatus) { + + return ErrAlreadySucceeded + } + + // If the errIfAlreadyPending option is set, we return + // an error if the payment is already in-flight or + // succeeded. + if opts.errIfAlreadyPending && + currStatus != lnrpc.Payment_FAILED { + + return fmt.Errorf("payment with hash %s is "+ + "already in flight or succeeded "+ + "(status %v)", hash, currStatus) + } + + if opts.usePendingAmount { + fullAmount = lnwire.MilliSatoshi( + payment.FullAmountMsat, + ) + } + } else if opts.errIfUnknown { + return ErrPaymentNotAssociated + } + + err = db.UpsertAccountPayment( + ctx, sqlc.UpsertAccountPaymentParams{ + AccountID: id, + Hash: hash[:], + Status: int16(status), + FullAmountMsat: int64(fullAmount), + }, + ) + if err != nil { + return err + } + + if opts.debitAccount { + acct, err := db.GetAccount(ctx, id) + if errors.Is(err, sql.ErrNoRows) { + return ErrAccNotFound + } else if err != nil { + return err + } + + newBalance := acct.CurrentBalanceMsat - + int64(fullAmount) + + _, err = db.UpdateAccountBalance( + ctx, sqlc.UpdateAccountBalanceParams{ + ID: id, + CurrentBalanceMsat: newBalance, + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return ErrAccNotFound + } else if err != nil { + return err + } + } + + return s.markAccountUpdated(ctx, db, id) + }) +} + +// DeleteAccountPayment removes a payment entry from the account with the given +// ID. It will return an error if the payment is not associated with the +// account. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) DeleteAccountPayment(ctx context.Context, alias AccountID, + hash lntypes.Hash) error { + + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := getAccountIDByAlias(ctx, db, alias) + if err != nil { + return err + } + + _, err = db.GetAccountPayment( + ctx, sqlc.GetAccountPaymentParams{ + AccountID: id, + Hash: hash[:], + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("payment with hash %s is not "+ + "associated with this account: %w", hash, + ErrPaymentNotAssociated) + } else if err != nil { + return err + } + + err = db.DeleteAccountPayment( + ctx, sqlc.DeleteAccountPaymentParams{ + AccountID: id, + Hash: hash[:], + }, + ) + if err != nil { + return err + } + + return s.markAccountUpdated(ctx, db, id) + }) +} + +// LastIndexes returns the last invoice add and settle index or +// ErrNoInvoiceIndexKnown if no indexes are known yet. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) LastIndexes(ctx context.Context) (uint64, uint64, error) { + var ( + readTxOpts = db.NewQueryReadTx() + addIndex, settleIndex int64 + ) + err := s.db.ExecTx(ctx, &readTxOpts, func(db SQLQueries) error { + var err error + addIndex, err = db.GetAccountIndex(ctx, addIndexName) + if errors.Is(err, sql.ErrNoRows) { + return ErrNoInvoiceIndexKnown + } else if err != nil { + return err + } + + settleIndex, err = db.GetAccountIndex(ctx, settleIndexName) + if errors.Is(err, sql.ErrNoRows) { + return ErrNoInvoiceIndexKnown + } + + return err + }) + + return uint64(addIndex), uint64(settleIndex), err +} + +// StoreLastIndexes stores the last invoice add and settle index. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) StoreLastIndexes(ctx context.Context, addIndex, + settleIndex uint64) error { + + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + err := db.SetAccountIndex(ctx, sqlc.SetAccountIndexParams{ + Name: addIndexName, + Value: int64(addIndex), + }) + if err != nil { + return err + } + + return db.SetAccountIndex(ctx, sqlc.SetAccountIndexParams{ + Name: settleIndexName, + Value: int64(settleIndex), + }) + }) +} + +// Close closes the underlying store. +// +// NOTE: This is part of the Store interface. +func (s *SQLStore) Close() error { + return s.DB.Close() +} + +// A compile-time check to ensure that SQLStore implements the Store interface. +var _ Store = (*SQLStore)(nil) diff --git a/accounts/store_test.go b/accounts/store_test.go index 3d44df3e7..00084d48b 100644 --- a/accounts/store_test.go +++ b/accounts/store_test.go @@ -2,6 +2,7 @@ package accounts import ( "context" + "github.com/lightningnetwork/lnd/lnwire" "testing" "time" @@ -71,6 +72,14 @@ func TestAccountStore(t *testing.T) { ) require.NoError(t, err) + // Adjust the account balance by first crediting 10000, and then + // debiting 5000. + err = store.CreditAccount(ctx, acct1.ID, lnwire.MilliSatoshi(10000)) + require.NoError(t, err) + + err = store.DebitAccount(ctx, acct1.ID, lnwire.MilliSatoshi(5000)) + require.NoError(t, err) + // Update the in-memory account so that we can compare it with the // account we get from the store. acct1.CurrentBalance = -500 @@ -85,11 +94,30 @@ func TestAccountStore(t *testing.T) { } acct1.Invoices[lntypes.Hash{12, 34, 56, 78}] = struct{}{} acct1.Invoices[lntypes.Hash{34, 56, 78, 90}] = struct{}{} + acct1.CurrentBalance += 10000 + acct1.CurrentBalance -= 5000 dbAccount, err = store.Account(ctx, acct1.ID) require.NoError(t, err) assertEqualAccounts(t, acct1, dbAccount) + // Test that adjusting the balance to exactly 0 should work, while + // adjusting the balance to below 0 should fail. + err = store.DebitAccount( + ctx, acct1.ID, lnwire.MilliSatoshi(acct1.CurrentBalance), + ) + require.NoError(t, err) + + acct1.CurrentBalance = 0 + + dbAccount, err = store.Account(ctx, acct1.ID) + require.NoError(t, err) + assertEqualAccounts(t, acct1, dbAccount) + + // Adjusting the value to below 0 should fail. + err = store.DebitAccount(ctx, acct1.ID, lnwire.MilliSatoshi(1)) + require.ErrorContains(t, err, "balance would be below 0") + // Sleep just a tiny bit to make sure we are never too quick to measure // the expiry, even though the time is nanosecond scale and writing to // the store and reading again should take at least a couple of @@ -130,8 +158,8 @@ func assertEqualAccounts(t *testing.T, expected, actual.LastUpdate = time.Time{} require.Equal(t, expected, actual) - require.Equal(t, expectedExpiry.UnixNano(), actualExpiry.UnixNano()) - require.Equal(t, expectedUpdate.UnixNano(), actualUpdate.UnixNano()) + require.Equal(t, expectedExpiry.Unix(), actualExpiry.Unix()) + require.Equal(t, expectedUpdate.Unix(), actualUpdate.Unix()) // Restore the old values to not influence the tests. expected.ExpirationDate = expectedExpiry @@ -168,7 +196,7 @@ func TestAccountUpdateMethods(t *testing.T) { require.NoError(t, err) require.EqualValues(t, balance, dbAcct.CurrentBalance) require.WithinDuration( - t, expiry, dbAcct.ExpirationDate, 0, + t, expiry, dbAcct.ExpirationDate, time.Second, ) } @@ -262,12 +290,12 @@ func TestAccountUpdateMethods(t *testing.T) { assertInvoices(hash1, hash2) }) - t.Run("IncreaseAccountBalance", func(t *testing.T) { + t.Run("CreditAccount", func(t *testing.T) { store := NewTestDB(t, clock.NewTestClock(time.Now())) // Increasing the balance of an account that doesn't exist // should error out. - err := store.IncreaseAccountBalance(ctx, AccountID{}, 100) + err := store.CreditAccount(ctx, AccountID{}, 100) require.ErrorIs(t, err, ErrAccNotFound) acct, err := store.NewAccount(ctx, 123, time.Time{}, "foo") @@ -284,7 +312,7 @@ func TestAccountUpdateMethods(t *testing.T) { // Increase the balance by 100 and assert that the new balance // is 223. - err = store.IncreaseAccountBalance(ctx, acct.ID, 100) + err = store.CreditAccount(ctx, acct.ID, 100) require.NoError(t, err) assertBalance(223) diff --git a/accounts/test_kvdb.go b/accounts/test_kvdb.go index 224a8912c..99c3e2ae6 100644 --- a/accounts/test_kvdb.go +++ b/accounts/test_kvdb.go @@ -1,3 +1,5 @@ +//go:build !test_db_sqlite && !test_db_postgres + package accounts import ( diff --git a/accounts/test_postgres.go b/accounts/test_postgres.go new file mode 100644 index 000000000..013c18a04 --- /dev/null +++ b/accounts/test_postgres.go @@ -0,0 +1,28 @@ +//go:build test_db_postgres && !test_db_sqlite + +package accounts + +import ( + "errors" + "testing" + + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" +) + +// ErrDBClosed is an error that is returned when a database operation is +// performed on a closed database. +var ErrDBClosed = errors.New("database is closed") + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { + return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) *SQLStore { + + return NewSQLStore(db.NewTestPostgresDB(t).BaseDB, clock) +} diff --git a/accounts/test_sqlite.go b/accounts/test_sqlite.go new file mode 100644 index 000000000..07319268d --- /dev/null +++ b/accounts/test_sqlite.go @@ -0,0 +1,30 @@ +//go:build test_db_sqlite && !test_db_postgres + +package accounts + +import ( + "errors" + "testing" + + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightningnetwork/lnd/clock" +) + +// ErrDBClosed is an error that is returned when a database operation is +// performed on a closed database. +var ErrDBClosed = errors.New("database is closed") + +// NewTestDB is a helper function that creates an BBolt database for testing. +func NewTestDB(t *testing.T, clock clock.Clock) *SQLStore { + return NewSQLStore(db.NewTestSqliteDB(t).BaseDB, clock) +} + +// NewTestDBFromPath is a helper function that creates a new BoltStore with a +// connection to an existing BBolt database for testing. +func NewTestDBFromPath(t *testing.T, dbPath string, + clock clock.Clock) *SQLStore { + + return NewSQLStore( + db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock, + ) +} diff --git a/db/sqlc/accounts.sql.go b/db/sqlc/accounts.sql.go index 4deefdb88..396aaaa46 100644 --- a/db/sqlc/accounts.sql.go +++ b/db/sqlc/accounts.sql.go @@ -26,6 +26,45 @@ func (q *Queries) AddAccountInvoice(ctx context.Context, arg AddAccountInvoicePa return err } +const creditAccount = `-- name: CreditAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat + $1 +WHERE id = $2 +RETURNING id +` + +type CreditAccountParams struct { + Amount int64 + ID int64 +} + +func (q *Queries) CreditAccount(ctx context.Context, arg CreditAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, creditAccount, arg.Amount, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + +const debitAccount = `-- name: DebitAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat - $1 +WHERE id = $2 +AND current_balance_msat >= $1 +RETURNING id +` + +type DebitAccountParams struct { + Amount int64 + ID int64 +} + +func (q *Queries) DebitAccount(ctx context.Context, arg DebitAccountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, debitAccount, arg.Amount, arg.ID) + var id int64 + err := row.Scan(&id) + return id, err +} + const deleteAccount = `-- name: DeleteAccount :exec DELETE FROM accounts WHERE id = $1 diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index b0265c596..859fce0a2 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -11,6 +11,8 @@ import ( type Querier interface { AddAccountInvoice(ctx context.Context, arg AddAccountInvoiceParams) error + CreditAccount(ctx context.Context, arg CreditAccountParams) (int64, error) + DebitAccount(ctx context.Context, arg DebitAccountParams) (int64, error) DeleteAccount(ctx context.Context, id int64) error DeleteAccountPayment(ctx context.Context, arg DeleteAccountPaymentParams) error GetAccount(ctx context.Context, id int64) (Account, error) diff --git a/db/sqlc/queries/accounts.sql b/db/sqlc/queries/accounts.sql index 637a49727..caa7aac48 100644 --- a/db/sqlc/queries/accounts.sql +++ b/db/sqlc/queries/accounts.sql @@ -9,6 +9,19 @@ SET current_balance_msat = $1 WHERE id = $2 RETURNING id; +-- name: CreditAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat + sqlc.arg(amount) +WHERE id = sqlc.arg(id) +RETURNING id; + +-- name: DebitAccount :one +UPDATE accounts +SET current_balance_msat = current_balance_msat - sqlc.arg(amount) +WHERE id = sqlc.arg(id) +AND current_balance_msat >= sqlc.arg(amount) +RETURNING id; + -- name: UpdateAccountExpiry :one UPDATE accounts SET expiration = $1 diff --git a/make/testing_flags.mk b/make/testing_flags.mk index 7687a9431..0370bc29f 100644 --- a/make/testing_flags.mk +++ b/make/testing_flags.mk @@ -24,6 +24,16 @@ UNIT_TARGETED = yes GOLIST = echo '$(PKG)/$(pkg)' endif +# Add the build tag for running unit tests against a postgres DB. +ifeq ($(dbbackend),postgres) +DEV_TAGS += test_db_postgres +endif + +# Add the build tag for running unit tests against a sqlite DB. +ifeq ($(dbbackend),sqlite) +DEV_TAGS += test_db_sqlite +endif + # Add any additional tags that are passed in to make. ifneq ($(tags),) DEV_TAGS += ${tags}