Skip to content

Commit 5f954fc

Browse files
committed
feat: restore weights on CPU
1 parent 4cb21dd commit 5f954fc

File tree

2 files changed

+265
-11
lines changed

2 files changed

+265
-11
lines changed

src/dalle_mini/model/modeling.py

Lines changed: 255 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,30 @@
1515
""" DalleBart model. """
1616

1717
import math
18+
import os
1819
from functools import partial
19-
from typing import Optional, Tuple
20+
from pickle import UnpicklingError
21+
from typing import Optional, Tuple, Union
2022

2123
import flax.linen as nn
2224
import jax
2325
import jax.numpy as jnp
26+
import msgpack.exceptions
2427
from flax.core.frozen_dict import unfreeze
2528
from flax.linen import make_causal_mask
26-
from flax.traverse_util import flatten_dict
29+
from flax.serialization import from_bytes
30+
from flax.traverse_util import flatten_dict, unflatten_dict
31+
from jax import lax
2732
from jax.random import PRNGKey
33+
from transformers.configuration_utils import PretrainedConfig
34+
from transformers.file_utils import (
35+
FLAX_WEIGHTS_NAME,
36+
WEIGHTS_NAME,
37+
cached_path,
38+
hf_bucket_url,
39+
is_offline_mode,
40+
is_remote_url,
41+
)
2842
from transformers.modeling_flax_outputs import (
2943
FlaxCausalLMOutputWithCrossAttentions,
3044
FlaxSeq2SeqLMOutput,
@@ -300,7 +314,8 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
300314
- added num_params property
301315
- config_class replaced to DalleBartConfig
302316
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
303-
- init weights on CPU
317+
- init weights on CPU with `load_on_cpu`
318+
- restore weights on CPU with custom `from_pretrained`
304319
"""
305320

306321
config_class = DalleBartConfig
@@ -359,6 +374,243 @@ def num_params(self):
359374
).values()
360375
return sum(list(num_params))
361376

377+
@classmethod
378+
def from_pretrained(
379+
cls,
380+
pretrained_model_name_or_path: Union[str, os.PathLike],
381+
dtype: jnp.dtype = jnp.float32,
382+
*model_args,
383+
**kwargs,
384+
):
385+
config = kwargs.pop("config", None)
386+
cache_dir = kwargs.pop("cache_dir", None)
387+
from_pt = kwargs.pop("from_pt", False)
388+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
389+
force_download = kwargs.pop("force_download", False)
390+
resume_download = kwargs.pop("resume_download", False)
391+
proxies = kwargs.pop("proxies", None)
392+
local_files_only = kwargs.pop("local_files_only", False)
393+
use_auth_token = kwargs.pop("use_auth_token", None)
394+
revision = kwargs.pop("revision", None)
395+
from_pipeline = kwargs.pop("_from_pipeline", None)
396+
from_auto_class = kwargs.pop("_from_auto", False)
397+
398+
user_agent = {
399+
"file_type": "model",
400+
"framework": "flax",
401+
"from_auto_class": from_auto_class,
402+
}
403+
if from_pipeline is not None:
404+
user_agent["using_pipeline"] = from_pipeline
405+
406+
if is_offline_mode() and not local_files_only:
407+
logger.info("Offline mode: forcing local_files_only=True")
408+
local_files_only = True
409+
410+
# Load config if we don't provide a configuration
411+
if not isinstance(config, PretrainedConfig):
412+
config_path = (
413+
config if config is not None else pretrained_model_name_or_path
414+
)
415+
config, model_kwargs = cls.config_class.from_pretrained(
416+
config_path,
417+
cache_dir=cache_dir,
418+
return_unused_kwargs=True,
419+
force_download=force_download,
420+
resume_download=resume_download,
421+
proxies=proxies,
422+
local_files_only=local_files_only,
423+
use_auth_token=use_auth_token,
424+
revision=revision,
425+
_from_auto=from_auto_class,
426+
_from_pipeline=from_pipeline,
427+
**kwargs,
428+
)
429+
else:
430+
model_kwargs = kwargs
431+
432+
# Add the dtype to model_kwargs
433+
model_kwargs["dtype"] = dtype
434+
435+
# Load model
436+
if pretrained_model_name_or_path is not None:
437+
if os.path.isdir(pretrained_model_name_or_path):
438+
if from_pt and os.path.isfile(
439+
os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
440+
):
441+
# Load from a PyTorch checkpoint
442+
archive_file = os.path.join(
443+
pretrained_model_name_or_path, WEIGHTS_NAME
444+
)
445+
elif os.path.isfile(
446+
os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
447+
):
448+
# Load from a Flax checkpoint
449+
archive_file = os.path.join(
450+
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
451+
)
452+
else:
453+
raise EnvironmentError(
454+
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
455+
f"{pretrained_model_name_or_path} or `from_pt` set to False"
456+
)
457+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
458+
pretrained_model_name_or_path
459+
):
460+
archive_file = pretrained_model_name_or_path
461+
else:
462+
archive_file = hf_bucket_url(
463+
pretrained_model_name_or_path,
464+
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
465+
revision=revision,
466+
)
467+
468+
# redirect to the cache, if necessary
469+
try:
470+
resolved_archive_file = cached_path(
471+
archive_file,
472+
cache_dir=cache_dir,
473+
force_download=force_download,
474+
proxies=proxies,
475+
resume_download=resume_download,
476+
local_files_only=local_files_only,
477+
use_auth_token=use_auth_token,
478+
user_agent=user_agent,
479+
)
480+
except EnvironmentError as err:
481+
logger.error(err)
482+
msg = (
483+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
484+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
485+
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
486+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
487+
)
488+
raise EnvironmentError(msg)
489+
490+
if resolved_archive_file == archive_file:
491+
logger.info(f"loading weights file {archive_file}")
492+
else:
493+
logger.info(
494+
f"loading weights file {archive_file} from cache at {resolved_archive_file}"
495+
)
496+
else:
497+
resolved_archive_file = None
498+
499+
# init random models
500+
model = cls(config, *model_args, **model_kwargs)
501+
502+
with open(resolved_archive_file, "rb") as state_f:
503+
try:
504+
state = from_bytes(cls, state_f.read())
505+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
506+
try:
507+
with open(resolved_archive_file) as f:
508+
if f.read().startswith("version"):
509+
raise OSError(
510+
"You seem to have cloned a repository without having git-lfs installed. Please install "
511+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
512+
"you cloned."
513+
)
514+
else:
515+
raise ValueError from e
516+
except (UnicodeDecodeError, ValueError):
517+
raise EnvironmentError(
518+
f"Unable to convert {archive_file} to Flax deserializable object. "
519+
)
520+
521+
# if model is base model only use model_prefix key
522+
if (
523+
cls.base_model_prefix not in dict(model.params)
524+
and cls.base_model_prefix in state
525+
):
526+
state = state[cls.base_model_prefix]
527+
528+
# if model is head model and we are loading weights from base model
529+
# we initialize new params dict with base_model_prefix
530+
if (
531+
cls.base_model_prefix in dict(model.params)
532+
and cls.base_model_prefix not in state
533+
):
534+
state = {cls.base_model_prefix: state}
535+
536+
# flatten dicts
537+
state = flatten_dict(state)
538+
539+
random_state = flatten_dict(unfreeze(model.params))
540+
541+
missing_keys = model.required_params - set(state.keys())
542+
unexpected_keys = set(state.keys()) - model.required_params
543+
544+
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
545+
# matching the weights in the model.
546+
mismatched_keys = []
547+
for key in state.keys():
548+
if key in random_state and state[key].shape != random_state[key].shape:
549+
if ignore_mismatched_sizes:
550+
mismatched_keys.append(
551+
(key, state[key].shape, random_state[key].shape)
552+
)
553+
state[key] = random_state[key]
554+
else:
555+
raise ValueError(
556+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
557+
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
558+
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
559+
"model."
560+
)
561+
562+
# add missing keys as random parameters
563+
for missing_key in missing_keys:
564+
state[missing_key] = random_state[missing_key]
565+
566+
# remove unexpected keys to not be saved again
567+
for unexpected_key in unexpected_keys:
568+
del state[unexpected_key]
569+
570+
if len(unexpected_keys) > 0:
571+
logger.warning(
572+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
573+
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
574+
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
575+
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
576+
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
577+
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
578+
)
579+
else:
580+
logger.info(
581+
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
582+
)
583+
584+
if len(missing_keys) > 0:
585+
logger.warning(
586+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
587+
f"and are newly initialized: {missing_keys}\n"
588+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
589+
)
590+
elif len(mismatched_keys) == 0:
591+
logger.info(
592+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
593+
f"If your task is similar to the task the model of the checkpoint was trained on, "
594+
f"you can already use {model.__class__.__name__} for predictions without further training."
595+
)
596+
if len(mismatched_keys) > 0:
597+
mismatched_warning = "\n".join(
598+
[
599+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
600+
for key, shape1, shape2 in mismatched_keys
601+
]
602+
)
603+
logger.warning(
604+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
605+
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
606+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
607+
)
608+
609+
# set correct parameters
610+
model.params = unflatten_dict(state)
611+
612+
return model
613+
362614

363615
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
364616
"""

tools/train/train.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ class TrainingArguments:
249249
"help": "Number of updates steps to accumulate before performing an update pass."
250250
},
251251
)
252+
gradient_checkpointing: bool = field(
253+
default=False, metadata={"help": "Use gradient checkpointing."}
254+
)
252255

253256
learning_rate: float = field(
254257
default=5e-5, metadata={"help": "The initial learning rate."}
@@ -515,25 +518,24 @@ def main():
515518
load_on_cpu=True,
516519
)
517520

518-
# Load tokenizer
519-
tokenizer = DalleBartTokenizer.from_pretrained(
520-
model_args.tokenizer_name, use_fast=True
521-
)
521+
# update model config per training args
522+
model.config.gradient_checkpointing = training_args.gradient_checkpointing
522523

523524
# get PartitionSpec for model params (required to be a dict)
524525
param_spec = set_partitions(model.params)
525526

526527
# convert params to frozen dict
527528
model._params = freeze(model.params)
528529

530+
# Load tokenizer
531+
tokenizer = DalleBartTokenizer.from_pretrained(
532+
model_args.tokenizer_name, use_fast=True
533+
)
534+
529535
# Preprocessing the datasets.
530536
# We need to normalize and tokenize inputs and targets.
531-
532537
dataset.preprocess(tokenizer=tokenizer, config=model.config)
533538

534-
# no dropout (hardcoded)
535-
model.config.dropout = 0.0
536-
537539
# Initialize our training
538540
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
539541

0 commit comments

Comments
 (0)