@@ -22,6 +22,7 @@ public abstract class DiffuserBase : IDiffuser
2222 protected readonly UNetConditionModel _unet ;
2323 protected readonly AutoEncoderModel _vaeDecoder ;
2424 protected readonly AutoEncoderModel _vaeEncoder ;
25+ protected readonly MemoryModeType _memoryMode ;
2526
2627 /// <summary>
2728 /// Initializes a new instance of the <see cref="DiffuserBase"/> class.
@@ -31,12 +32,13 @@ public abstract class DiffuserBase : IDiffuser
3132 /// <param name="vaeDecoder">The vae decoder.</param>
3233 /// <param name="vaeEncoder">The vae encoder.</param>
3334 /// <param name="logger">The logger.</param>
34- public DiffuserBase ( UNetConditionModel unet , AutoEncoderModel vaeDecoder , AutoEncoderModel vaeEncoder , ILogger logger = default )
35+ public DiffuserBase ( UNetConditionModel unet , AutoEncoderModel vaeDecoder , AutoEncoderModel vaeEncoder , MemoryModeType memoryMode , ILogger logger = default )
3536 {
3637 _logger = logger ;
3738 _unet = unet ;
3839 _vaeDecoder = vaeDecoder ;
3940 _vaeEncoder = vaeEncoder ;
41+ _memoryMode = memoryMode ;
4042 }
4143
4244 /// <summary>
@@ -137,10 +139,15 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOption
137139 var results = await _vaeDecoder . RunInferenceAsync ( inferenceParameters ) ;
138140 using ( var imageResult = results . First ( ) )
139141 {
142+ // Unload if required
143+ if ( _memoryMode != MemoryModeType . Maximum )
144+ await _vaeDecoder . UnloadAsync ( ) ;
145+
140146 _logger ? . LogEnd ( "Latents decoded" , timestamp ) ;
141147 return imageResult . ToDenseTensor ( ) ;
142148 }
143149 }
150+
144151 }
145152
146153
0 commit comments