@@ -29,25 +29,18 @@ def setup_seed(seed: int) -> None:
2929 torch .backends .cudnn .deterministic = True
3030
3131
32- class StreamGenerationConfig (GenerationConfig ):
33- def __init__ (self , ** kwargs ):
34- super ().__init__ (** kwargs )
35- self .do_stream = kwargs .pop ("do_stream" , False )
36-
37-
3832class NewGenerationMixin (GenerationMixin ):
3933 @torch .inference_mode ()
4034 def generate ( # noqa: PLR0911
4135 self ,
4236 inputs : torch .Tensor | None = None ,
43- generation_config : StreamGenerationConfig | None = None ,
37+ generation_config : GenerationConfig | None = None ,
4438 logits_processor : LogitsProcessorList | None = None ,
4539 stopping_criteria : StoppingCriteriaList | None = None ,
4640 prefix_allowed_tokens_fn : Callable [[int , torch .Tensor ], list [int ]] | None = None ,
4741 synced_gpus : bool | None = False ,
4842 assistant_model : PreTrainedModel | None = None ,
4943 streamer : "BaseStreamer | None" = None ,
50- use_model_defaults : bool | None = None ,
5144 custom_generate : str | Callable | None = None ,
5245 seed : int = 0 ,
5346 ** kwargs ,
@@ -102,11 +95,6 @@ def generate( # noqa: PLR0911
10295 same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
10396 is much faster than running generation with the model you're calling generate from. As such, the
10497 assistant model should be much smaller.
105- use_model_defaults (`bool`, *optional*):
106- When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
107- generation configuration (`model.generation_config`), as opposed to the global defaults
108- (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
109- `True`.
11098 kwargs:
11199 Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
112100 forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@@ -137,10 +125,16 @@ def generate( # noqa: PLR0911
137125 generation_mode_kwargs = self ._extract_generation_mode_kwargs (
138126 custom_generate , kwargs , synced_gpus , assistant_model , streamer
139127 )
140-
141- generation_config , model_kwargs = self ._prepare_generation_config (
142- generation_config , use_model_defaults , ** kwargs
128+ # Check length values before updating the config with defaults.
129+ # We'll use it later to define the final min/max length (# 6)
130+ has_default_max_length = kwargs .get ("max_length" ) is None and (
131+ generation_config is None or generation_config .max_length is None
143132 )
133+ has_default_min_length = kwargs .get ("min_length" ) is None and (
134+ generation_config is None or generation_config .min_length is None
135+ )
136+ generation_config , model_kwargs = self ._prepare_generation_config (generation_config , ** kwargs )
137+
144138 generation_mode = generation_config .get_generation_mode (assistant_model )
145139 self ._validate_model_kwargs (model_kwargs .copy ())
146140 self ._validate_generation_mode (generation_mode , generation_config , generation_mode_kwargs )
@@ -212,8 +206,6 @@ def generate( # noqa: PLR0911
212206
213207 # 6. Prepare `max_length` depending on other stopping criteria.
214208 input_ids_length = input_ids .shape [- 1 ]
215- has_default_max_length = kwargs .get ("max_length" ) is None and generation_config .max_length is not None
216- has_default_min_length = kwargs .get ("min_length" ) is None and generation_config .min_length is not None
217209 generation_config = self ._prepare_generated_length (
218210 generation_config = generation_config ,
219211 has_default_max_length = has_default_max_length ,
@@ -500,7 +492,6 @@ def init_stream_support():
500492 repetition_penalty = 1.2 ,
501493 early_stopping = True ,
502494 seed = 0 ,
503- do_stream = True ,
504495 )
505496 stream_result = ""
506497 for x in generator :
0 commit comments