@@ -13,6 +13,7 @@ import (
1313 "github.com/openmeterio/openmeter/api"
1414 "github.com/openmeterio/openmeter/openmeter/app"
1515 "github.com/openmeterio/openmeter/openmeter/billing"
16+ "github.com/openmeterio/openmeter/openmeter/customer"
1617 "github.com/openmeterio/openmeter/openmeter/ent/db"
1718 "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice"
1819 "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline"
@@ -606,33 +607,56 @@ func (a *adapter) UpdateInvoice(ctx context.Context, in billing.UpdateInvoiceAda
606607
607608func (a * adapter ) GetInvoiceOwnership (ctx context.Context , in billing.GetInvoiceOwnershipAdapterInput ) (billing.GetOwnershipAdapterResponse , error ) {
608609 if err := in .Validate (); err != nil {
609- return billing. GetOwnershipAdapterResponse {} , billing.ValidationError {
610+ return nil , billing.ValidationError {
610611 Err : err ,
611612 }
612613 }
613614
614615 return entutils .TransactingRepo (ctx , a , func (ctx context.Context , tx * adapter ) (billing.GetOwnershipAdapterResponse , error ) {
615- dbInvoice , err := tx .db .BillingInvoice .Query ().
616- Where (billinginvoice .ID (in .ID )).
617- Where (billinginvoice .Namespace (in .Namespace )).
618- First (ctx )
616+ dbInvoices , err := tx .db .BillingInvoice .Query ().
617+ Where (
618+ billinginvoice .IDIn (
619+ lo .Map (
620+ in .InvoiceIDs ,
621+ func (invoiceID billing.InvoiceID , _ int ) string {
622+ return invoiceID .ID
623+ },
624+ )... ,
625+ ),
626+ ).
627+ All (ctx )
619628 if err != nil {
620- if db .IsNotFound (err ) {
621- return billing.GetOwnershipAdapterResponse {}, billing.NotFoundError {
622- Entity : billing .EntityInvoice ,
623- ID : in .ID ,
624- Err : err ,
629+ return nil , err
630+ }
631+
632+ invoiceToCustomerID := lo .SliceToMap (dbInvoices , func (dbInvoice * db.BillingInvoice ) (billing.InvoiceID , customer.CustomerID ) {
633+ return billing.InvoiceID {
634+ Namespace : dbInvoice .Namespace ,
635+ ID : dbInvoice .ID ,
636+ }, customer.CustomerID {
637+ Namespace : dbInvoice .Namespace ,
638+ ID : dbInvoice .CustomerID ,
625639 }
640+ })
641+
642+ // Let's validate if we got all the invoices (and most importantly look up invoices with
643+ // namespaceID, to prevent looking up invoices with different than expected namespace ID)
644+ var notFoundErrs []error
645+ for _ , invoiceID := range in .InvoiceIDs {
646+ if _ , found := invoiceToCustomerID [invoiceID ]; ! found {
647+ notFoundErrs = append (notFoundErrs , billing.NotFoundError {
648+ Entity : billing .EntityInvoice ,
649+ ID : invoiceID .ID ,
650+ Err : fmt .Errorf ("invoice not found: %s" , invoiceID .ID ),
651+ })
626652 }
653+ }
627654
628- return billing.GetOwnershipAdapterResponse {}, err
655+ if len (notFoundErrs ) > 0 {
656+ return nil , errors .Join (notFoundErrs ... )
629657 }
630658
631- return billing.GetOwnershipAdapterResponse {
632- Namespace : dbInvoice .Namespace ,
633- InvoiceID : dbInvoice .ID ,
634- CustomerID : dbInvoice .CustomerID ,
635- }, nil
659+ return invoiceToCustomerID , nil
636660 })
637661}
638662
0 commit comments