Skip to content

Commit 7b39a25

Browse files
committed
feat: track OpenRouter usage
1 parent 2e593f7 commit 7b39a25

13 files changed

Lines changed: 4571 additions & 963 deletions

File tree

internal/api/chat/create_conversation_message_stream_v2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ func (s *ChatServerV2) CreateConversationMessageStream(
281281
APIKey: settings.OpenAIAPIKey,
282282
}
283283

284-
openaiChatHistory, inappChatHistory, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
284+
openaiChatHistory, inappChatHistory, _, err := s.aiClientV2.ChatCompletionStreamV2(ctx, stream, conversation.UserID, conversation.ProjectID, conversation.ID.Hex(), modelSlug, conversation.OpenaiChatHistoryCompletion, llmProvider)
285285
if err != nil {
286286
return s.sendStreamError(stream, err)
287287
}

internal/models/usage.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package models
2+
3+
import (
4+
"time"
5+
6+
"go.mongodb.org/mongo-driver/v2/bson"
7+
)
8+
9+
// Usage tracks cost per user, per project, per model, per hour.
10+
// Each document represents one hour bucket of usage.
11+
type Usage struct {
12+
ID bson.ObjectID `bson:"_id"`
13+
UserID bson.ObjectID `bson:"user_id"`
14+
ProjectID string `bson:"project_id"`
15+
ModelSlug string `bson:"model_slug"`
16+
HourBucket bson.DateTime `bson:"hour_bucket"` // Timestamp truncated to the hour
17+
Cost float64 `bson:"cost"` // Cost in USD
18+
UpdatedAt bson.DateTime `bson:"updated_at"`
19+
}
20+
21+
func (u Usage) CollectionName() string {
22+
return "usages"
23+
}
24+
25+
// TruncateToHour truncates a time to the start of its hour.
26+
func TruncateToHour(t time.Time) time.Time {
27+
return t.Truncate(time.Hour)
28+
}

