Skip to content

Commit bee8db5

Browse files
authored
feat: user cost (#126)
#58 <img height="300" alt="image" src="https://github.com/user-attachments/assets/ae4a2b1d-4384-40b5-bb80-c286cf9a5e34" />
1 parent 2e593f7 commit bee8db5

32 files changed

Lines changed: 2975 additions & 25 deletions

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
5555
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
5656
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
5757
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
58+
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
59+
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
5860
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
5961
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
6062
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -160,6 +162,8 @@ golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sU
160162
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0=
161163
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4=
162164
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
165+
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
166+
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
163167
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
164168
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
165169
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@@ -188,6 +192,8 @@ golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
188192
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
189193
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
190194
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
195+
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
196+
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
191197
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
192198
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
193199
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=

internal/api/chat/create_conversation_message_stream_v2.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
281281
APIKey: settings.OpenAIAPIKey,
282282
}
283283

284-
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
284+
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.UserID, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
285285
if err != nil {
286286
return s.sendStreamError(stream, err)
287287
}
@@ -307,7 +307,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
307307
for i, bsonMsg := range conversation.InappChatHistory {
308308
protoMessages[i] = mapper.BSONToChatMessageV2(bsonMsg)
309309
}
310-
title, err := s.aiClientV2.GetConversationTitleV2(ctx, protoMessages, llmProvider)
310+
title, err := s.aiClientV2.GetConversationTitleV2(ctx, conversation.UserID, protoMessages, llmProvider)
311311
if err != nil {
312312
s.logger.Error("Failed to get conversation title", "error", err, "conversationID", conversation.ID.Hex())
313313
return

internal/api/grpc.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
1616
commentv1 "paperdebugger/pkg/gen/api/comment/v1"
1717
projectv1 "paperdebugger/pkg/gen/api/project/v1"
18+
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
1819
userv1 "paperdebugger/pkg/gen/api/user/v1"
1920

2021
// "github.com/grpc-ecosystem/go-grpc-middleware"
@@ -106,6 +107,7 @@ func NewGrpcServer(
106107
userServer userv1.UserServiceServer,
107108
projectServer projectv1.ProjectServiceServer,
108109
commentServer commentv1.CommentServiceServer,
110+
usageServer usagev1.UsageServiceServer,
109111
) *GrpcServer {
110112
grpcServer := &GrpcServer{}
111113
grpcServer.userService = userService
@@ -121,5 +123,6 @@ func NewGrpcServer(
121123
userv1.RegisterUserServiceServer(grpcServer.Server, userServer)
122124
projectv1.RegisterProjectServiceServer(grpcServer.Server, projectServer)
123125
commentv1.RegisterCommentServiceServer(grpcServer.Server, commentServer)
126+
usagev1.RegisterUsageServiceServer(grpcServer.Server, usageServer)
124127
return grpcServer
125128
}

internal/api/server.go

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@ import (
66
"fmt"
77
"net"
88
"net/http"
9+
"os"
10+
"os/signal"
911
"strings"
12+
"syscall"
1013

1114
"paperdebugger/internal/libs/logger"
1215
"paperdebugger/internal/libs/metadatautil"
1316
"paperdebugger/internal/libs/shared"
17+
"paperdebugger/internal/services"
18+
aiclient "paperdebugger/internal/services/toolkit/client"
1419
authv1 "paperdebugger/pkg/gen/api/auth/v1"
1520
chatv1 "paperdebugger/pkg/gen/api/chat/v1"
1621
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
1722
commentv1 "paperdebugger/pkg/gen/api/comment/v1"
1823
projectv1 "paperdebugger/pkg/gen/api/project/v1"
1924
sharedv1 "paperdebugger/pkg/gen/api/shared/v1"
25+
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
2026
userv1 "paperdebugger/pkg/gen/api/user/v1"
2127

2228
"github.com/gin-gonic/gin"
@@ -30,25 +36,37 @@ import (
3036
)
3137

3238
type Server struct {
33-
grpcServer *GrpcServer
34-
ginServer *GinServer
39+
grpcServer *GrpcServer
40+
ginServer *GinServer
41+
pricingService *services.PricingService
42+
aiClientV2 *aiclient.AIClientV2
3543

3644
logger *logger.Logger
3745
}
3846

3947
func NewServer(
4048
grpcServer *GrpcServer,
4149
ginServer *GinServer,
50+
pricingService *services.PricingService,
51+
aiClientV2 *aiclient.AIClientV2,
4252
logger *logger.Logger,
4353
) *Server {
4454
return &Server{
45-
grpcServer: grpcServer,
46-
ginServer: ginServer,
47-
logger: logger,
55+
grpcServer: grpcServer,
56+
ginServer: ginServer,
57+
pricingService: pricingService,
58+
aiClientV2: aiClientV2,
59+
logger: logger,
4860
}
4961
}
5062

5163
func (s *Server) Run(addr string) {
64+
// Start the pricing updater in the background
65+
ctx, cancel := context.WithCancel(context.Background())
66+
defer cancel()
67+
68+
s.pricingService.StartPriceUpdater(ctx)
69+
5270
listener, err := net.Listen("tcp", ":0")
5371
if err != nil {
5472
s.logger.Fatalf("failed to start grpc server listener: %v", err)
@@ -105,6 +123,22 @@ func (s *Server) Run(addr string) {
105123
s.logger.Fatalf("failed to register comment service grpc gateway: %v", err)
106124
return
107125
}
126+
err = usagev1.RegisterUsageServiceHandler(context.Background(), mux, client)
127+
if err != nil {
128+
s.logger.Fatalf("failed to register usage service grpc gateway: %v", err)
129+
return
130+
}
131+
132+
// Set up signal handling for graceful shutdown
133+
sigChan := make(chan os.Signal, 1)
134+
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
135+
136+
go func() {
137+
<-sigChan
138+
s.logger.Info("[PAPERDEBUGGER] received shutdown signal, shutting down gracefully...")
139+
s.Shutdown()
140+
os.Exit(0)
141+
}()
108142

109143
s.logger.Infof("[PAPERDEBUGGER] http server listening on %s", addr)
110144
s.ginServer.Any("/_pd/api/*path", func(c *gin.Context) { mux.ServeHTTP(c.Writer, c.Request) })
@@ -114,6 +148,16 @@ func (s *Server) Run(addr string) {
114148
}
115149
}
116150

151+
// Shutdown gracefully shuts down all server components.
152+
func (s *Server) Shutdown() {
153+
s.logger.Info("[PAPERDEBUGGER] shutting down AI client (draining usage records)...")
154+
s.aiClientV2.Shutdown()
155+
s.logger.Info("[PAPERDEBUGGER] AI client shutdown complete")
156+
157+
s.grpcServer.GracefulStop()
158+
s.logger.Info("[PAPERDEBUGGER] gRPC server shutdown complete")
159+
}
160+
117161
func (s *Server) metadataAnnotator() func(ctx context.Context, req *http.Request) metadata.MD {
118162
return func(ctx context.Context, req *http.Request) metadata.MD {
119163
md := metadata.New(map[string]string{})
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package usage
2+
3+
import (
4+
"context"
5+
6+
"paperdebugger/internal/libs/contextutil"
7+
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
8+
9+
"google.golang.org/protobuf/types/known/timestamppb"
10+
)
11+
12+
func (s *UsageServer) GetSessionUsage(
13+
ctx context.Context,
14+
req *usagev1.GetSessionUsageRequest,
15+
) (*usagev1.GetSessionUsageResponse, error) {
16+
actor, err := contextutil.GetActor(ctx)
17+
if err != nil {
18+
return nil, err
19+
}
20+
21+
// Get session with costs already calculated by the service layer
22+
session, err := s.usageService.GetActiveSessionWithCosts(ctx, actor.ID)
23+
if err != nil {
24+
return nil, err
25+
}
26+
27+
if session == nil {
28+
return &usagev1.GetSessionUsageResponse{
29+
Session: nil,
30+
}, nil
31+
}
32+
33+
return &usagev1.GetSessionUsageResponse{
34+
Session: &usagev1.SessionUsage{
35+
SessionExpiry: timestamppb.New(session.SessionExpiry),
36+
Models: convertModelsToProto(session.Models),
37+
TotalCostUsd: session.TotalCostUSD,
38+
},
39+
}, nil
40+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package usage
2+
3+
import (
4+
"context"
5+
6+
"paperdebugger/internal/libs/contextutil"
7+
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
8+
)
9+
10+
func (s *UsageServer) GetWeeklyUsage(
11+
ctx context.Context,
12+
req *usagev1.GetWeeklyUsageRequest,
13+
) (*usagev1.GetWeeklyUsageResponse, error) {
14+
actor, err := contextutil.GetActor(ctx)
15+
if err != nil {
16+
return nil, err
17+
}
18+
19+
// Get weekly stats with costs already calculated by the service layer
20+
stats, err := s.usageService.GetWeeklyUsageWithCosts(ctx, actor.ID)
21+
if err != nil {
22+
return nil, err
23+
}
24+
25+
return &usagev1.GetWeeklyUsageResponse{
26+
Usage: &usagev1.WeeklyUsage{
27+
Models: convertModelsToProto(stats.Models),
28+
SessionCount: stats.SessionCount,
29+
TotalCostUsd: stats.TotalCostUSD,
30+
},
31+
}, nil
32+
}

internal/api/usage/server.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package usage
2+
3+
import (
4+
"paperdebugger/internal/libs/logger"
5+
"paperdebugger/internal/services"
6+
usagev1 "paperdebugger/pkg/gen/api/usage/v1"
7+
)
8+
9+
type UsageServer struct {
10+
usagev1.UnimplementedUsageServiceServer
11+
12+
usageService *services.UsageService
13+
logger *logger.Logger
14+
}
15+
16+
func NewUsageServer(
17+
usageService *services.UsageService,
18+
logger *logger.Logger,
19+
) usagev1.UsageServiceServer {
20+
return &UsageServer{
21+
usageService: usageService,
22+
logger: logger,
23+
}
24+
}
25+
26+
// convertModelsToProto converts ModelUsageStats to proto format.
27+
// Costs are already calculated by the service layer.
28+
func convertModelsToProto(models map[string]*services.ModelUsageStats) map[string]*usagev1.ModelTokens {
29+
protoModels := make(map[string]*usagev1.ModelTokens, len(models))
30+
31+
for modelName, stats := range models {
32+
protoModels[modelName] = &usagev1.ModelTokens{
33+
PromptTokens: stats.PromptTokens,
34+
CompletionTokens: stats.CompletionTokens,
35+
TotalTokens: stats.TotalTokens,
36+
RequestCount: stats.RequestCount,
37+
CostUsd: stats.CostUSD,
38+
}
39+
}
40+
41+
return protoModels
42+
}

internal/libs/db/db.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"paperdebugger/internal/libs/cfg"
88
"paperdebugger/internal/libs/logger"
9+
"paperdebugger/internal/models"
910

1011
"go.mongodb.org/mongo-driver/v2/bson"
1112
"go.mongodb.org/mongo-driver/v2/mongo"
@@ -43,5 +44,47 @@ func NewDB(cfg *cfg.Cfg, logger *logger.Logger) (*DB, error) {
4344
}
4445

4546
logger.Info("[MONGO] initialized")
46-
return &DB{Client: client, cfg: cfg, logger: logger}, nil
47+
48+
db := &DB{Client: client, cfg: cfg, logger: logger}
49+
db.ensureIndexes()
50+
return db, nil
51+
}
52+
53+
// ensureIndexes creates necessary indexes for the database collections.
54+
func (db *DB) ensureIndexes() {
55+
sessions := db.Database("paperdebugger").Collection((models.LLMSession{}).CollectionName())
56+
57+
// TTL index: auto-delete sessions after 30 days past their expiry time
58+
_, err := sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{
59+
Keys: bson.D{{Key: "session_expiry", Value: 1}},
60+
Options: options.Index().SetExpireAfterSeconds(30 * 24 * 60 * 60),
61+
})
62+
if err != nil {
63+
db.logger.Error("Failed to create TTL index on llm_sessions", "error", err)
64+
}
65+
66+
// Compound index for efficient active session lookups
67+
_, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{
68+
Keys: bson.D{
69+
{Key: "user_id", Value: 1},
70+
{Key: "session_expiry", Value: -1},
71+
},
72+
})
73+
if err != nil {
74+
db.logger.Error("Failed to create compound index on llm_sessions", "error", err)
75+
}
76+
77+
// Unique compound index for session creation and queries.
78+
// session_start is rounded to the second, so concurrent requests within the same
79+
// second will conflict, triggering duplicate key handling in RecordUsage.
80+
_, err = sessions.Indexes().CreateOne(context.Background(), mongo.IndexModel{
81+
Keys: bson.D{
82+
{Key: "user_id", Value: 1},
83+
{Key: "session_start", Value: -1},
84+
},
85+
Options: options.Index().SetUnique(true),
86+
})
87+
if err != nil {
88+
db.logger.Error("Failed to create session_start index on llm_sessions", "error", err)
89+
}
4790
}

internal/models/model_pricing.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package models
2+
3+
import (
4+
"time"
5+
6+
"go.mongodb.org/mongo-driver/v2/bson"
7+
)
8+
9+
// ModelPricing stores the pricing information for an LLM model.
10+
// Prices are in USD per token.
11+
type ModelPricing struct {
12+
ID bson.ObjectID `bson:"_id"`
13+
ModelID string `bson:"model_id"` // e.g., "openai/gpt-4"
14+
ModelSlug string `bson:"model_slug"` // e.g., "gpt-4" (short name used in our app)
15+
Name string `bson:"name"` // e.g., "OpenAI: GPT-4"
16+
PromptPrice float64 `bson:"prompt_price"` // USD per token
17+
CompletionPrice float64 `bson:"completion_price"` // USD per token
18+
UpdatedAt time.Time `bson:"updated_at"`
19+
}
20+
21+
func (m ModelPricing) CollectionName() string {
22+
return "model_pricing"
23+
}

internal/models/usage.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package models
2+
3+
import "go.mongodb.org/mongo-driver/v2/bson"
4+
5+
// ModelTokens stores token counts for a specific model.
6+
type ModelTokens struct {
7+
PromptTokens int64 `bson:"prompt_tokens"`
8+
CompletionTokens int64 `bson:"completion_tokens"`
9+
TotalTokens int64 `bson:"total_tokens"`
10+
RequestCount int64 `bson:"request_count"`
11+
}
12+
13+
// LLMSession represents a user's session for tracking LLM usage and token counts.
14+
// Tokens are stored per model in the Models map.
15+
type LLMSession struct {
16+
ID bson.ObjectID `bson:"_id"`
17+
UserID bson.ObjectID `bson:"user_id"`
18+
SessionStart bson.DateTime `bson:"session_start"`
19+
SessionExpiry bson.DateTime `bson:"session_expiry"`
20+
Models map[string]*ModelTokens `bson:"models"`
21+
}
22+
23+
func (s LLMSession) CollectionName() string {
24+
return "llm_sessions"
25+
}

0 commit comments

Comments
 (0)