@@ -494,3 +494,64 @@ func (r BillingTransactionRepository) GetBalanceForRange(ctx context.Context, ac
494494 }
495495 return amount , nil
496496}
497+
498+ func (r BillingTransactionRepository ) getCreditBalanceExcludingSource (ctx context.Context , tx * sqlx.Tx , accountID string ,
499+ start * time.Time , end * time.Time , excludeSource string ) (* int64 , error ) {
500+ stmt := dialect .Select (goqu .SUM ("amount" )).From (TABLE_BILLING_TRANSACTIONS ).Where (goqu.Ex {
501+ "account_id" : accountID ,
502+ "type" : credit .CreditType ,
503+ }).Where (goqu .C ("source" ).Neq (excludeSource ))
504+ if start != nil {
505+ stmt = stmt .Where (goqu.Ex {
506+ "created_at" : goqu.Op {"gte" : * start },
507+ })
508+ }
509+ if end != nil {
510+ stmt = stmt .Where (goqu.Ex {
511+ "created_at" : goqu.Op {"lt" : * end },
512+ })
513+ }
514+ query , params , err := stmt .ToSQL ()
515+ if err != nil {
516+ return nil , fmt .Errorf ("%w: %s" , parseErr , err )
517+ }
518+
519+ var creditBalance * int64
520+ if err = r .dbc .WithTimeout (ctx , TABLE_BILLING_TRANSACTIONS , "GetCreditBalanceExcludingSource" , func (ctx context.Context ) error {
521+ return tx .QueryRowxContext (ctx , query , params ... ).Scan (& creditBalance )
522+ }); err != nil {
523+ return nil , fmt .Errorf ("%w: %s" , dbErr , err )
524+ }
525+ return creditBalance , nil
526+ }
527+
528+ // GetBalanceForRangeWithoutOverdraft returns the balance excluding credit transactions
529+ // with source 'system.overdraft'. This prevents reconciliation credits from inflating
530+ // the balance when calculating credit overdraft invoices.
531+ func (r BillingTransactionRepository ) GetBalanceForRangeWithoutOverdraft (ctx context.Context , accountID string , start time.Time ,
532+ end time.Time ) (int64 , error ) {
533+ var amount int64
534+ if err := r .dbc .WithTxn (ctx , sql.TxOptions {
535+ Isolation : sql .LevelSerializable ,
536+ }, func (tx * sqlx.Tx ) error {
537+ debitBalance , err := r .getDebitBalance (ctx , tx , accountID , & start , & end )
538+ if err != nil {
539+ return fmt .Errorf ("failed to get debit balance: %w" , err )
540+ }
541+ creditBalance , err := r .getCreditBalanceExcludingSource (ctx , tx , accountID , & start , & end , credit .SourceSystemOverdraftEvent )
542+ if err != nil {
543+ return fmt .Errorf ("failed to get credit balance: %w" , err )
544+ }
545+ if creditBalance == nil {
546+ creditBalance = new (int64 )
547+ }
548+ if debitBalance == nil {
549+ debitBalance = new (int64 )
550+ }
551+ amount = * creditBalance - * debitBalance
552+ return nil
553+ }); err != nil {
554+ return 0 , fmt .Errorf ("failed to get balance: %w" , err )
555+ }
556+ return amount , nil
557+ }
0 commit comments