diff --git a/flaxdiff/data/datasets.py b/flaxdiff/data/datasets.py index f17448f..f054a8d 100644 --- a/flaxdiff/data/datasets.py +++ b/flaxdiff/data/datasets.py @@ -39,11 +39,11 @@ def get_dataset_grain( augmenter = dataset["augmenter"](image_scale, method) local_batch_size = batch_size // jax.process_count() - model, tokenizer = defaultTextEncodeModel() + # model, tokenizer = defaultTextEncodeModel() - null_labels, null_labels_full = encodePrompts([""], model, tokenizer) - null_labels = np.array(null_labels[0], dtype=np.float16) - null_labels_full = np.array(null_labels_full[0], dtype=np.float16) + # null_labels, null_labels_full = encodePrompts([""], model, tokenizer) + # null_labels = np.array(null_labels[0], dtype=np.float16) + # null_labels_full = np.array(null_labels_full[0], dtype=np.float16) sampler = pygrain.IndexSampler( num_records=len(data_source) if count is None else count, @@ -80,13 +80,13 @@ def get_trainset(): "train_len": len(data_source), "local_batch_size": local_batch_size, "global_batch_size": batch_size, - "null_labels": null_labels, - "null_labels_full": null_labels_full, - "model": model, - "tokenizer": tokenizer, + # "null_labels": null_labels, + # "null_labels_full": null_labels_full, + # "model": model, + # "tokenizer": tokenizer, } -def generate_collate_fn(tokenizer): +def generate_collate_fn(): auto_tokenize = AutoTextTokenizer(tensor_type="np") def default_collate(batch): try: @@ -121,11 +121,11 @@ def get_dataset_online( ): local_batch_size = batch_size // jax.process_count() - model, tokenizer = defaultTextEncodeModel() + # model, tokenizer = defaultTextEncodeModel() - null_labels, null_labels_full = encodePrompts([""], model, tokenizer) - null_labels = np.array(null_labels[0], dtype=np.float16) - null_labels_full = np.array(null_labels_full[0], dtype=np.float16) + # null_labels, null_labels_full = encodePrompts([""], model, tokenizer) + # null_labels = np.array(null_labels[0], dtype=np.float16) + # null_labels_full = np.array(null_labels_full[0], dtype=np.float16) sources = onlineDatasetMap[data_name]["source"] dataloader = OnlineStreamingDataLoader( @@ -137,7 +137,7 @@ def get_dataset_online( global_process_count=jax.process_count(), global_process_index=jax.process_index(), prefetch=worker_buffer_size, - collate_fn=generate_collate_fn(tokenizer), + collate_fn=generate_collate_fn(), default_split="train", ) @@ -173,8 +173,8 @@ def __next__(self): "train_len": len(dataloader) * jax.process_count(), "local_batch_size": local_batch_size, "global_batch_size": batch_size, - "null_labels": null_labels, - "null_labels_full": null_labels_full, - "model": model, - "tokenizer": tokenizer, + # "null_labels": null_labels, + # "null_labels_full": null_labels_full, + # "model": model, + # "tokenizer": tokenizer, } \ No newline at end of file diff --git a/flaxdiff/trainer/autoencoder_trainer.py b/flaxdiff/trainer/autoencoder_trainer.py index b52eab4..f57dccb 100644 --- a/flaxdiff/trainer/autoencoder_trainer.py +++ b/flaxdiff/trainer/autoencoder_trainer.py @@ -14,7 +14,7 @@ from flaxdiff.utils import RandomMarkovState from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics - +from .diffusion_trainer import TrainState from flaxdiff.models.autoencoder.autoencoder import AutoEncoder class AutoEncoderTrainer(SimpleTrainer): diff --git a/flaxdiff/trainer/diffusion_trainer.py b/flaxdiff/trainer/diffusion_trainer.py index 32e87bb..c218364 100644 --- a/flaxdiff/trainer/diffusion_trainer.py +++ b/flaxdiff/trainer/diffusion_trainer.py @@ -19,6 +19,7 @@ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder from flax.training import dynamic_scale as dynamic_scale_lib +from flaxdiff.utils import TextEncoder, ConditioningEncoder class TrainState(SimpleTrainState): rngs: jax.random.PRNGKey @@ -49,6 +50,7 @@ def __init__(self, name: str = "Diffusion", model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(), autoencoder: AutoEncoder = None, + encoder: ConditioningEncoder = None, **kwargs ): super().__init__( @@ -64,6 +66,7 @@ def __init__(self, self.unconditional_prob = unconditional_prob self.autoencoder = autoencoder + self.encoder = encoder def generate_states( self, @@ -106,7 +109,7 @@ def generate_states( return state, best_state - def _define_train_step(self, batch_size, null_labels_seq, text_embedder): + def _define_train_step(self, batch_size): noise_schedule: NoiseScheduler = self.noise_schedule model = self.model model_output_transform = self.model_output_transform @@ -115,6 +118,11 @@ def _define_train_step(self, batch_size, null_labels_seq, text_embedder): # Determine the number of unconditional samples num_unconditional = int(batch_size * unconditional_prob) + + _, null_labels_full = self.encoder([""]) + null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16) + + conditioning_encoder = self.encoder.model nS, nC = null_labels_seq.shape null_labels_seq = jnp.broadcast_to( @@ -146,7 +154,7 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc local_rng_state, rngs = local_rng_state.get_random_key() images = autoencoder.encode(images, rngs) - output = text_embedder( + output = conditioning_encoder( input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) label_seq = output.last_hidden_state @@ -231,8 +239,5 @@ def compute_metrics(state: TrainState, expected, pred): return compute_metrics def fit(self, data, steps_per_epoch, epochs): - null_labels_full = data['null_labels_full'] local_batch_size = data['local_batch_size'] - text_embedder = data['model'] - super().fit(data, steps_per_epoch, epochs, { - "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder}) + super().fit(data, steps_per_epoch, epochs, {"batch_size": local_batch_size}) diff --git a/flaxdiff/utils.py b/flaxdiff/utils.py index a6b2114..cca26cc 100644 --- a/flaxdiff/utils.py +++ b/flaxdiff/utils.py @@ -7,6 +7,7 @@ from functools import partial import numpy as np from jax.sharding import Mesh, PartitionSpec as P +from abc import ABC, abstractmethod class MarkovState(struct.PyTreeNode): pass @@ -115,21 +116,38 @@ def _normalize( mul *= scale y = mul * x return jnp.asarray(y, dtype) - - + @dataclass -class TextEncoder: +class ConditioningEncoder(ABC): model: nn.Module tokenizer: Callable + + def __call__(self, data): + tokens = self.tokenize(data) + outputs = self.encode_from_tokens(tokens) + return outputs + + def encode_from_tokens(self, tokens): + outputs = self.model(input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask']) + last_hidden_state = outputs.last_hidden_state + return last_hidden_state - def __call__(self, prompts): - # inputs = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="np") - inputs = self.tokenizer(prompts, padding="max_length", + def tokenize(self, data): + tokens = self.tokenizer(data, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np") - outputs = self.model(input_ids=inputs['input_ids'], - attention_mask=inputs['attention_mask']) - # outputs = infer(inputs['input_ids'], inputs['attention_mask']) - + return tokens + +@dataclass +class TextEncoder(ConditioningEncoder): + def __call__(self, data): + tokens = self.tokenize(data) + outputs = self.encode_from_tokens(tokens) + return outputs + + def encode_from_tokens(self, tokens): + outputs = self.model(input_ids=tokens['input_ids'], + attention_mask=tokens['attention_mask']) last_hidden_state = outputs.last_hidden_state pooler_output = outputs.pooler_output # pooled (EOS token) states embed_pooled = pooler_output # .astype(jnp.float16)