@@ -27,13 +27,19 @@ class Dataset:
2727 do_eval : bool = True
2828 seed_dataset : int = None
2929 shard_by_host : bool = False
30+ blank_caption_prob : float = 0.0
3031 train_dataset : Dataset = field (init = False )
3132 eval_dataset : Dataset = field (init = False )
3233 rng_dataset : jnp .ndarray = field (init = False )
3334 multi_hosts : bool = field (init = False )
3435
3536 def __post_init__ (self ):
3637 self .multi_hosts = jax .process_count () > 1
38+ # feed blank captions only in streaming mode for now
39+ if self .blank_caption_prob :
40+ assert (
41+ self .streaming is True
42+ ), "blank_caption_prob can only be used in streaming mode"
3743 # define data_files
3844 if self .train_file is not None or self .validation_file is not None :
3945 # accept braceexpand notation
@@ -101,6 +107,25 @@ def preprocess(self, tokenizer, config):
101107 self .seed_dataset = np .random .get_state ()[1 ][0 ]
102108 self .rng_dataset = jax .random .PRNGKey (self .seed_dataset )
103109
110+ # blank captions
111+ if self .blank_caption_prob :
112+ partial_blank_caption_function = partial (
113+ blank_caption_function ,
114+ text_column = self .text_column ,
115+ blank_caption_prob = self .blank_caption_prob ,
116+ )
117+ if hasattr (self , "train_dataset" ):
118+ self .train_dataset = (
119+ self .train_dataset .map (partial_blank_caption_function )
120+ if self .streaming
121+ else self .train_dataset .map (
122+ partial_blank_caption_function ,
123+ num_proc = self .preprocessing_num_workers ,
124+ load_from_cache_file = False ,
125+ desc = "Blanking some captions" ,
126+ )
127+ )
128+
104129 # normalize text
105130 if normalize_text :
106131 text_normalizer = TextNormalizer ()
@@ -144,6 +169,10 @@ def preprocess(self, tokenizer, config):
144169 getattr (self , ds ).map (
145170 partial_preprocess_function ,
146171 batched = True ,
172+ remove_columns = [
173+ self .text_column ,
174+ self .encoding_column ,
175+ ],
147176 )
148177 if self .streaming
149178 else getattr (self , ds ).map (
@@ -193,8 +222,8 @@ def _dataloader_datasets_streaming(
193222 while (self .multi_hosts and split == "train" ) or first_loop :
194223 # in multi-host, we run forever (no epoch) as hosts need to stop
195224 # at the same time and training data may not be split equally
196- # For validation data we put the entire set on each host as we could lose
197- # too many samples on pods
225+ # For validation data we put the entire batch on each host and then
226+ # keep only the one specific to each host (could be improved but not necessary)
198227 if epoch is not None :
199228 assert split == "train"
200229 # reshuffle training data at each epoch
@@ -252,6 +281,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
252281 return shifted_input_ids
253282
254283
284+ def blank_caption_function (example , text_column , blank_caption_prob ):
285+ if blank_caption_prob and np .random .rand () < blank_caption_prob :
286+ example [text_column ] = ""
287+ return example
288+
289+
255290def normalize_function (example , text_column , text_normalizer ):
256291 example [text_column ] = text_normalizer (example [text_column ])
257292 return example
0 commit comments