Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions flaxdiff/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def get_dataset_grain(
augmenter = dataset["augmenter"](image_scale, method)

local_batch_size = batch_size // jax.process_count()
model, tokenizer = defaultTextEncodeModel()
# model, tokenizer = defaultTextEncodeModel()

null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
null_labels = np.array(null_labels[0], dtype=np.float16)
null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
# null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
# null_labels = np.array(null_labels[0], dtype=np.float16)
# null_labels_full = np.array(null_labels_full[0], dtype=np.float16)

sampler = pygrain.IndexSampler(
num_records=len(data_source) if count is None else count,
Expand Down Expand Up @@ -80,13 +80,13 @@ def get_trainset():
"train_len": len(data_source),
"local_batch_size": local_batch_size,
"global_batch_size": batch_size,
"null_labels": null_labels,
"null_labels_full": null_labels_full,
"model": model,
"tokenizer": tokenizer,
# "null_labels": null_labels,
# "null_labels_full": null_labels_full,
# "model": model,
# "tokenizer": tokenizer,
}

def generate_collate_fn(tokenizer):
def generate_collate_fn():
auto_tokenize = AutoTextTokenizer(tensor_type="np")
def default_collate(batch):
try:
Expand Down Expand Up @@ -121,11 +121,11 @@ def get_dataset_online(
):
local_batch_size = batch_size // jax.process_count()

model, tokenizer = defaultTextEncodeModel()
# model, tokenizer = defaultTextEncodeModel()

null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
null_labels = np.array(null_labels[0], dtype=np.float16)
null_labels_full = np.array(null_labels_full[0], dtype=np.float16)
# null_labels, null_labels_full = encodePrompts([""], model, tokenizer)
# null_labels = np.array(null_labels[0], dtype=np.float16)
# null_labels_full = np.array(null_labels_full[0], dtype=np.float16)

sources = onlineDatasetMap[data_name]["source"]
dataloader = OnlineStreamingDataLoader(
Expand All @@ -137,7 +137,7 @@ def get_dataset_online(
global_process_count=jax.process_count(),
global_process_index=jax.process_index(),
prefetch=worker_buffer_size,
collate_fn=generate_collate_fn(tokenizer),
collate_fn=generate_collate_fn(),
default_split="train",
)

Expand Down Expand Up @@ -173,8 +173,8 @@ def __next__(self):
"train_len": len(dataloader) * jax.process_count(),
"local_batch_size": local_batch_size,
"global_batch_size": batch_size,
"null_labels": null_labels,
"null_labels_full": null_labels_full,
"model": model,
"tokenizer": tokenizer,
# "null_labels": null_labels,
# "null_labels_full": null_labels_full,
# "model": model,
# "tokenizer": tokenizer,
}
2 changes: 1 addition & 1 deletion flaxdiff/trainer/autoencoder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from flaxdiff.utils import RandomMarkovState

from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics

from .diffusion_trainer import TrainState
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder

class AutoEncoderTrainer(SimpleTrainer):
Expand Down
17 changes: 11 additions & 6 deletions flaxdiff/trainer/diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
from flax.training import dynamic_scale as dynamic_scale_lib
from flaxdiff.utils import TextEncoder, ConditioningEncoder

class TrainState(SimpleTrainState):
rngs: jax.random.PRNGKey
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self,
name: str = "Diffusion",
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
autoencoder: AutoEncoder = None,
encoder: ConditioningEncoder = None,
**kwargs
):
super().__init__(
Expand All @@ -64,6 +66,7 @@ def __init__(self,
self.unconditional_prob = unconditional_prob

self.autoencoder = autoencoder
self.encoder = encoder

def generate_states(
self,
Expand Down Expand Up @@ -106,7 +109,7 @@ def generate_states(

return state, best_state

def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
def _define_train_step(self, batch_size):
noise_schedule: NoiseScheduler = self.noise_schedule
model = self.model
model_output_transform = self.model_output_transform
Expand All @@ -115,6 +118,11 @@ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):

# Determine the number of unconditional samples
num_unconditional = int(batch_size * unconditional_prob)

_, null_labels_full = self.encoder([""])
null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)

conditioning_encoder = self.encoder.model

nS, nC = null_labels_seq.shape
null_labels_seq = jnp.broadcast_to(
Expand Down Expand Up @@ -146,7 +154,7 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
local_rng_state, rngs = local_rng_state.get_random_key()
images = autoencoder.encode(images, rngs)

output = text_embedder(
output = conditioning_encoder(
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
label_seq = output.last_hidden_state

Expand Down Expand Up @@ -231,8 +239,5 @@ def compute_metrics(state: TrainState, expected, pred):
return compute_metrics

def fit(self, data, steps_per_epoch, epochs):
null_labels_full = data['null_labels_full']
local_batch_size = data['local_batch_size']
text_embedder = data['model']
super().fit(data, steps_per_epoch, epochs, {
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
super().fit(data, steps_per_epoch, epochs, {"batch_size": local_batch_size})
38 changes: 28 additions & 10 deletions flaxdiff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial
import numpy as np
from jax.sharding import Mesh, PartitionSpec as P
from abc import ABC, abstractmethod

class MarkovState(struct.PyTreeNode):
pass
Expand Down Expand Up @@ -115,21 +116,38 @@ def _normalize(
mul *= scale
y = mul * x
return jnp.asarray(y, dtype)



@dataclass
class TextEncoder:
class ConditioningEncoder(ABC):
model: nn.Module
tokenizer: Callable

def __call__(self, data):
tokens = self.tokenize(data)
outputs = self.encode_from_tokens(tokens)
return outputs

def encode_from_tokens(self, tokens):
outputs = self.model(input_ids=tokens['input_ids'],
attention_mask=tokens['attention_mask'])
last_hidden_state = outputs.last_hidden_state
return last_hidden_state

def __call__(self, prompts):
# inputs = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="np")
inputs = self.tokenizer(prompts, padding="max_length",
def tokenize(self, data):
tokens = self.tokenizer(data, padding="max_length",
max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
outputs = self.model(input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'])
# outputs = infer(inputs['input_ids'], inputs['attention_mask'])

return tokens

@dataclass
class TextEncoder(ConditioningEncoder):
def __call__(self, data):
tokens = self.tokenize(data)
outputs = self.encode_from_tokens(tokens)
return outputs

def encode_from_tokens(self, tokens):
outputs = self.model(input_ids=tokens['input_ids'],
attention_mask=tokens['attention_mask'])
last_hidden_state = outputs.last_hidden_state
pooler_output = outputs.pooler_output # pooled (EOS token) states
embed_pooled = pooler_output # .astype(jnp.float16)
Expand Down
Loading