Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 894f2ec

Browse files
authored
Fix the way reversed_step in ddim treats the first alpha in alphas_cumprod.
This value in now set to zero or taken from the last element in the arrary depending on a value of an init argument. Also, names of local variables in reversed_step are fixed to reflect the actual logic. (#452) Co-authored-by: Matan Atad <matan.atad@tum.de>
1 parent a3762b9 commit 894f2ec

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

  • generative/networks/schedulers

generative/networks/schedulers/ddim.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class DDIMScheduler(Scheduler):
6565
set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
6666
For the final step there is no previous alpha. When this option is `True` the previous alpha product is
6767
fixed to `1`, otherwise it uses the value of alpha at step 0.
68+
A similar approach is used for reverse steps, setting this option to `True` will use zero as the first alpha.
6869
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
6970
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
7071
stable diffusion.
@@ -96,6 +97,10 @@ def __init__(
9697
# whether we use the final alpha of the "non-previous" one.
9798
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
9899

100+
# For reverse steps, we require the next alphas_cumprod. Similary to above, the first step doesn't
101+
# have a next value so we can either set it to zero or use the second value.
102+
self.first_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_one else self.alphas_cumprod[-1]
103+
99104
# standard deviation of the initial noise distribution
100105
self.init_noise_sigma = 1.0
101106

@@ -234,7 +239,7 @@ def reversed_step(
234239
sample: current instance of sample being created by diffusion process.
235240
236241
Returns:
237-
pred_prev_sample: Predicted previous sample
242+
pred_next_sample: Predicted next sample
238243
pred_original_sample: Predicted original sample
239244
"""
240245
# See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf
@@ -245,14 +250,14 @@ def reversed_step(
245250
# - std_dev_t -> sigma_t
246251
# - eta -> η
247252
# - pred_sample_direction -> "direction pointing to x_t"
248-
# - pred_post_sample -> "x_t+1"
253+
# - pred_next_sample -> "x_t+1"
249254

250-
# 1. get previous step value (=t+1)
251-
prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
255+
# 1. get next step value (=t+1)
256+
next_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
252257

253258
# 2. compute alphas, betas at timestep t+1
254259
alpha_prod_t = self.alphas_cumprod[timestep]
255-
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
260+
alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod
256261

257262
beta_prod_t = 1 - alpha_prod_t
258263

@@ -274,9 +279,9 @@ def reversed_step(
274279
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
275280

276281
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
277-
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
282+
pred_sample_direction = (1 - alpha_prod_t_next) ** (0.5) * pred_epsilon
278283

279284
# 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
280-
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
285+
pred_next_sample = alpha_prod_t_next ** (0.5) * pred_original_sample + pred_sample_direction
281286

282-
return pred_post_sample, pred_original_sample
287+
return pred_next_sample, pred_original_sample

0 commit comments

Comments
 (0)