@@ -65,6 +65,7 @@ class DDIMScheduler(Scheduler):
6565 set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
6666 For the final step there is no previous alpha. When this option is `True` the previous alpha product is
6767 fixed to `1`, otherwise it uses the value of alpha at step 0.
68+ A similar approach is used for reverse steps, setting this option to `True` will use zero as the first alpha.
6869 steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
6970 `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
7071 stable diffusion.
@@ -96,6 +97,10 @@ def __init__(
9697 # whether we use the final alpha of the "non-previous" one.
9798 self .final_alpha_cumprod = torch .tensor (1.0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
9899
100+ # For reverse steps, we require the next alphas_cumprod. Similary to above, the first step doesn't
101+ # have a next value so we can either set it to zero or use the second value.
102+ self .first_alpha_cumprod = torch .tensor (0.0 ) if set_alpha_to_one else self .alphas_cumprod [- 1 ]
103+
99104 # standard deviation of the initial noise distribution
100105 self .init_noise_sigma = 1.0
101106
@@ -234,7 +239,7 @@ def reversed_step(
234239 sample: current instance of sample being created by diffusion process.
235240
236241 Returns:
237- pred_prev_sample : Predicted previous sample
242+ pred_next_sample : Predicted next sample
238243 pred_original_sample: Predicted original sample
239244 """
240245 # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf
@@ -245,14 +250,14 @@ def reversed_step(
245250 # - std_dev_t -> sigma_t
246251 # - eta -> η
247252 # - pred_sample_direction -> "direction pointing to x_t"
248- # - pred_post_sample -> "x_t+1"
253+ # - pred_next_sample -> "x_t+1"
249254
250- # 1. get previous step value (=t+1)
251- prev_timestep = timestep + self .num_train_timesteps // self .num_inference_steps
255+ # 1. get next step value (=t+1)
256+ next_timestep = timestep + self .num_train_timesteps // self .num_inference_steps
252257
253258 # 2. compute alphas, betas at timestep t+1
254259 alpha_prod_t = self .alphas_cumprod [timestep ]
255- alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ] if prev_timestep >= 0 else self .final_alpha_cumprod
260+ alpha_prod_t_next = self .alphas_cumprod [next_timestep ] if next_timestep < len ( self . alphas_cumprod ) else self .first_alpha_cumprod
256261
257262 beta_prod_t = 1 - alpha_prod_t
258263
@@ -274,9 +279,9 @@ def reversed_step(
274279 pred_original_sample = torch .clamp (pred_original_sample , - 1 , 1 )
275280
276281 # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
277- pred_sample_direction = (1 - alpha_prod_t_prev ) ** (0.5 ) * pred_epsilon
282+ pred_sample_direction = (1 - alpha_prod_t_next ) ** (0.5 ) * pred_epsilon
278283
279284 # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
280- pred_post_sample = alpha_prod_t_prev ** (0.5 ) * pred_original_sample + pred_sample_direction
285+ pred_next_sample = alpha_prod_t_next ** (0.5 ) * pred_original_sample + pred_sample_direction
281286
282- return pred_post_sample , pred_original_sample
287+ return pred_next_sample , pred_original_sample
0 commit comments