@@ -18,20 +18,28 @@ import Foundation
1818 import Tokenizers
1919 import Hub
2020
21- /// Wrapper to store ModelContext in NSCache (requires NSObject subclass).
22- private final class CachedContext : NSObject , @unchecked Sendable {
23- let context : ModelContext
24- init ( _ context: ModelContext ) { self . context = context }
21+ /// Wrapper to store model availability state in NSCache.
22+ private final class CachedModelState : NSObject , @unchecked Sendable {
23+ enum Value {
24+ case loaded( ModelContext )
25+ case failed( String )
26+ }
27+
28+ let value : Value
29+
30+ init ( _ value: Value ) {
31+ self . value = value
32+ }
2533 }
2634
2735 /// Coordinates a bounded in-memory cache with structured, coalesced loading.
2836 private final class ModelContextCache {
29- private let cache : NSCache < NSString , CachedContext >
30- private let inFlight = Locked < [ String : Task < CachedContext , Error > ] > ( [ : ] )
37+ private let cache : NSCache < NSString , CachedModelState >
38+ private let inFlight = Locked < [ String : Task < CachedModelState , Error > ] > ( [ : ] )
3139
3240 /// Creates a cache with a count-based eviction limit.
3341 init ( countLimit: Int ) {
34- let cache = NSCache < NSString , CachedContext > ( )
42+ let cache = NSCache < NSString , CachedModelState > ( )
3543 cache. countLimit = countLimit
3644 self . cache = cache
3745 }
@@ -42,23 +50,45 @@ import Foundation
4250 loader: @escaping @Sendable ( ) async throws -> ModelContext
4351 ) async throws -> ModelContext {
4452 let cacheKey = key as NSString
45- if let cached = cache. object ( forKey: cacheKey) {
46- return cached. context
53+ if let cached = cache. object ( forKey: cacheKey) ,
54+ case . loaded( let context) = cached. value
55+ {
56+ return context
4757 }
4858
4959 if let task = inFlightTask ( for: key) {
50- return try await task. value. context
60+ let cached = try await task. value
61+ if case . loaded( let context) = cached. value {
62+ return context
63+ }
64+ throw CancellationError ( )
5165 }
5266
53- let task = Task { try await CachedContext ( loader ( ) ) }
67+ let task = Task {
68+ let context = try await loader ( )
69+ return CachedModelState ( . loaded( context) )
70+ }
5471 setInFlight ( task, for: key)
5572
5673 do {
5774 let cached = try await task. value
5875 cache. setObject ( cached, forKey: cacheKey)
5976 clearInFlight ( for: key)
60- return cached. context
77+ if case . loaded( let context) = cached. value {
78+ return context
79+ }
80+ throw CancellationError ( )
6181 } catch {
82+ // Don't treat cancellations as load failures.
83+ if error is CancellationError || Task . isCancelled {
84+ cache. removeObject ( forKey: cacheKey)
85+ clearInFlight ( for: key)
86+ throw error
87+ }
88+ cache. setObject (
89+ CachedModelState ( . failed( String ( reflecting: error) ) ) ,
90+ forKey: cacheKey
91+ )
6292 clearInFlight ( for: key)
6393 throw error
6494 }
@@ -74,6 +104,28 @@ import Foundation
74104 cache. removeAllObjects ( )
75105 }
76106
107+ /// Returns whether a cached context exists for the key.
108+ func contains( _ key: String ) -> Bool {
109+ guard let cached = cache. object ( forKey: key as NSString ) else {
110+ return false
111+ }
112+ if case . loaded = cached. value {
113+ return true
114+ }
115+ return false
116+ }
117+
118+ /// Returns a description of the most recent load failure for the key.
119+ func failureDescription( for key: String ) -> String ? {
120+ guard let cached = cache. object ( forKey: key as NSString ) else {
121+ return nil
122+ }
123+ if case . failed( let description) = cached. value {
124+ return description
125+ }
126+ return nil
127+ }
128+
77129 /// Cancels in-flight work and removes cached data for the key.
78130 func removeAndCancel( for key: String ) async {
79131 let task = removeInFlight ( for: key)
@@ -88,27 +140,27 @@ import Foundation
88140 cache. removeAllObjects ( )
89141 }
90142
91- private func inFlightTask( for key: String ) -> Task < CachedContext , Error > ? {
143+ private func inFlightTask( for key: String ) -> Task < CachedModelState , Error > ? {
92144 inFlight. withLock { $0 [ key] }
93145 }
94146
95- private func setInFlight( _ task: Task < CachedContext , Error > , for key: String ) {
147+ private func setInFlight( _ task: Task < CachedModelState , Error > , for key: String ) {
96148 inFlight. withLock { $0 [ key] = task }
97149 }
98150
99151 private func clearInFlight( for key: String ) {
100152 inFlight. withLock { $0 [ key] = nil }
101153 }
102154
103- private func removeInFlight( for key: String ) -> Task < CachedContext , Error > ? {
155+ private func removeInFlight( for key: String ) -> Task < CachedModelState , Error > ? {
104156 inFlight. withLock {
105157 let task = $0 [ key]
106158 $0 [ key] = nil
107159 return task
108160 }
109161 }
110162
111- private func removeAllInFlight( ) -> [ Task < CachedContext , Error > ] {
163+ private func removeAllInFlight( ) -> [ Task < CachedModelState , Error > ] {
112164 inFlight. withLock {
113165 let tasks = Array ( $0. values)
114166 $0. removeAll ( )
@@ -132,8 +184,12 @@ import Foundation
132184 /// ```
133185 public struct MLXLanguageModel : LanguageModel {
134186 /// The reason the model is unavailable.
135- /// This model is always available.
136- public typealias UnavailableReason = Never
187+ public enum UnavailableReason : Sendable , Equatable , Hashable {
188+ /// The model has not been loaded into memory yet.
189+ case notLoaded
190+ /// The model failed to load and includes the underlying error details.
191+ case failedToLoad( String )
192+ }
137193
138194 /// The model identifier.
139195 public let modelId : String
@@ -156,6 +212,20 @@ import Foundation
156212 self . directory = directory
157213 }
158214
215+ /// The current availability of this model in memory.
216+ public var availability : Availability < UnavailableReason > {
217+ let key = directory? . absoluteString ?? modelId
218+ if modelCache. contains ( key) {
219+ return . available
220+ }
221+
222+ if let failureDescription = modelCache. failureDescription ( for: key) {
223+ return . unavailable( . failedToLoad( failureDescription) )
224+ }
225+
226+ return . unavailable( . notLoaded)
227+ }
228+
159229 /// Removes this model from the shared cache and cancels any in-flight load.
160230 ///
161231 /// Call this to free memory when the model is no longer needed.
0 commit comments