Skip to content

Commit 7939874

Browse files
authored
feat(data): super conditioning (borisdayma#141)
* feat(data): online filtering * feat(generate): super conditioning * feat: add processor
1 parent 803ccbf commit 7939874

File tree

9 files changed

+515
-68
lines changed

9 files changed

+515
-68
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ To generate sample predictions and understand the inference pipeline step by ste
3535
Join the community on the [DALLE-Pytorch Discord](https://discord.gg/xBPBXfcFHd).
3636
Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model with cool prompts!
3737

38-
3938
## Development
4039

4140
### Dependencies Installation
@@ -95,6 +94,7 @@ Many thanks to the people who helped make it better:
9594

9695
- the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
9796
- [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
97+
- [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
9898

9999
## Citing DALL·E mini
100100

src/dalle_mini/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
__version__ = "0.0.2"
1+
__version__ = "0.0.3"
2+
3+
from .model import DalleBart, DalleBartProcessor

src/dalle_mini/data.py

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from braceexpand import braceexpand
88
from datasets import Dataset, load_dataset
99

10-
from .text import TextNormalizer
10+
from .model.text import TextNormalizer
1111

1212

1313
@dataclass
@@ -28,6 +28,11 @@ class Dataset:
2828
seed_dataset: int = None
2929
shard_by_host: bool = False
3030
blank_caption_prob: float = 0.0
31+
clip_score_column: str = "clip_score"
32+
min_clip_score: float = None
33+
max_clip_score: float = None
34+
filter_column: str = None
35+
filter_value: str = None
3136
train_dataset: Dataset = field(init=False)
3237
eval_dataset: Dataset = field(init=False)
3338
rng_dataset: jnp.ndarray = field(init=False)
@@ -36,6 +41,7 @@ class Dataset:
3641
def __post_init__(self):
3742
self.multi_hosts = jax.process_count() > 1
3843
# feed blank captions only in streaming mode for now
44+
# otherwise dataset could be cached with same blanked captions
3945
if self.blank_caption_prob:
4046
assert (
4147
self.streaming is True
@@ -107,23 +113,30 @@ def preprocess(self, tokenizer, config):
107113
self.seed_dataset = np.random.get_state()[1][0]
108114
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
109115

110-
# blank captions
111-
if self.blank_caption_prob:
112-
partial_blank_caption_function = partial(
113-
blank_caption_function,
114-
text_column=self.text_column,
115-
blank_caption_prob=self.blank_caption_prob,
116-
)
117-
if hasattr(self, "train_dataset"):
118-
self.train_dataset = (
119-
self.train_dataset.map(partial_blank_caption_function)
120-
if self.streaming
121-
else self.train_dataset.map(
122-
partial_blank_caption_function,
123-
num_proc=self.preprocessing_num_workers,
124-
load_from_cache_file=False,
125-
desc="Blanking some captions",
126-
)
116+
# filter data
117+
partial_filter_function = partial(
118+
filter_function,
119+
filter_column=self.filter_column,
120+
filter_value=self.filter_value,
121+
clip_score_column=self.clip_score_column,
122+
min_clip_score=self.min_clip_score,
123+
max_clip_score=self.max_clip_score,
124+
)
125+
for ds in ["train_dataset", "eval_dataset"]:
126+
if hasattr(self, ds):
127+
setattr(
128+
self,
129+
ds,
130+
(
131+
getattr(self, ds).filter(partial_filter_function)
132+
if self.streaming
133+
else getattr(self, ds).filter(
134+
partial_filter_function,
135+
num_proc=self.preprocessing_num_workers,
136+
load_from_cache_file=not self.overwrite_cache,
137+
desc="Filtering datasets",
138+
)
139+
),
127140
)
128141

129142
# normalize text
@@ -151,6 +164,25 @@ def preprocess(self, tokenizer, config):
151164
),
152165
)
153166

167+
# blank captions
168+
if self.blank_caption_prob:
169+
partial_blank_caption_function = partial(
170+
blank_caption_function,
171+
text_column=self.text_column,
172+
blank_caption_prob=self.blank_caption_prob,
173+
)
174+
if hasattr(self, "train_dataset"):
175+
self.train_dataset = (
176+
self.train_dataset.map(partial_blank_caption_function)
177+
if self.streaming
178+
else self.train_dataset.map(
179+
partial_blank_caption_function,
180+
num_proc=self.preprocessing_num_workers,
181+
load_from_cache_file=False,
182+
desc="Blanking some captions",
183+
)
184+
)
185+
154186
# preprocess
155187
partial_preprocess_function = partial(
156188
preprocess_function,
@@ -230,8 +262,8 @@ def _dataloader_datasets_streaming(
230262
dataset.set_epoch(epoch)
231263
epoch += 1
232264
for item in dataset:
233-
for k, v in item.items():
234-
batch[k].append(v)
265+
for k in keys:
266+
batch[k].append(item[k])
235267
if len(batch[keys[0]]) == batch_size:
236268
batch = {k: jnp.array(v) for k, v in batch.items()}
237269
yield batch
@@ -292,6 +324,23 @@ def normalize_function(example, text_column, text_normalizer):
292324
return example
293325

294326

327+
def filter_function(
328+
example,
329+
min_clip_score,
330+
max_clip_score,
331+
clip_score_column,
332+
filter_column,
333+
filter_value,
334+
):
335+
if min_clip_score is not None and example[clip_score_column] < min_clip_score:
336+
return False
337+
if max_clip_score is not None and example[clip_score_column] > max_clip_score:
338+
return False
339+
if filter_column is not None and example[filter_column] != filter_value:
340+
return False
341+
return True
342+
343+
295344
def preprocess_function(
296345
examples,
297346
tokenizer,

src/dalle_mini/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .configuration import DalleBartConfig
22
from .modeling import DalleBart
33
from .partitions import set_partitions
4+
from .processor import DalleBartProcessor
45
from .tokenizer import DalleBartTokenizer

0 commit comments

Comments
 (0)