Skip to content

Commit 46187c2

Browse files
noorbhatiamattt
andauthored
Add additionalContext support to MLXLanguageModel (#145)
* Add additionalContext support to MLXLanguageModel * Fix merge conflict resolution * Incorporate feedback from review * Add userInputProcessing property to MLX custom generation options * Group KV cache generation options into struct * Improve ergnomics of resize processor at call site * Fix compiler errors due to JSONValue ambiguity * Update expectation for MLX image processing test * Incorporate feedback from review --------- Co-authored-by: Mattt Zmuda <mattt@me.com>
1 parent e85aa33 commit 46187c2

5 files changed

Lines changed: 267 additions & 78 deletions

File tree

README.md

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,23 +488,43 @@ let response = try await session.respond {
488488
}
489489
```
490490
491-
You can tune MLX KV-cache behavior per request with model-specific options:
491+
You can tune MLX request behavior per call with model-specific options,
492+
including KV-cache settings and optional media preprocessing:
492493
493494
```swift
494495
var options = GenerationOptions(temperature: 0.7)
495-
options[custom: MLXLanguageModel.self] = .init(
496-
maxKVSize: 4096,
497-
kvBits: 4,
498-
kvGroupSize: 64,
499-
quantizedKVStart: 128
496+
var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default
497+
mlxOptions.kvCache = .init(
498+
maxSize: 4096,
499+
bits: 4,
500+
groupSize: 64,
501+
quantizedStart: 128
500502
)
503+
// Apply a deterministic preprocessing step for image inputs.
504+
mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512))
505+
// Inject extra template context consumed by model-specific chat templates.
506+
mlxOptions.additionalContext = [
507+
"user_name": .string("Alice"),
508+
"turn_count": .int(3),
509+
"verbose": .bool(true),
510+
]
511+
options[custom: MLXLanguageModel.self] = mlxOptions
501512
502513
let response = try await session.respond(
503514
to: "Summarize this transcript",
504515
options: options
505516
)
506517
```
507518
519+
You can specify `userInputProcessing` to enforce a consistent image
520+
preprocessing step
521+
(for example, fixed dimensions for predictable latency, memory usage, and vision behavior).
522+
By default, images are passed through without an explicit resize override
523+
(`resize: nil`), so MLX applies its default media processing behavior.
524+
525+
You can also set `additionalContext` to provide extra JSON template variables
526+
for model-specific chat templates.
527+
508528
GPU cache behavior can be configured when creating the model:
509529
510530
```swift

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 173 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -196,38 +196,122 @@ import Foundation
196196
/// Set these values through ``GenerationOptions`` using
197197
/// `GenerationOptions[custom: MLXLanguageModel.self]`.
198198
public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions, Codable {
199-
/// Limits how many tokens the KV cache retains.
200-
///
201-
/// Set this to `nil` to use the backend default.
202-
public var maxKVSize: Int?
203-
/// Sets the KV-cache quantization bit width.
204-
///
205-
/// Set this to `nil` to disable KV quantization.
206-
public var kvBits: Int?
207-
/// Sets the token group size used for KV quantization.
208-
public var kvGroupSize: Int
209-
/// Sets the token offset where quantized KV storage starts.
210-
public var quantizedKVStart: Int
199+
/// Configures KV-cache behavior for MLX generation.
200+
public struct KVCache: Codable, Equatable, Sendable {
201+
/// Limits how many tokens the KV cache retains.
202+
///
203+
/// Set this to `nil` to use the backend default.
204+
public var maxSize: Int?
205+
206+
/// Sets the KV-cache quantization bit width.
207+
///
208+
/// Set this to `nil` to disable KV quantization.
209+
public var bits: Int?
210+
211+
/// Sets the token group size used for KV quantization.
212+
public var groupSize: Int
213+
214+
/// Sets the token offset where quantized KV storage starts.
215+
public var quantizedStart: Int
216+
217+
/// Default KV-cache options used when none are provided at runtime.
218+
/// By default, the token group size is 64 and the quantized start is 0.
219+
public static var `default`: Self {
220+
.init(
221+
maxSize: nil,
222+
bits: nil,
223+
groupSize: 64,
224+
quantizedStart: 0
225+
)
226+
}
227+
228+
/// Creates KV-cache configuration for MLX generation.
229+
///
230+
/// - Parameters:
231+
/// - maxSize: The maximum number of tokens to retain in KV cache storage.
232+
/// Pass `nil` to use the backend default.
233+
/// - bits: The KV-cache quantization bit width.
234+
/// Pass `nil` to disable KV quantization.
235+
/// - groupSize: The token group size used for KV quantization.
236+
/// - quantizedStart: The token index where quantized KV storage begins.
237+
public init(
238+
maxSize: Int?,
239+
bits: Int?,
240+
groupSize: Int,
241+
quantizedStart: Int
242+
) {
243+
self.maxSize = maxSize
244+
self.bits = bits
245+
self.groupSize = groupSize
246+
self.quantizedStart = quantizedStart
247+
}
248+
}
249+
/// KV-cache configuration used for generation.
250+
public var kvCache: KVCache
251+
252+
/// Configures media preprocessing applied before model input.
253+
public struct UserInputProcessing: Codable, Equatable, Sendable {
254+
/// Optional resize target applied to media before tokenization.
255+
public var resize: CGSize?
256+
257+
/// Creates user-input processing configuration.
258+
///
259+
/// - Parameter resize: Optional target size for media resizing.
260+
init(resize: CGSize?) {
261+
self.resize = resize
262+
}
263+
264+
/// Creates processing that resizes media to a fixed size.
265+
///
266+
/// - Parameter size: Target size used for resizing media inputs.
267+
public static func resize(to size: CGSize) -> Self {
268+
.init(resize: size)
269+
}
270+
271+
var mlxValue: MLXLMCommon.UserInput.Processing {
272+
.init(resize: resize)
273+
}
274+
}
275+
/// Processing to apply to user media before input preparation.
276+
public var userInputProcessing: UserInputProcessing?
277+
278+
var processingForUserInput: MLXLMCommon.UserInput.Processing {
279+
userInputProcessing?.mlxValue
280+
?? .init(resize: nil)
281+
}
282+
283+
/// Additional key-value pairs injected into the chat template rendering context.
284+
public var additionalContext: [String: AnyLanguageModel.JSONValue]?
285+
286+
var additionalContextForUserInput: [String: any Sendable]? {
287+
additionalContext?.mapValues { $0.toSendable() }
288+
}
211289

212290
/// Creates MLX-specific generation options.
213291
///
214292
/// - Parameters:
215-
/// - maxKVSize: The maximum number of tokens to retain in KV cache storage.
216-
/// Pass `nil` to use the backend default.
217-
/// - kvBits: The KV-cache quantization bit width.
218-
/// Pass `nil` to disable KV quantization.
219-
/// - kvGroupSize: The token group size used for KV quantization.
220-
/// - quantizedKVStart: The token index where quantized KV storage begins.
293+
/// - kvCache: KV-cache configuration used for generation.
294+
/// - additionalContext: Additional key-value pairs injected into the chat
295+
/// template rendering context.
296+
/// - userInputProcessing: Processing to apply to user media before input preparation.
297+
/// Defaults to `nil`, which lets MLX use its default media handling.
221298
public init(
222-
maxKVSize: Int? = nil,
223-
kvBits: Int? = nil,
224-
kvGroupSize: Int = 64,
225-
quantizedKVStart: Int = 0
299+
kvCache: KVCache,
300+
userInputProcessing: UserInputProcessing?,
301+
additionalContext: [String: AnyLanguageModel.JSONValue]?
226302
) {
227-
self.maxKVSize = maxKVSize
228-
self.kvBits = kvBits
229-
self.kvGroupSize = kvGroupSize
230-
self.quantizedKVStart = quantizedKVStart
303+
self.kvCache = kvCache
304+
self.additionalContext = additionalContext
305+
self.userInputProcessing = userInputProcessing
306+
}
307+
308+
/// Default MLX generation options used when none are provided at runtime.
309+
public static var `default`: Self {
310+
.init(
311+
kvCache: .default,
312+
userInputProcessing: nil,
313+
additionalContext: nil
314+
)
231315
}
232316
}
233317