internal/services/toolkit/client/client_v2.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type AIClientV2 struct {
2020

2121
reverseCommentService *services.ReverseCommentService
2222
projectService *services.ProjectService
23+
usageService *services.UsageService
2324
cfg *cfg.Cfg
2425
logger *logger.Logger
2526
}
@@ -60,6 +61,7 @@ func NewAIClientV2(
6061

6162
reverseCommentService *services.ReverseCommentService,
6263
projectService *services.ProjectService,
64+
usageService *services.UsageService,
6365
cfg *cfg.Cfg,
6466
logger *logger.Logger,
6567
) *AIClientV2 {
@@ -107,6 +109,7 @@ func NewAIClientV2(
107109

108110
reverseCommentService: reverseCommentService,
109111
projectService: projectService,
112+
usageService: usageService,
110113
cfg: cfg,
111114
logger: logger,
112115
}

internal/services/toolkit/client/completion_v2.go

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,18 @@ import (
66
"paperdebugger/internal/models"
77
"paperdebugger/internal/services/toolkit/handler"
88
chatv2 "paperdebugger/pkg/gen/api/chat/v2"
9+
"strconv"
910
"strings"
1011

1112
"github.com/openai/openai-go/v3"
13+
"go.mongodb.org/mongo-driver/v2/bson"
1214
)
1315

16+
// UsageCost holds cost information from a completion.
17+
type UsageCost struct {
18+
Cost float64
19+
}
20+
1421
// define []openai.ChatCompletionMessageParamUnion as OpenAIChatHistory
1522

1623
// ChatCompletion orchestrates a chat completion process with a language model (e.g., GPT), handling tool calls and message history management.
@@ -24,13 +31,14 @@ import (
2431
// Returns:
2532
// 1. The full chat history sent to the language model (including any tool call results).
2633
// 2. The incremental chat history visible to the user (including tool call results and assistant responses).
27-
// 3. An error, if any occurred during the process.
28-
func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) {
29-
openaiChatHistory, inappChatHistory, err := a.ChatCompletionStreamV2(ctx, nil, "", modelSlug, messages, llmProvider)
34+
// 3. Cost information (in USD).
35+
// 4. An error, if any occurred during the process.
36+
func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, UsageCost, error) {
37+
openaiChatHistory, inappChatHistory, usage, err := a.ChatCompletionStreamV2(ctx, nil, bson.ObjectID{}, "", "", modelSlug, messages, llmProvider)
3038
if err != nil {
31-
return nil, nil, err
39+
return nil, nil, UsageCost{}, err
3240
}
33-
return openaiChatHistory, inappChatHistory, nil
41+
return openaiChatHistory, inappChatHistory, usage, nil
3442
}
3543

3644
// ChatCompletionStream orchestrates a streaming chat completion process with a language model (e.g., GPT), handling tool calls, message history management, and real-time streaming of responses to the client.
@@ -46,17 +54,19 @@ func (a *AIClientV2) ChatCompletionV2(ctx context.Context, modelSlug string, mes
4654
// Returns: (same as ChatCompletion)
4755
// 1. The full chat history sent to the language model (including any tool call results).
4856
// 2. The incremental chat history visible to the user (including tool call results and assistant responses).
49-
// 3. An error, if any occurred during the process. (However, in the streaming mode, the error is not returned, but sending by callbackStream)
57+
// 3. Cost information (in USD, accumulated across all calls).
58+
// 4. An error, if any occurred during the process. (However, in the streaming mode, the error is not returned, but sending by callbackStream)
5059
//
5160
// This function works as follows: (same as ChatCompletion)
5261
// - It initializes the chat history for the language model and the user, and sets up a stream handler for real-time updates.
5362
// - It repeatedly sends the current chat history to the language model, receives streaming responses, and forwards them to the client as they arrive.
5463
// - If tool calls are required, it handles them and appends the results to the chat history, then continues the loop.
5564
// - If no tool calls are needed, it appends the assistant's response and exits the loop.
56-
// - Finally, it returns the updated chat histories and any error encountered.
57-
func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, error) {
65+
// - Finally, it returns the updated chat histories, accumulated cost, and any error encountered.
66+
func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream chatv2.ChatService_CreateConversationMessageStreamServer, userID bson.ObjectID, projectID string, conversationId string, modelSlug string, messages OpenAIChatHistory, llmProvider *models.LLMProviderConfig) (OpenAIChatHistory, AppChatHistory, UsageCost, error) {
5867
openaiChatHistory := messages
5968
inappChatHistory := AppChatHistory{}
69+
usage := UsageCost{}
6070

6171
streamHandler := handler.NewStreamHandlerV2(callbackStream, conversationId, modelSlug)
6272

@@ -77,6 +87,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
7787
answer_content := ""
7888
answer_content_id := ""
7989
has_sent_part_begin := false
90+
has_finished := false
8091
tool_info := map[int]map[string]string{}
8192
toolCalls := []openai.FinishedChatCompletionToolCall{}
8293
handleReasoning := func(raw string) (string, bool) {
@@ -92,12 +103,18 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
92103
}
93104

94105
for stream.Next() {
95-
// time.Sleep(5000 * time.Millisecond) // DEBUG POINT: change this to test in a slow mode
96106
chunk := stream.Current()
97107

108+
// Capture cost from any chunk that has usage data (OpenRouter sends usage in a separate chunk after FinishReason)
109+
if chunk.Usage.PromptTokens > 0 || chunk.Usage.CompletionTokens > 0 {
110+
if costField, ok := chunk.Usage.JSON.ExtraFields["cost"]; ok {
111+
if cost, err := strconv.ParseFloat(costField.Raw(), 64); err == nil {
112+
usage.Cost += cost
113+
}
114+
}
115+
}
116+
98117
if len(chunk.Choices) == 0 {
99-
// Handle usage information
100-
// fmt.Printf("Usage: %+v\n", chunk.Usage)
101118
continue
102119
}
103120

@@ -180,17 +197,15 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
180197
}
181198
}
182199

183-
if chunk.Choices[0].FinishReason != "" {
184-
// fmt.Printf("FinishReason: %s\n", chunk.Choices[0].FinishReason)
185-
// answer_content += chunk.Choices[0].Delta.Content
186-
// fmt.Printf("answer_content: %s\n", answer_content)
200+
if chunk.Choices[0].FinishReason != "" && !has_finished {
187201
streamHandler.HandleTextDoneItem(chunk, answer_content, reasoning_content)
188-
break
202+
has_finished = true
203+
// Don't break - continue reading to capture the usage chunk that comes after
189204
}
190205
}
191206

192207
if err := stream.Err(); err != nil {
193-
return nil, nil, err
208+
return nil, nil, UsageCost{}, err
194209
}
195210

196211
if answer_content != "" {
@@ -200,7 +215,7 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
200215
// Execute the calls (if any), return incremental data
201216
openaiToolHistory, inappToolHistory, err := a.toolCallHandler.HandleToolCallsV2(ctx, toolCalls, streamHandler)
202217
if err != nil {
203-
return nil, nil, err
218+
return nil, nil, UsageCost{}, err
204219
}
205220

206221
// // Record the tool call results
@@ -213,5 +228,12 @@ func (a *AIClientV2) ChatCompletionStreamV2(ctx context.Context, callbackStream
213228
}
214229
}
215230

216-
return openaiChatHistory, inappChatHistory, nil
231+
// Track cost if userID is provided and user is not using their own API key (BYOK)
232+
if !userID.IsZero() && !llmProvider.IsCustom() {
233+
if err := a.usageService.TrackUsage(ctx, userID, projectID, modelSlug, usage.Cost); err != nil {
234+
a.logger.Error("Failed to track usage", "error", err)
235+
}
236+
}
237+
238+
return openaiChatHistory, inappChatHistory, usage, nil
217239
}

