@@ -55,6 +55,10 @@ class PNDMScheduler(nn.Module):
5555 each diffusion step uses the value of alphas product at that step and at the previous one. For the final
5656 step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
5757 otherwise it uses the value of alpha at step 0.
58+ prediction_type: {``"epsilon"``, ``"v_prediction"``}
59+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
60+ process) or `v_prediction` (see section 2.4
61+ https://imagen.research.google/video/paper.pdf)
5862 steps_offset:
5963 an offset added to the inference steps. You can use a combination of `offset=1` and
6064 `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
@@ -69,6 +73,7 @@ def __init__(
6973 beta_schedule : str = "linear" ,
7074 skip_prk_steps : bool = False ,
7175 set_alpha_to_one : bool = False ,
76+ prediction_type : str = "epsilon" ,
7277 steps_offset : int = 0 ,
7378 ) -> None :
7479 super ().__init__ ()
@@ -83,6 +88,10 @@ def __init__(
8388 else :
8489 raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
8590
91+ if prediction_type .lower () not in ["epsilon" , "v_prediction" ]:
92+ raise ValueError (f"prediction_type given as { prediction_type } must be one of `epsilon` or `v_prediction`" )
93+
94+ self .prediction_type = prediction_type
8695 self .num_train_timesteps = num_train_timesteps
8796 self .alphas = 1.0 - self .betas
8897 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
@@ -301,6 +310,9 @@ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: i
301310 beta_prod_t = 1 - alpha_prod_t
302311 beta_prod_t_prev = 1 - alpha_prod_t_prev
303312
313+ if self .prediction_type == "v_prediction" :
314+ model_output = (alpha_prod_t ** 0.5 ) * model_output + (beta_prod_t ** 0.5 ) * sample
315+
304316 # corresponds to (α_(t−δ) - α_t) divided by
305317 # denominator of x_t in formula (9) and plus 1
306318 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
0 commit comments