@@ -761,15 +845,16 @@ import Foundation
761845
}
762846

763847
private func makeUserInput(
764-
session: LanguageModelSession,
765-
fallbackPrompt: String,
766-
tools: [ToolSpec]?
848+
chat: [MLXLMCommon.Chat.Message],
849+
tools: [ToolSpec]?,
850+
processing: MLXLMCommon.UserInput.Processing = .init(resize: nil),
851+
additionalContext: [String: any Sendable]? = nil
767852
) -> MLXLMCommon.UserInput {
768-
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: fallbackPrompt)
769853
return MLXLMCommon.UserInput(
770854
chat: chat,
771-
processing: .init(resize: .init(width: 512, height: 512)),
772-
tools: tools
855+
processing: processing,
856+
tools: tools,
857+
additionalContext: additionalContext,
773858
)
774859
}
775860

@@ -813,6 +898,12 @@ import Foundation
813898
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
814899
let generateParameters = toGenerateParameters(options)
815900

901+
// Extract additional context from custom options
902+
let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput
903+
let userInputProcessing =
904+
options[custom: MLXLanguageModel.self]?.processingForUserInput
905+
?? .init(resize: nil)
906+
816907
// Build chat history from full transcript
817908
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
818909

@@ -825,10 +916,11 @@ import Foundation
825916
// Loop until no more tool calls
826917
while true {
827918
// Build user input with current chat history and tools
828-
let userInput = MLXLMCommon.UserInput(
919+
let userInput = makeUserInput(
829920
chat: chat,
830-
processing: .init(resize: .init(width: 512, height: 512)),
831-
tools: toolSpecs
921+
tools: toolSpecs,
922+
processing: userInputProcessing,
923+
additionalContext: additionalContext
832924
)
833925
let lmInput = try await context.processor.prepare(input: userInput)
834926
let resolved = resolveCache(
@@ -991,10 +1083,20 @@ import Foundation
9911083

9921084
// Build chat inside task to avoid Sendable issues
9931085
let generateParameters = toGenerateParameters(options)
994-
let userInput = makeUserInput(
1086+
let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput
1087+
let userInputProcessing =
1088+
options[custom: MLXLanguageModel.self]?.processingForUserInput
1089+
?? .init(resize: nil)
1090+
let chat = convertTranscriptToMLXChat(
9951091
session: session,
996-
fallbackPrompt: prompt.description,
997-
tools: nil
1092+
fallbackPrompt: prompt.description
1093+
)
1094+
1095+
let userInput = makeUserInput(
1096+
chat: chat,
1097+
tools: nil,
1098+
processing: userInputProcessing,
1099+
additionalContext: additionalContext
9981100
)
9991101
let lmInput = try await context.processor.prepare(input: userInput)
10001102
let resolved = resolveCache(
@@ -1092,7 +1194,7 @@ import Foundation
10921194
let newCache = context.model.newCache(parameters: params)
10931195
let userInput = MLXLMCommon.UserInput(
10941196
chat: [.init(role: .system, content: instructions)],
1095-
processing: .init(resize: .init(width: 512, height: 512)),
1197+
processing: .init(resize: nil),
10961198
tools: toolSpecs
10971199
)
10981200
let lmInput = try await context.processor.prepare(input: userInput)
@@ -1116,10 +1218,10 @@ import Foundation
11161218
let custom = options[custom: MLXLanguageModel.self]
11171219
return MLXLMCommon.GenerateParameters(
11181220
maxTokens: options.maximumResponseTokens,
1119-
maxKVSize: custom?.maxKVSize,
1120-
kvBits: custom?.kvBits,
1121-
kvGroupSize: custom?.kvGroupSize ?? 64,
1122-
quantizedKVStart: custom?.quantizedKVStart ?? 0,
1221+
maxKVSize: custom?.kvCache.maxSize,
1222+
kvBits: custom?.kvCache.bits,
1223+
kvGroupSize: custom?.kvCache.groupSize ?? 64,
1224+
quantizedKVStart: custom?.kvCache.quantizedStart ?? 0,
11231225
temperature: Float(options.temperature ?? 0.6),
11241226
topP: 1.0,
11251227
repetitionPenalty: nil,
@@ -1132,10 +1234,10 @@ import Foundation
11321234
let custom = options[custom: MLXLanguageModel.self]
11331235
return MLXLMCommon.GenerateParameters(
11341236
maxTokens: options.maximumResponseTokens,
1135-
maxKVSize: custom?.maxKVSize,
1136-
kvBits: custom?.kvBits,
1137-
kvGroupSize: custom?.kvGroupSize ?? 64,
1138-
quantizedKVStart: custom?.quantizedKVStart ?? 0,
1237+
maxKVSize: custom?.kvCache.maxSize,
1238+
kvBits: custom?.kvCache.bits,
1239+
kvGroupSize: custom?.kvCache.groupSize ?? 64,
1240+
quantizedKVStart: custom?.kvCache.quantizedStart ?? 0,
11391241
temperature: Float(options.temperature ?? 0.2),
11401242
topP: 0.95,
11411243
repetitionPenalty: 1.1,
@@ -1489,7 +1591,7 @@ import Foundation
14891591
{
14901592
header += ". Expected value: \(constString)"
14911593
} else if let enumValues = jsonSchema.enum, !enumValues.isEmpty,
1492-
let data = try? encoder.encode(JSONValue.array(enumValues)),
1594+
let data = try? encoder.encode(enumValues),
14931595
let enumString = String(data: data, encoding: .utf8)
14941596
{
14951597
header += ". Allowed values: \(enumString)"
@@ -1529,10 +1631,17 @@ import Foundation
15291631
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
15301632
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
15311633
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)
1634+
1635+
let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput
1636+
let userInputProcessing =
1637+
options[custom: MLXLanguageModel.self]?.processingForUserInput
1638+
?? .init(resize: nil)
1639+
15321640
let userInput = MLXLMCommon.UserInput(
15331641
chat: chat,
1534-
processing: .init(resize: .init(width: 512, height: 512)),
1535-
tools: nil
1642+
processing: userInputProcessing,
1643+
tools: nil,
1644+
additionalContext: additionalContext,
15361645
)
15371646
let lmInput = try await context.processor.prepare(input: userInput)
15381647

@@ -1773,4 +1882,18 @@ import Foundation
17731882
return sampledToken.item(Int.self)
17741883
}
17751884
}
1885+
extension AnyLanguageModel.JSONValue {
1886+
/// Recursively converts a `JSONValue` to its primitive Swift equivalent.
1887+
func toSendable() -> any Sendable {
1888+
switch self {
1889+
case .string(let s): return s
1890+
case .int(let i): return i
1891+
case .double(let d): return d
1892+
case .bool(let b): return b
1893+
case .null: return NSNull()
1894+
case .array(let arr): return arr.map { $0.toSendable() }
1895+
case .object(let obj): return obj.mapValues { $0.toSendable() }
1896+
}
1897+
}
1898+
}
17761899
#endif // MLX
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import enum JSONSchema.JSONValue
2+
3+
/// A type-safe representation of JSON values used by AnyLanguageModel APIs.
4+
public typealias JSONValue = JSONSchema.JSONValue

0 commit comments

Comments
 (0)