Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 34c4b29

Browse files
authored
Add v_prediction and update docstrings (#165)
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
1 parent cd7d30a commit 34c4b29

3 files changed

Lines changed: 19 additions & 2 deletions

File tree

generative/networks/schedulers/ddim.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ class DDIMScheduler(nn.Module):
5555
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
5656
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
5757
stable diffusion.
58-
prediction_type: prediction type of the scheduler function, one of `epsilon` (predicting the noise of the
59-
diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
58+
prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``}
59+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
60+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
6061
https://imagen.research.google/video/paper.pdf)
6162
"""
6263

generative/networks/schedulers/ddpm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ class DDPMScheduler(nn.Module):
5151
variance_type: {``"fixed_small"``, ``"fixed_large"``, ``"learned"``, ``"learned_range"``}
5252
options to clip the variance used when adding noise to the denoised sample.
5353
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
54+
prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``}
55+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
56+
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
57+
https://imagen.research.google/video/paper.pdf)
5458
"""
5559

5660
def __init__(

generative/networks/schedulers/pndm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class PNDMScheduler(nn.Module):
5555
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
5656
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
5757
otherwise it uses the value of alpha at step 0.
58+
prediction_type: {``"epsilon"``, ``"v_prediction"``}
59+
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
60+
process) or `v_prediction` (see section 2.4
61+
https://imagen.research.google/video/paper.pdf)
5862
steps_offset:
5963
an offset added to the inference steps. You can use a combination of `offset=1` and
6064
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
@@ -69,6 +73,7 @@ def __init__(
6973
beta_schedule: str = "linear",
7074
skip_prk_steps: bool = False,
7175
set_alpha_to_one: bool = False,
76+
prediction_type: str = "epsilon",
7277
steps_offset: int = 0,
7378
) -> None:
7479
super().__init__()
@@ -83,6 +88,10 @@ def __init__(
8388
else:
8489
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
8590

91+
if prediction_type.lower() not in ["epsilon", "v_prediction"]:
92+
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon` or `v_prediction`")
93+
94+
self.prediction_type = prediction_type
8695
self.num_train_timesteps = num_train_timesteps
8796
self.alphas = 1.0 - self.betas
8897
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -301,6 +310,9 @@ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: i
301310
beta_prod_t = 1 - alpha_prod_t
302311
beta_prod_t_prev = 1 - alpha_prod_t_prev
303312

313+
if self.prediction_type == "v_prediction":
314+
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
315+
304316
# corresponds to (α_(t−δ) - α_t) divided by
305317
# denominator of x_t in formula (9) and plus 1
306318
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =

0 commit comments

Comments
 (0)