Skip to content

Commit 4e7b2be

Browse files
authored
fix(billing): exclude overdraft credits from credit overdraft balance (#1502)
1 parent 607e429 commit 4e7b2be

4 files changed

Lines changed: 129 additions & 1 deletion

File tree

billing/credit/mocks/transaction_repository.go

Lines changed: 59 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

billing/credit/service.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ type TransactionRepository interface {
1919
List(ctx context.Context, flt Filter) ([]Transaction, error)
2020
GetByID(ctx context.Context, id string) (Transaction, error)
2121
GetBalanceForRange(ctx context.Context, accountID string, start time.Time, end time.Time) (int64, error)
22+
GetBalanceForRangeWithoutOverdraft(ctx context.Context, accountID string, start time.Time, end time.Time) (int64, error)
2223
}
2324

2425
type CustomerRepository interface {
@@ -164,6 +165,12 @@ func (s Service) GetBalanceForRange(ctx context.Context, accountID string, start
164165
return s.transactionRepository.GetBalanceForRange(ctx, accountID, start, end)
165166
}
166167

168+
// GetBalanceForRangeWithoutOverdraft returns the balance for the given accountID within the given time range
169+
// excluding credit transactions sourced from overdraft reconciliation
170+
func (s Service) GetBalanceForRangeWithoutOverdraft(ctx context.Context, accountID string, start time.Time, end time.Time) (int64, error) {
171+
return s.transactionRepository.GetBalanceForRangeWithoutOverdraft(ctx, accountID, start, end)
172+
}
173+
167174
func (s Service) GetByID(ctx context.Context, id string) (Transaction, error) {
168175
return s.transactionRepository.GetByID(ctx, id)
169176
}

billing/invoice/service.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ type CustomerService interface {
5555
type CreditService interface {
5656
Add(ctx context.Context, cred credit.Credit) error
5757
GetBalanceForRange(ctx context.Context, accountID string, start time.Time, end time.Time) (int64, error)
58+
GetBalanceForRangeWithoutOverdraft(ctx context.Context, accountID string, start time.Time, end time.Time) (int64, error)
5859
}
5960

6061
type ProductService interface {
@@ -522,7 +523,7 @@ func (s *Service) GenerateForCredits(ctx context.Context) error {
522523
continue
523524
}
524525

525-
balance, err := s.creditService.GetBalanceForRange(ctx, c.ID, startRange, endRange)
526+
balance, err := s.creditService.GetBalanceForRangeWithoutOverdraft(ctx, c.ID, startRange, endRange)
526527
if err != nil {
527528
errs = append(errs, fmt.Errorf("failed to get balance for customer %s: %w", c.ID, err))
528529
continue

internal/store/postgres/billing_transactions_repository.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,64 @@ func (r BillingTransactionRepository) GetBalanceForRange(ctx context.Context, ac
494494
}
495495
return amount, nil
496496
}
497+
498+
func (r BillingTransactionRepository) getCreditBalanceExcludingSource(ctx context.Context, tx *sqlx.Tx, accountID string,
499+
start *time.Time, end *time.Time, excludeSource string) (*int64, error) {
500+
stmt := dialect.Select(goqu.SUM("amount")).From(TABLE_BILLING_TRANSACTIONS).Where(goqu.Ex{
501+
"account_id": accountID,
502+
"type": credit.CreditType,
503+
}).Where(goqu.C("source").Neq(excludeSource))
504+
if start != nil {
505+
stmt = stmt.Where(goqu.Ex{
506+
"created_at": goqu.Op{"gte": *start},
507+
})
508+
}
509+
if end != nil {
510+
stmt = stmt.Where(goqu.Ex{
511+
"created_at": goqu.Op{"lt": *end},
512+
})
513+
}
514+
query, params, err := stmt.ToSQL()
515+
if err != nil {
516+
return nil, fmt.Errorf("%w: %s", parseErr, err)
517+
}
518+
519+
var creditBalance *int64
520+
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "GetCreditBalanceExcludingSource", func(ctx context.Context) error {
521+
return tx.QueryRowxContext(ctx, query, params...).Scan(&creditBalance)
522+
}); err != nil {
523+
return nil, fmt.Errorf("%w: %s", dbErr, err)
524+
}
525+
return creditBalance, nil
526+
}
527+
528+
// GetBalanceForRangeWithoutOverdraft returns the balance excluding credit transactions
529+
// with source 'system.overdraft'. This prevents reconciliation credits from inflating
530+
// the balance when calculating credit overdraft invoices.
531+
func (r BillingTransactionRepository) GetBalanceForRangeWithoutOverdraft(ctx context.Context, accountID string, start time.Time,
532+
end time.Time) (int64, error) {
533+
var amount int64
534+
if err := r.dbc.WithTxn(ctx, sql.TxOptions{
535+
Isolation: sql.LevelSerializable,
536+
}, func(tx *sqlx.Tx) error {
537+
debitBalance, err := r.getDebitBalance(ctx, tx, accountID, &start, &end)
538+
if err != nil {
539+
return fmt.Errorf("failed to get debit balance: %w", err)
540+
}
541+
creditBalance, err := r.getCreditBalanceExcludingSource(ctx, tx, accountID, &start, &end, credit.SourceSystemOverdraftEvent)
542+
if err != nil {
543+
return fmt.Errorf("failed to get credit balance: %w", err)
544+
}
545+
if creditBalance == nil {
546+
creditBalance = new(int64)
547+
}
548+
if debitBalance == nil {
549+
debitBalance = new(int64)
550+
}
551+
amount = *creditBalance - *debitBalance
552+
return nil
553+
}); err != nil {
554+
return 0, fmt.Errorf("failed to get balance: %w", err)
555+
}
556+
return amount, nil
557+
}

0 commit comments

Comments
 (0)