internal/services/toolkit/client/get_citation_keys.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func (a *AIClientV2) GetCitationKeys(ctx context.Context, sentence string, userI
241241
// Bibliography is placed at the start of the prompt to leverage prompt caching
242242
message := fmt.Sprintf("Bibliography: %s\nSentence: %s\nBased on the sentence and bibliography, suggest only the most relevant citation keys separated by commas with no spaces (e.g. key1,key2). Be selective and only include citations that are directly relevant. Avoid suggesting more than 3 citations. If no relevant citations are found, return '%s'.", bibliography, sentence, emptyCitation)
243243

244-
_, resp, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{
244+
_, resp, _, err := a.ChatCompletionV2(ctx, "gpt-5.2", OpenAIChatHistory{
245245
openai.SystemMessage("You are a helpful assistant that suggests relevant citation keys."),
246246
openai.UserMessage(message),
247247
}, llmProvider)

internal/services/toolkit/client/get_citation_keys_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ func setupTestClient(t *testing.T) (*client.AIClientV2, *services.ProjectService
2525
}
2626

2727
projectService := services.NewProjectService(dbInstance, cfg.GetCfg(), logger.GetLogger())
28+
usageService := services.NewUsageService(dbInstance, cfg.GetCfg(), logger.GetLogger())
2829
aiClient := client.NewAIClientV2(
2930
dbInstance,
3031
&services.ReverseCommentService{},
3132
projectService,
33+
usageService,
3234
cfg.GetCfg(),
3335
logger.GetLogger(),
3436
)

internal/services/toolkit/client/get_conversation_title_v2.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func (a *AIClientV2) GetConversationTitleV2(ctx context.Context, inappChatHistor
2929
message := strings.Join(messages, "\n")
3030
message = fmt.Sprintf("%s\nBased on above conversation, generate a short, clear, and descriptive title that summarizes the main topic or purpose of the discussion. The title should be concise, specific, and use natural language. Avoid vague or generic titles. Use abbreviation and short words if possible. Use 3-5 words if possible. Give me the title only, no other text including any other words.", message)
3131

32-
_, resp, err := a.ChatCompletionV2(ctx, "gpt-5-nano", OpenAIChatHistory{
32+
_, resp, _, err := a.ChatCompletionV2(ctx, "gpt-5-nano", OpenAIChatHistory{
3333
openai.SystemMessage("You are a helpful assistant that generates a title for a conversation."),
3434
openai.UserMessage(message),
3535
}, llmProvider)

internal/services/toolkit/client/utils_v2.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2)
7474
Tools: toolRegistry.GetTools(),
7575
ParallelToolCalls: openaiv3.Bool(true),
7676
Store: openaiv3.Bool(false),
77+
StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{
78+
IncludeUsage: openaiv3.Bool(true),
79+
},
7780
}
7881
}
7982
}
@@ -85,6 +88,9 @@ func getDefaultParamsV2(modelSlug string, toolRegistry *registry.ToolRegistryV2)
8588
Tools: toolRegistry.GetTools(), // Tool registration is managed centrally by the registry
8689
ParallelToolCalls: openaiv3.Bool(true),
8790
Store: openaiv3.Bool(false), // Must set to false, because we are construct our own chat history.
91+
StreamOptions: openaiv3.ChatCompletionStreamOptionsParam{
92+
IncludeUsage: openaiv3.Bool(true),
93+
},
8894
}
8995
}
9096

