@@ -196,38 +196,122 @@ import Foundation
196196 /// Set these values through ``GenerationOptions`` using
197197 /// `GenerationOptions[custom: MLXLanguageModel.self]`.
198198 public struct CustomGenerationOptions : AnyLanguageModel . CustomGenerationOptions , Codable {
199- /// Limits how many tokens the KV cache retains.
200- ///
201- /// Set this to `nil` to use the backend default.
202- public var maxKVSize : Int ?
203- /// Sets the KV-cache quantization bit width.
204- ///
205- /// Set this to `nil` to disable KV quantization.
206- public var kvBits : Int ?
207- /// Sets the token group size used for KV quantization.
208- public var kvGroupSize : Int
209- /// Sets the token offset where quantized KV storage starts.
210- public var quantizedKVStart : Int
199+ /// Configures KV-cache behavior for MLX generation.
200+ public struct KVCache : Codable , Equatable , Sendable {
201+ /// Limits how many tokens the KV cache retains.
202+ ///
203+ /// Set this to `nil` to use the backend default.
204+ public var maxSize : Int ?
205+
206+ /// Sets the KV-cache quantization bit width.
207+ ///
208+ /// Set this to `nil` to disable KV quantization.
209+ public var bits : Int ?
210+
211+ /// Sets the token group size used for KV quantization.
212+ public var groupSize : Int
213+
214+ /// Sets the token offset where quantized KV storage starts.
215+ public var quantizedStart : Int
216+
217+ /// Default KV-cache options used when none are provided at runtime.
218+ /// By default, the token group size is 64 and the quantized start is 0.
219+ public static var `default` : Self {
220+ . init(
221+ maxSize: nil ,
222+ bits: nil ,
223+ groupSize: 64 ,
224+ quantizedStart: 0
225+ )
226+ }
227+
228+ /// Creates KV-cache configuration for MLX generation.
229+ ///
230+ /// - Parameters:
231+ /// - maxSize: The maximum number of tokens to retain in KV cache storage.
232+ /// Pass `nil` to use the backend default.
233+ /// - bits: The KV-cache quantization bit width.
234+ /// Pass `nil` to disable KV quantization.
235+ /// - groupSize: The token group size used for KV quantization.
236+ /// - quantizedStart: The token index where quantized KV storage begins.
237+ public init (
238+ maxSize: Int ? ,
239+ bits: Int ? ,
240+ groupSize: Int ,
241+ quantizedStart: Int
242+ ) {
243+ self . maxSize = maxSize
244+ self . bits = bits
245+ self . groupSize = groupSize
246+ self . quantizedStart = quantizedStart
247+ }
248+ }
249+ /// KV-cache configuration used for generation.
250+ public var kvCache : KVCache
251+
252+ /// Configures media preprocessing applied before model input.
253+ public struct UserInputProcessing : Codable , Equatable , Sendable {
254+ /// Optional resize target applied to media before tokenization.
255+ public var resize : CGSize ?
256+
257+ /// Creates user-input processing configuration.
258+ ///
259+ /// - Parameter resize: Optional target size for media resizing.
260+ init ( resize: CGSize ? ) {
261+ self . resize = resize
262+ }
263+
264+ /// Creates processing that resizes media to a fixed size.
265+ ///
266+ /// - Parameter size: Target size used for resizing media inputs.
267+ public static func resize( to size: CGSize ) -> Self {
268+ . init( resize: size)
269+ }
270+
271+ var mlxValue : MLXLMCommon . UserInput . Processing {
272+ . init( resize: resize)
273+ }
274+ }
275+ /// Processing to apply to user media before input preparation.
276+ public var userInputProcessing : UserInputProcessing ?
277+
278+ var processingForUserInput : MLXLMCommon . UserInput . Processing {
279+ userInputProcessing? . mlxValue
280+ ?? . init( resize: nil )
281+ }
282+
283+ /// Additional key-value pairs injected into the chat template rendering context.
284+ public var additionalContext : [ String : AnyLanguageModel . JSONValue ] ?
285+
286+ var additionalContextForUserInput : [ String : any Sendable ] ? {
287+ additionalContext? . mapValues { $0. toSendable ( ) }
288+ }
211289
212290 /// Creates MLX-specific generation options.
213291 ///
214292 /// - Parameters:
215- /// - maxKVSize: The maximum number of tokens to retain in KV cache storage.
216- /// Pass `nil` to use the backend default.
217- /// - kvBits: The KV-cache quantization bit width.
218- /// Pass `nil` to disable KV quantization.
219- /// - kvGroupSize: The token group size used for KV quantization.
220- /// - quantizedKVStart: The token index where quantized KV storage begins.
293+ /// - kvCache: KV-cache configuration used for generation.
294+ /// - additionalContext: Additional key-value pairs injected into the chat
295+ /// template rendering context.
296+ /// - userInputProcessing: Processing to apply to user media before input preparation.
297+ /// Defaults to `nil`, which lets MLX use its default media handling.
221298 public init (
222- maxKVSize: Int ? = nil ,
223- kvBits: Int ? = nil ,
224- kvGroupSize: Int = 64 ,
225- quantizedKVStart: Int = 0
299+ kvCache: KVCache ,
300+ userInputProcessing: UserInputProcessing ? ,
301+ additionalContext: [ String : AnyLanguageModel . JSONValue ] ?
226302 ) {
227- self . maxKVSize = maxKVSize
228- self . kvBits = kvBits
229- self . kvGroupSize = kvGroupSize
230- self . quantizedKVStart = quantizedKVStart
303+ self . kvCache = kvCache
304+ self . additionalContext = additionalContext
305+ self . userInputProcessing = userInputProcessing
306+ }
307+
308+ /// Default MLX generation options used when none are provided at runtime.
309+ public static var `default` : Self {
310+ . init(
311+ kvCache: . default,
312+ userInputProcessing: nil ,
313+ additionalContext: nil
314+ )
231315 }
232316 }
233317
@@ -761,15 +845,16 @@ import Foundation
761845 }
762846
763847 private func makeUserInput(
764- session: LanguageModelSession ,
765- fallbackPrompt: String ,
766- tools: [ ToolSpec ] ?
848+ chat: [ MLXLMCommon . Chat . Message ] ,
849+ tools: [ ToolSpec ] ? ,
850+ processing: MLXLMCommon . UserInput . Processing = . init( resize: nil ) ,
851+ additionalContext: [ String : any Sendable ] ? = nil
767852 ) -> MLXLMCommon . UserInput {
768- let chat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: fallbackPrompt)
769853 return MLXLMCommon . UserInput (
770854 chat: chat,
771- processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
772- tools: tools
855+ processing: processing,
856+ tools: tools,
857+ additionalContext: additionalContext,
773858 )
774859 }
775860
@@ -813,6 +898,12 @@ import Foundation
813898 // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
814899 let generateParameters = toGenerateParameters ( options)
815900
901+ // Extract additional context from custom options
902+ let additionalContext = options [ custom: MLXLanguageModel . self] ? . additionalContextForUserInput
903+ let userInputProcessing =
904+ options [ custom: MLXLanguageModel . self] ? . processingForUserInput
905+ ?? . init( resize: nil )
906+
816907 // Build chat history from full transcript
817908 var chat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
818909
@@ -825,10 +916,11 @@ import Foundation
825916 // Loop until no more tool calls
826917 while true {
827918 // Build user input with current chat history and tools
828- let userInput = MLXLMCommon . UserInput (
919+ let userInput = makeUserInput (
829920 chat: chat,
830- processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
831- tools: toolSpecs
921+ tools: toolSpecs,
922+ processing: userInputProcessing,
923+ additionalContext: additionalContext
832924 )
833925 let lmInput = try await context. processor. prepare ( input: userInput)
834926 let resolved = resolveCache (
@@ -991,10 +1083,20 @@ import Foundation
9911083
9921084 // Build chat inside task to avoid Sendable issues
9931085 let generateParameters = toGenerateParameters ( options)
994- let userInput = makeUserInput (
1086+ let additionalContext = options [ custom: MLXLanguageModel . self] ? . additionalContextForUserInput
1087+ let userInputProcessing =
1088+ options [ custom: MLXLanguageModel . self] ? . processingForUserInput
1089+ ?? . init( resize: nil )
1090+ let chat = convertTranscriptToMLXChat (
9951091 session: session,
996- fallbackPrompt: prompt. description,
997- tools: nil
1092+ fallbackPrompt: prompt. description
1093+ )
1094+
1095+ let userInput = makeUserInput (
1096+ chat: chat,
1097+ tools: nil ,
1098+ processing: userInputProcessing,
1099+ additionalContext: additionalContext
9981100 )
9991101 let lmInput = try await context. processor. prepare ( input: userInput)
10001102 let resolved = resolveCache (
@@ -1092,7 +1194,7 @@ import Foundation
10921194 let newCache = context. model. newCache ( parameters: params)
10931195 let userInput = MLXLMCommon . UserInput (
10941196 chat: [ . init( role: . system, content: instructions) ] ,
1095- processing: . init( resize: . init ( width : 512 , height : 512 ) ) ,
1197+ processing: . init( resize: nil ) ,
10961198 tools: toolSpecs
10971199 )
10981200 let lmInput = try await context. processor. prepare ( input: userInput)
@@ -1116,10 +1218,10 @@ import Foundation
11161218 let custom = options [ custom: MLXLanguageModel . self]
11171219 return MLXLMCommon . GenerateParameters (
11181220 maxTokens: options. maximumResponseTokens,
1119- maxKVSize: custom? . maxKVSize ,
1120- kvBits: custom? . kvBits ,
1121- kvGroupSize: custom? . kvGroupSize ?? 64 ,
1122- quantizedKVStart: custom? . quantizedKVStart ?? 0 ,
1221+ maxKVSize: custom? . kvCache . maxSize ,
1222+ kvBits: custom? . kvCache . bits ,
1223+ kvGroupSize: custom? . kvCache . groupSize ?? 64 ,
1224+ quantizedKVStart: custom? . kvCache . quantizedStart ?? 0 ,
11231225 temperature: Float ( options. temperature ?? 0.6 ) ,
11241226 topP: 1.0 ,
11251227 repetitionPenalty: nil ,
@@ -1132,10 +1234,10 @@ import Foundation
11321234 let custom = options [ custom: MLXLanguageModel . self]
11331235 return MLXLMCommon . GenerateParameters (
11341236 maxTokens: options. maximumResponseTokens,
1135- maxKVSize: custom? . maxKVSize ,
1136- kvBits: custom? . kvBits ,
1137- kvGroupSize: custom? . kvGroupSize ?? 64 ,
1138- quantizedKVStart: custom? . quantizedKVStart ?? 0 ,
1237+ maxKVSize: custom? . kvCache . maxSize ,
1238+ kvBits: custom? . kvCache . bits ,
1239+ kvGroupSize: custom? . kvCache . groupSize ?? 64 ,
1240+ quantizedKVStart: custom? . kvCache . quantizedStart ?? 0 ,
11391241 temperature: Float ( options. temperature ?? 0.2 ) ,
11401242 topP: 0.95 ,
11411243 repetitionPenalty: 1.1 ,
@@ -1489,7 +1591,7 @@ import Foundation
14891591 {
14901592 header += " . Expected value: \( constString) "
14911593 } else if let enumValues = jsonSchema. enum, !enumValues. isEmpty,
1492- let data = try ? encoder. encode ( JSONValue . array ( enumValues) ) ,
1594+ let data = try ? encoder. encode ( enumValues) ,
14931595 let enumString = String ( data: data, encoding: . utf8)
14941596 {
14951597 header += " . Allowed values: \( enumString) "
@@ -1529,10 +1631,17 @@ import Foundation
15291631 let baseChat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
15301632 let schemaPrompt = includeSchemaInPrompt ? schemaPrompt ( for: schema) : nil
15311633 let chat = normalizeChatForStructuredGeneration ( baseChat, schemaPrompt: schemaPrompt)
1634+
1635+ let additionalContext = options [ custom: MLXLanguageModel . self] ? . additionalContextForUserInput
1636+ let userInputProcessing =
1637+ options [ custom: MLXLanguageModel . self] ? . processingForUserInput
1638+ ?? . init( resize: nil )
1639+
15321640 let userInput = MLXLMCommon . UserInput (
15331641 chat: chat,
1534- processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
1535- tools: nil
1642+ processing: userInputProcessing,
1643+ tools: nil ,
1644+ additionalContext: additionalContext,
15361645 )
15371646 let lmInput = try await context. processor. prepare ( input: userInput)
15381647
@@ -1773,4 +1882,18 @@ import Foundation
17731882 return sampledToken. item ( Int . self)
17741883 }
17751884 }
1885+ extension AnyLanguageModel . JSONValue {
1886+ /// Recursively converts a `JSONValue` to its primitive Swift equivalent.
1887+ func toSendable( ) -> any Sendable {
1888+ switch self {
1889+ case . string( let s) : return s
1890+ case . int( let i) : return i
1891+ case . double( let d) : return d
1892+ case . bool( let b) : return b
1893+ case . null: return NSNull ( )
1894+ case . array( let arr) : return arr. map { $0. toSendable ( ) }
1895+ case . object( let obj) : return obj. mapValues { $0. toSendable ( ) }
1896+ }
1897+ }
1898+ }
17761899#endif // MLX
0 commit comments