77from braceexpand import braceexpand
88from datasets import Dataset , load_dataset
99
10- from .text import TextNormalizer
10+ from .model . text import TextNormalizer
1111
1212
1313@dataclass
@@ -28,6 +28,11 @@ class Dataset:
2828 seed_dataset : int = None
2929 shard_by_host : bool = False
3030 blank_caption_prob : float = 0.0
31+ clip_score_column : str = "clip_score"
32+ min_clip_score : float = None
33+ max_clip_score : float = None
34+ filter_column : str = None
35+ filter_value : str = None
3136 train_dataset : Dataset = field (init = False )
3237 eval_dataset : Dataset = field (init = False )
3338 rng_dataset : jnp .ndarray = field (init = False )
@@ -36,6 +41,7 @@ class Dataset:
3641 def __post_init__ (self ):
3742 self .multi_hosts = jax .process_count () > 1
3843 # feed blank captions only in streaming mode for now
44+ # otherwise dataset could be cached with same blanked captions
3945 if self .blank_caption_prob :
4046 assert (
4147 self .streaming is True
@@ -107,23 +113,30 @@ def preprocess(self, tokenizer, config):
107113 self .seed_dataset = np .random .get_state ()[1 ][0 ]
108114 self .rng_dataset = jax .random .PRNGKey (self .seed_dataset )
109115
110- # blank captions
111- if self .blank_caption_prob :
112- partial_blank_caption_function = partial (
113- blank_caption_function ,
114- text_column = self .text_column ,
115- blank_caption_prob = self .blank_caption_prob ,
116- )
117- if hasattr (self , "train_dataset" ):
118- self .train_dataset = (
119- self .train_dataset .map (partial_blank_caption_function )
120- if self .streaming
121- else self .train_dataset .map (
122- partial_blank_caption_function ,
123- num_proc = self .preprocessing_num_workers ,
124- load_from_cache_file = False ,
125- desc = "Blanking some captions" ,
126- )
116+ # filter data
117+ partial_filter_function = partial (
118+ filter_function ,
119+ filter_column = self .filter_column ,
120+ filter_value = self .filter_value ,
121+ clip_score_column = self .clip_score_column ,
122+ min_clip_score = self .min_clip_score ,
123+ max_clip_score = self .max_clip_score ,
124+ )
125+ for ds in ["train_dataset" , "eval_dataset" ]:
126+ if hasattr (self , ds ):
127+ setattr (
128+ self ,
129+ ds ,
130+ (
131+ getattr (self , ds ).filter (partial_filter_function )
132+ if self .streaming
133+ else getattr (self , ds ).filter (
134+ partial_filter_function ,
135+ num_proc = self .preprocessing_num_workers ,
136+ load_from_cache_file = not self .overwrite_cache ,
137+ desc = "Filtering datasets" ,
138+ )
139+ ),
127140 )
128141
129142 # normalize text
@@ -151,6 +164,25 @@ def preprocess(self, tokenizer, config):
151164 ),
152165 )
153166
167+ # blank captions
168+ if self .blank_caption_prob :
169+ partial_blank_caption_function = partial (
170+ blank_caption_function ,
171+ text_column = self .text_column ,
172+ blank_caption_prob = self .blank_caption_prob ,
173+ )
174+ if hasattr (self , "train_dataset" ):
175+ self .train_dataset = (
176+ self .train_dataset .map (partial_blank_caption_function )
177+ if self .streaming
178+ else self .train_dataset .map (
179+ partial_blank_caption_function ,
180+ num_proc = self .preprocessing_num_workers ,
181+ load_from_cache_file = False ,
182+ desc = "Blanking some captions" ,
183+ )
184+ )
185+
154186 # preprocess
155187 partial_preprocess_function = partial (
156188 preprocess_function ,
@@ -230,8 +262,8 @@ def _dataloader_datasets_streaming(
230262 dataset .set_epoch (epoch )
231263 epoch += 1
232264 for item in dataset :
233- for k , v in item . items () :
234- batch [k ].append (v )
265+ for k in keys :
266+ batch [k ].append (item [ k ] )
235267 if len (batch [keys [0 ]]) == batch_size :
236268 batch = {k : jnp .array (v ) for k , v in batch .items ()}
237269 yield batch
@@ -292,6 +324,23 @@ def normalize_function(example, text_column, text_normalizer):
292324 return example
293325
294326
327+ def filter_function (
328+ example ,
329+ min_clip_score ,
330+ max_clip_score ,
331+ clip_score_column ,
332+ filter_column ,
333+ filter_value ,
334+ ):
335+ if min_clip_score is not None and example [clip_score_column ] < min_clip_score :
336+ return False
337+ if max_clip_score is not None and example [clip_score_column ] > max_clip_score :
338+ return False
339+ if filter_column is not None and example [filter_column ] != filter_value :
340+ return False
341+ return True
342+
343+
295344def preprocess_function (
296345 examples ,
297346 tokenizer ,
0 commit comments