Skip to content

Commit 7b311b1

Browse files
authored
Add UnavailableReason for MLXLanguageModel (#138)
* Add UnavailableReason for MLXLanguageModel * Distinguish between not loaded and load failure * Refactor model state / context caching * Use removeFromCache to explicitly unload models when checking availability
1 parent 8532473 commit 7b311b1

2 files changed

Lines changed: 122 additions & 18 deletions

File tree

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 88 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,28 @@ import Foundation
1818
import Tokenizers
1919
import Hub
2020

21-
/// Wrapper to store ModelContext in NSCache (requires NSObject subclass).
22-
private final class CachedContext: NSObject, @unchecked Sendable {
23-
let context: ModelContext
24-
init(_ context: ModelContext) { self.context = context }
21+
/// Wrapper to store model availability state in NSCache.
22+
private final class CachedModelState: NSObject, @unchecked Sendable {
23+
enum Value {
24+
case loaded(ModelContext)
25+
case failed(String)
26+
}
27+
28+
let value: Value
29+
30+
init(_ value: Value) {
31+
self.value = value
32+
}
2533
}
2634

2735
/// Coordinates a bounded in-memory cache with structured, coalesced loading.
2836
private final class ModelContextCache {
29-
private let cache: NSCache<NSString, CachedContext>
30-
private let inFlight = Locked<[String: Task<CachedContext, Error>]>([:])
37+
private let cache: NSCache<NSString, CachedModelState>
38+
private let inFlight = Locked<[String: Task<CachedModelState, Error>]>([:])
3139

3240
/// Creates a cache with a count-based eviction limit.
3341
init(countLimit: Int) {
34-
let cache = NSCache<NSString, CachedContext>()
42+
let cache = NSCache<NSString, CachedModelState>()
3543
cache.countLimit = countLimit
3644
self.cache = cache
3745
}
@@ -42,23 +50,45 @@ import Foundation
4250
loader: @escaping @Sendable () async throws -> ModelContext
4351
) async throws -> ModelContext {
4452
let cacheKey = key as NSString
45-
if let cached = cache.object(forKey: cacheKey) {
46-
return cached.context
53+
if let cached = cache.object(forKey: cacheKey),
54+
case .loaded(let context) = cached.value
55+
{
56+
return context
4757
}
4858

4959
if let task = inFlightTask(for: key) {
50-
return try await task.value.context
60+
let cached = try await task.value
61+
if case .loaded(let context) = cached.value {
62+
return context
63+
}
64+
throw CancellationError()
5165
}
5266

53-
let task = Task { try await CachedContext(loader()) }
67+
let task = Task {
68+
let context = try await loader()
69+
return CachedModelState(.loaded(context))
70+
}
5471
setInFlight(task, for: key)
5572

5673
do {
5774
let cached = try await task.value
5875
cache.setObject(cached, forKey: cacheKey)
5976
clearInFlight(for: key)
60-
return cached.context
77+
if case .loaded(let context) = cached.value {
78+
return context
79+
}
80+
throw CancellationError()
6181
} catch {
82+
// Don't treat cancellations as load failures.
83+
if error is CancellationError || Task.isCancelled {
84+
cache.removeObject(forKey: cacheKey)
85+
clearInFlight(for: key)
86+
throw error
87+
}
88+
cache.setObject(
89+
CachedModelState(.failed(String(reflecting: error))),
90+
forKey: cacheKey
91+
)
6292
clearInFlight(for: key)
6393
throw error
6494
}
@@ -74,6 +104,28 @@ import Foundation
74104
cache.removeAllObjects()
75105
}
76106

107+
/// Returns whether a cached context exists for the key.
108+
func contains(_ key: String) -> Bool {
109+
guard let cached = cache.object(forKey: key as NSString) else {
110+
return false
111+
}
112+
if case .loaded = cached.value {
113+
return true
114+
}
115+
return false
116+
}
117+
118+
/// Returns a description of the most recent load failure for the key.
119+
func failureDescription(for key: String) -> String? {
120+
guard let cached = cache.object(forKey: key as NSString) else {
121+
return nil
122+
}
123+
if case .failed(let description) = cached.value {
124+
return description
125+
}
126+
return nil
127+
}
128+
77129
/// Cancels in-flight work and removes cached data for the key.
78130
func removeAndCancel(for key: String) async {
79131
let task = removeInFlight(for: key)
@@ -88,27 +140,27 @@ import Foundation
88140
cache.removeAllObjects()
89141
}
90142

91-
private func inFlightTask(for key: String) -> Task<CachedContext, Error>? {
143+
private func inFlightTask(for key: String) -> Task<CachedModelState, Error>? {
92144
inFlight.withLock { $0[key] }
93145
}
94146

95-
private func setInFlight(_ task: Task<CachedContext, Error>, for key: String) {
147+
private func setInFlight(_ task: Task<CachedModelState, Error>, for key: String) {
96148
inFlight.withLock { $0[key] = task }
97149
}
98150

99151
private func clearInFlight(for key: String) {
100152
inFlight.withLock { $0[key] = nil }
101153
}
102154

103-
private func removeInFlight(for key: String) -> Task<CachedContext, Error>? {
155+
private func removeInFlight(for key: String) -> Task<CachedModelState, Error>? {
104156
inFlight.withLock {
105157
let task = $0[key]
106158
$0[key] = nil
107159
return task
108160
}
109161
}
110162

111-
private func removeAllInFlight() -> [Task<CachedContext, Error>] {
163+
private func removeAllInFlight() -> [Task<CachedModelState, Error>] {
112164
inFlight.withLock {
113165
let tasks = Array($0.values)
114166
$0.removeAll()
@@ -132,8 +184,12 @@ import Foundation
132184
/// ```
133185
public struct MLXLanguageModel: LanguageModel {
134186
/// The reason the model is unavailable.
135-
/// This model is always available.
136-
public typealias UnavailableReason = Never
187+
public enum UnavailableReason: Sendable, Equatable, Hashable {
188+
/// The model has not been loaded into memory yet.
189+
case notLoaded
190+
/// The model failed to load and includes the underlying error details.
191+
case failedToLoad(String)
192+
}
137193

138194
/// The model identifier.
139195
public let modelId: String
@@ -156,6 +212,20 @@ import Foundation
156212
self.directory = directory
157213
}
158214

215+
/// The current availability of this model in memory.
216+
public var availability: Availability<UnavailableReason> {
217+
let key = directory?.absoluteString ?? modelId
218+
if modelCache.contains(key) {
219+
return .available
220+
}
221+
222+
if let failureDescription = modelCache.failureDescription(for: key) {
223+
return .unavailable(.failedToLoad(failureDescription))
224+
}
225+
226+
return .unavailable(.notLoaded)
227+
}
228+
159229
/// Removes this model from the shared cache and cancels any in-flight load.
160230
///
161231
/// Call this to free memory when the model is no longer needed.

Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ import Testing
3535
let model = MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit")
3636
let visionModel = MLXLanguageModel(modelId: "mlx-community/Qwen2-VL-2B-Instruct-4bit")
3737

38+
@Test func availabilityBecomesAvailableAfterSuccessfulLoad() async throws {
39+
await model.removeFromCache()
40+
41+
#expect(model.availability == .unavailable(.notLoaded))
42+
#expect(model.isAvailable == false)
43+
44+
let session = LanguageModelSession(model: model)
45+
let response = try await session.respond(to: "Say hello")
46+
#expect(!response.content.isEmpty)
47+
48+
#expect(model.availability == .available)
49+
#expect(model.isAvailable == true)
50+
}
51+
3852
@Test func basicResponse() async throws {
3953
let session = LanguageModelSession(model: model)
4054

@@ -205,5 +219,25 @@ import Testing
205219
)
206220
#expect([Priority.low, Priority.medium, Priority.high].contains(response.content))
207221
}
222+
223+
@Test func unavailableForNonexistentModel() async {
224+
let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test")
225+
await model.removeFromCache()
226+
#expect(model.availability == .unavailable(.notLoaded))
227+
#expect(model.isAvailable == false)
228+
229+
let session = LanguageModelSession(model: model)
230+
await #expect(throws: Error.self) {
231+
_ = try await session.respond(to: "Hello")
232+
}
233+
234+
switch model.availability {
235+
case .unavailable(.failedToLoad(let description)):
236+
#expect(!description.isEmpty)
237+
default:
238+
Issue.record("Expected model availability to report failedToLoad after failed request")
239+
}
240+
#expect(model.isAvailable == false)
241+
}
208242
}
209243
#endif // MLX

0 commit comments

Comments
 (0)