internal/services/usage.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package services
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"paperdebugger/internal/libs/cfg"
8+
"paperdebugger/internal/libs/db"
9+
"paperdebugger/internal/libs/logger"
10+
"paperdebugger/internal/models"
11+
12+
"go.mongodb.org/mongo-driver/v2/bson"
13+
"go.mongodb.org/mongo-driver/v2/mongo"
14+
"go.mongodb.org/mongo-driver/v2/mongo/options"
15+
)
16+
17+
type UsageService struct {
18+
BaseService
19+
usageCollection *mongo.Collection
20+
}
21+
22+
func NewUsageService(db *db.DB, cfg *cfg.Cfg, logger *logger.Logger) *UsageService {
23+
base := NewBaseService(db, cfg, logger)
24+
return &UsageService{
25+
BaseService: base,
26+
usageCollection: base.db.Collection((models.Usage{}).CollectionName()),
27+
}
28+
}
29+
30+
// TrackUsage increments cost for a user/project/model/hour bucket.
31+
// Uses upsert to create or update the usage record atomically.
32+
func (s *UsageService) TrackUsage(ctx context.Context, userID bson.ObjectID, projectID string, modelSlug string, cost float64) error {
33+
if cost == 0 {
34+
return nil
35+
}
36+
37+
now := time.Now()
38+
hourBucket := models.TruncateToHour(now)
39+
40+
filter := bson.M{
41+
"user_id": userID,
42+
"project_id": projectID,
43+
"model_slug": modelSlug,
44+
"hour_bucket": bson.NewDateTimeFromTime(hourBucket),
45+
}
46+
47+
update := bson.M{
48+
"$inc": bson.M{
49+
"cost": cost,
50+
},
51+
"$set": bson.M{
52+
"updated_at": bson.NewDateTimeFromTime(now),
53+
},
54+
"$setOnInsert": bson.M{
55+
"_id": bson.NewObjectID(),
56+
},
57+
}
58+
59+
opts := options.UpdateOne().SetUpsert(true)
60+
_, err := s.usageCollection.UpdateOne(ctx, filter, update, opts)
61+
return err
62+
}

internal/wire.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ var Set = wire.NewSet(
4343
services.NewProjectService,
4444
services.NewPromptService,
4545
services.NewOAuthService,
46+
services.NewUsageService,
4647

4748
cfg.GetCfg,
4849
logger.GetLogger,

0 commit comments

Comments
 (0)