Skip to content

Commit 7935b41

Browse files
committed
Expose provider token usage on responses
1 parent 701d7e6 commit 7935b41

10 files changed

Lines changed: 585 additions & 25 deletions

Sources/AnyLanguageModel/LanguageModelSession.swift

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ public final class LanguageModelSession: @unchecked Sendable {
174174
public struct Response<Content>: Sendable where Content: Generable, Content: Sendable {
175175
public let content: Content
176176
public let rawContent: GeneratedContent
177+
public let usage: LanguageModelUsage?
177178
public let transcriptEntries: ArraySlice<Transcript.Entry>
178179

179180
/// Creates a response value from generated content and transcript entries.
@@ -184,10 +185,12 @@ public final class LanguageModelSession: @unchecked Sendable {
184185
public init(
185186
content: Content,
186187
rawContent: GeneratedContent,
188+
usage: LanguageModelUsage? = nil,
187189
transcriptEntries: ArraySlice<Transcript.Entry>
188190
) {
189191
self.content = content
190192
self.rawContent = rawContent
193+
self.usage = usage
191194
self.transcriptEntries = transcriptEntries
192195
}
193196
}
@@ -801,8 +804,12 @@ extension LanguageModelSession {
801804
/// - Parameters:
802805
/// - content: The complete response content.
803806
/// - rawContent: The raw content produced by the model.
804-
public init(content: Content, rawContent: GeneratedContent) {
805-
self.fallbackSnapshot = Snapshot(content: content.asPartiallyGenerated(), rawContent: rawContent)
807+
public init(content: Content, rawContent: GeneratedContent, usage: LanguageModelUsage? = nil) {
808+
self.fallbackSnapshot = Snapshot(
809+
content: content.asPartiallyGenerated(),
810+
rawContent: rawContent,
811+
usage: usage
812+
)
806813
self.streaming = nil
807814
}
808815

@@ -817,14 +824,20 @@ extension LanguageModelSession {
817824
public struct Snapshot: Sendable where Content.PartiallyGenerated: Sendable {
818825
public var content: Content.PartiallyGenerated
819826
public var rawContent: GeneratedContent
827+
public var usage: LanguageModelUsage?
820828

821829
/// Creates a snapshot from partially generated content and raw content.
822830
/// - Parameters:
823831
/// - content: The partially generated content.
824832
/// - rawContent: The raw content produced by the model.
825-
public init(content: Content.PartiallyGenerated, rawContent: GeneratedContent) {
833+
public init(
834+
content: Content.PartiallyGenerated,
835+
rawContent: GeneratedContent,
836+
usage: LanguageModelUsage? = nil
837+
) {
826838
self.content = content
827839
self.rawContent = rawContent
840+
self.usage = usage
828841
}
829842
}
830843
}
@@ -887,6 +900,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence {
887900
return LanguageModelSession.Response(
888901
content: finalContent,
889902
rawContent: last.rawContent,
903+
usage: last.usage,
890904
transcriptEntries: []
891905
)
892906
}
@@ -902,6 +916,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence {
902916
return LanguageModelSession.Response(
903917
content: finalContent,
904918
rawContent: fallbackSnapshot.rawContent,
919+
usage: fallbackSnapshot.usage,
905920
transcriptEntries: []
906921
)
907922
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/// Provider-reported token usage for a language model response.
2+
///
3+
/// Values are optional because not every provider or endpoint reports every field.
4+
/// When a single `respond` call internally triggers multiple provider requests
5+
/// (for example while resolving tool calls), the returned usage is aggregated
6+
/// across those underlying requests.
7+
public struct LanguageModelUsage: Hashable, Codable, Sendable {
8+
/// Tokens consumed by the request input or prompt.
9+
public var inputTokens: Int?
10+
11+
/// Tokens generated in the response.
12+
public var outputTokens: Int?
13+
14+
/// Total tokens reported for the request.
15+
public var totalTokens: Int?
16+
17+
/// Tokens spent on reasoning or thinking, when reported separately.
18+
public var reasoningTokens: Int?
19+
20+
/// Input tokens served from prompt cache, when reported separately.
21+
public var cachedInputTokens: Int?
22+
23+
/// Input tokens written into a prompt cache, when reported separately.
24+
public var cacheCreationInputTokens: Int?
25+
26+
public init(
27+
inputTokens: Int? = nil,
28+
outputTokens: Int? = nil,
29+
totalTokens: Int? = nil,
30+
reasoningTokens: Int? = nil,
31+
cachedInputTokens: Int? = nil,
32+
cacheCreationInputTokens: Int? = nil
33+
) {
34+
self.inputTokens = inputTokens
35+
self.outputTokens = outputTokens
36+
self.totalTokens = totalTokens
37+
self.reasoningTokens = reasoningTokens
38+
self.cachedInputTokens = cachedInputTokens
39+
self.cacheCreationInputTokens = cacheCreationInputTokens
40+
}
41+
}
42+
43+
extension LanguageModelUsage {
44+
var isEmpty: Bool {
45+
inputTokens == nil
46+
&& outputTokens == nil
47+
&& totalTokens == nil
48+
&& reasoningTokens == nil
49+
&& cachedInputTokens == nil
50+
&& cacheCreationInputTokens == nil
51+
}
52+
53+
var normalized: Self? {
54+
isEmpty ? nil : self
55+
}
56+
57+
mutating func add(_ other: Self?) {
58+
guard let other else { return }
59+
inputTokens = Self.sum(inputTokens, other.inputTokens)
60+
outputTokens = Self.sum(outputTokens, other.outputTokens)
61+
totalTokens = Self.sum(totalTokens, other.totalTokens)
62+
reasoningTokens = Self.sum(reasoningTokens, other.reasoningTokens)
63+
cachedInputTokens = Self.sum(cachedInputTokens, other.cachedInputTokens)
64+
cacheCreationInputTokens = Self.sum(cacheCreationInputTokens, other.cacheCreationInputTokens)
65+
}
66+
67+
mutating func merge(_ other: Self?) {
68+
guard let other else { return }
69+
inputTokens = other.inputTokens ?? inputTokens
70+
outputTokens = other.outputTokens ?? outputTokens
71+
totalTokens = other.totalTokens ?? totalTokens
72+
reasoningTokens = other.reasoningTokens ?? reasoningTokens
73+
cachedInputTokens = other.cachedInputTokens ?? cachedInputTokens
74+
cacheCreationInputTokens = other.cacheCreationInputTokens ?? cacheCreationInputTokens
75+
}
76+
77+
private static func sum(_ lhs: Int?, _ rhs: Int?) -> Int? {
78+
switch (lhs, rhs) {
79+
case (.some(let lhs), .some(let rhs)):
80+
lhs + rhs
81+
case (.some(let lhs), .none):
82+
lhs
83+
case (.none, .some(let rhs)):
84+
rhs
85+
case (.none, .none):
86+
nil
87+
}
88+
}
89+
}

Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ public struct AnthropicLanguageModel: LanguageModel {
345345
)
346346

347347
var entries: [Transcript.Entry] = []
348+
let usage = message.usage?.languageModelUsage
348349

349350
// Handle tool calls, if present
350351
let toolUses: [AnthropicToolUse] = message.content.compactMap { block in
@@ -363,6 +364,7 @@ public struct AnthropicLanguageModel: LanguageModel {
363364
return LanguageModelSession.Response(
364365
content: empty.content,
365366
rawContent: empty.rawContent,
367+
usage: usage,
366368
transcriptEntries: ArraySlice(entries)
367369
)
368370
case .invocations(let invocations):
@@ -386,6 +388,7 @@ public struct AnthropicLanguageModel: LanguageModel {
386388
return LanguageModelSession.Response(
387389
content: text as! Content,
388390
rawContent: GeneratedContent(text),
391+
usage: usage,
389392
transcriptEntries: ArraySlice(entries)
390393
)
391394
}
@@ -395,6 +398,7 @@ public struct AnthropicLanguageModel: LanguageModel {
395398
return LanguageModelSession.Response(
396399
content: content,
397400
rawContent: rawContent,
401+
usage: usage,
398402
transcriptEntries: ArraySlice(entries)
399403
)
400404
}
@@ -445,30 +449,48 @@ public struct AnthropicLanguageModel: LanguageModel {
445449

446450
var accumulatedText = ""
447451
let expectsStructuredResponse = type != String.self
452+
var latestUsage = LanguageModelUsage()
453+
var lastSnapshot: LanguageModelSession.ResponseStream<Content>.Snapshot?
448454

449455
for try await event in events {
450456
switch event {
457+
case .messageStart(let start):
458+
latestUsage.merge(start.message.usage?.languageModelUsage)
451459
case .contentBlockDelta(let delta):
452460
if case .textDelta(let textDelta) = delta.delta {
453461
accumulatedText += textDelta.text
454462

455463
if expectsStructuredResponse {
456-
if let snapshot: LanguageModelSession.ResponseStream<Content>.Snapshot =
464+
if var snapshot: LanguageModelSession.ResponseStream<Content>.Snapshot =
457465
try? partialSnapshot(from: accumulatedText)
458466
{
467+
snapshot.usage = latestUsage.normalized
468+
lastSnapshot = snapshot
459469
continuation.yield(snapshot)
460470
}
461471
} else {
462472
let raw = GeneratedContent(accumulatedText)
463473
let content: Content.PartiallyGenerated = (accumulatedText as! Content)
464474
.asPartiallyGenerated()
465-
continuation.yield(.init(content: content, rawContent: raw))
475+
let snapshot = LanguageModelSession.ResponseStream<Content>.Snapshot(
476+
content: content,
477+
rawContent: raw,
478+
usage: latestUsage.normalized
479+
)
480+
lastSnapshot = snapshot
481+
continuation.yield(snapshot)
466482
}
467483
}
484+
case .messageDelta(let delta):
485+
latestUsage.merge(delta.usage?.languageModelUsage)
468486
case .messageStop:
487+
if var lastSnapshot, lastSnapshot.usage != latestUsage.normalized {
488+
lastSnapshot.usage = latestUsage.normalized
489+
continuation.yield(lastSnapshot)
490+
}
469491
continuation.finish()
470492
return
471-
case .messageStart, .contentBlockStart, .contentBlockStop, .messageDelta, .ping, .ignored:
493+
case .contentBlockStart, .contentBlockStop, .ping, .ignored:
472494
break
473495
}
474496
}
@@ -995,9 +1017,10 @@ private struct AnthropicMessageResponse: Codable, Sendable {
9951017
let content: [AnthropicContent]
9961018
let model: String
9971019
let stopReason: StopReason?
1020+
let usage: AnthropicUsage?
9981021

9991022
enum CodingKeys: String, CodingKey {
1000-
case id, type, role, content, model
1023+
case id, type, role, content, model, usage
10011024
case stopReason = "stop_reason"
10021025
}
10031026

@@ -1012,6 +1035,20 @@ private struct AnthropicMessageResponse: Codable, Sendable {
10121035
}
10131036
}
10141037

1038+
private struct AnthropicUsage: Codable, Sendable {
1039+
let inputTokens: Int?
1040+
let outputTokens: Int?
1041+
let cacheCreationInputTokens: Int?
1042+
let cacheReadInputTokens: Int?
1043+
1044+
enum CodingKeys: String, CodingKey {
1045+
case inputTokens = "input_tokens"
1046+
case outputTokens = "output_tokens"
1047+
case cacheCreationInputTokens = "cache_creation_input_tokens"
1048+
case cacheReadInputTokens = "cache_read_input_tokens"
1049+
}
1050+
}
1051+
10151052
private struct AnthropicErrorResponse: Codable { let error: AnthropicErrorDetail }
10161053
private struct AnthropicErrorDetail: Codable {
10171054
let type: String
@@ -1157,6 +1194,7 @@ private enum AnthropicStreamEvent: Codable, Sendable {
11571194
struct MessageDeltaEvent: Codable, Sendable {
11581195
let type: String
11591196
let delta: Delta
1197+
let usage: AnthropicUsage?
11601198

11611199
struct Delta: Codable, Sendable {
11621200
let stopReason: String?
@@ -1169,3 +1207,14 @@ private enum AnthropicStreamEvent: Codable, Sendable {
11691207
}
11701208
}
11711209
}
1210+
1211+
private extension AnthropicUsage {
1212+
var languageModelUsage: LanguageModelUsage? {
1213+
LanguageModelUsage(
1214+
inputTokens: inputTokens,
1215+
outputTokens: outputTokens,
1216+
cachedInputTokens: cacheReadInputTokens,
1217+
cacheCreationInputTokens: cacheCreationInputTokens
1218+
).normalized
1219+
}
1220+
}

0 commit comments

Comments
 (0)