|
15 | 15 | """ DalleBart model. """ |
16 | 16 |
|
17 | 17 | import math |
| 18 | +import os |
18 | 19 | from functools import partial |
19 | | -from typing import Optional, Tuple |
| 20 | +from pickle import UnpicklingError |
| 21 | +from typing import Optional, Tuple, Union |
20 | 22 |
|
21 | 23 | import flax.linen as nn |
22 | 24 | import jax |
23 | 25 | import jax.numpy as jnp |
| 26 | +import msgpack.exceptions |
24 | 27 | from flax.core.frozen_dict import unfreeze |
25 | 28 | from flax.linen import make_causal_mask |
26 | | -from flax.traverse_util import flatten_dict |
| 29 | +from flax.serialization import from_bytes |
| 30 | +from flax.traverse_util import flatten_dict, unflatten_dict |
| 31 | +from jax import lax |
27 | 32 | from jax.random import PRNGKey |
| 33 | +from transformers.configuration_utils import PretrainedConfig |
| 34 | +from transformers.file_utils import ( |
| 35 | + FLAX_WEIGHTS_NAME, |
| 36 | + WEIGHTS_NAME, |
| 37 | + cached_path, |
| 38 | + hf_bucket_url, |
| 39 | + is_offline_mode, |
| 40 | + is_remote_url, |
| 41 | +) |
28 | 42 | from transformers.modeling_flax_outputs import ( |
29 | 43 | FlaxCausalLMOutputWithCrossAttentions, |
30 | 44 | FlaxSeq2SeqLMOutput, |
@@ -300,7 +314,8 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel): |
300 | 314 | - added num_params property |
301 | 315 | - config_class replaced to DalleBartConfig |
302 | 316 | - __init__ accepts abstract_init which does uses parameter shape to initialize the model |
303 | | - - init weights on CPU |
| 317 | + - init weights on CPU with `load_on_cpu` |
| 318 | + - restore weights on CPU with custom `from_pretrained` |
304 | 319 | """ |
305 | 320 |
|
306 | 321 | config_class = DalleBartConfig |
@@ -359,6 +374,243 @@ def num_params(self): |
359 | 374 | ).values() |
360 | 375 | return sum(list(num_params)) |
361 | 376 |
|
| 377 | + @classmethod |
| 378 | + def from_pretrained( |
| 379 | + cls, |
| 380 | + pretrained_model_name_or_path: Union[str, os.PathLike], |
| 381 | + dtype: jnp.dtype = jnp.float32, |
| 382 | + *model_args, |
| 383 | + **kwargs, |
| 384 | + ): |
| 385 | + config = kwargs.pop("config", None) |
| 386 | + cache_dir = kwargs.pop("cache_dir", None) |
| 387 | + from_pt = kwargs.pop("from_pt", False) |
| 388 | + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) |
| 389 | + force_download = kwargs.pop("force_download", False) |
| 390 | + resume_download = kwargs.pop("resume_download", False) |
| 391 | + proxies = kwargs.pop("proxies", None) |
| 392 | + local_files_only = kwargs.pop("local_files_only", False) |
| 393 | + use_auth_token = kwargs.pop("use_auth_token", None) |
| 394 | + revision = kwargs.pop("revision", None) |
| 395 | + from_pipeline = kwargs.pop("_from_pipeline", None) |
| 396 | + from_auto_class = kwargs.pop("_from_auto", False) |
| 397 | + |
| 398 | + user_agent = { |
| 399 | + "file_type": "model", |
| 400 | + "framework": "flax", |
| 401 | + "from_auto_class": from_auto_class, |
| 402 | + } |
| 403 | + if from_pipeline is not None: |
| 404 | + user_agent["using_pipeline"] = from_pipeline |
| 405 | + |
| 406 | + if is_offline_mode() and not local_files_only: |
| 407 | + logger.info("Offline mode: forcing local_files_only=True") |
| 408 | + local_files_only = True |
| 409 | + |
| 410 | + # Load config if we don't provide a configuration |
| 411 | + if not isinstance(config, PretrainedConfig): |
| 412 | + config_path = ( |
| 413 | + config if config is not None else pretrained_model_name_or_path |
| 414 | + ) |
| 415 | + config, model_kwargs = cls.config_class.from_pretrained( |
| 416 | + config_path, |
| 417 | + cache_dir=cache_dir, |
| 418 | + return_unused_kwargs=True, |
| 419 | + force_download=force_download, |
| 420 | + resume_download=resume_download, |
| 421 | + proxies=proxies, |
| 422 | + local_files_only=local_files_only, |
| 423 | + use_auth_token=use_auth_token, |
| 424 | + revision=revision, |
| 425 | + _from_auto=from_auto_class, |
| 426 | + _from_pipeline=from_pipeline, |
| 427 | + **kwargs, |
| 428 | + ) |
| 429 | + else: |
| 430 | + model_kwargs = kwargs |
| 431 | + |
| 432 | + # Add the dtype to model_kwargs |
| 433 | + model_kwargs["dtype"] = dtype |
| 434 | + |
| 435 | + # Load model |
| 436 | + if pretrained_model_name_or_path is not None: |
| 437 | + if os.path.isdir(pretrained_model_name_or_path): |
| 438 | + if from_pt and os.path.isfile( |
| 439 | + os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) |
| 440 | + ): |
| 441 | + # Load from a PyTorch checkpoint |
| 442 | + archive_file = os.path.join( |
| 443 | + pretrained_model_name_or_path, WEIGHTS_NAME |
| 444 | + ) |
| 445 | + elif os.path.isfile( |
| 446 | + os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) |
| 447 | + ): |
| 448 | + # Load from a Flax checkpoint |
| 449 | + archive_file = os.path.join( |
| 450 | + pretrained_model_name_or_path, FLAX_WEIGHTS_NAME |
| 451 | + ) |
| 452 | + else: |
| 453 | + raise EnvironmentError( |
| 454 | + f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " |
| 455 | + f"{pretrained_model_name_or_path} or `from_pt` set to False" |
| 456 | + ) |
| 457 | + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( |
| 458 | + pretrained_model_name_or_path |
| 459 | + ): |
| 460 | + archive_file = pretrained_model_name_or_path |
| 461 | + else: |
| 462 | + archive_file = hf_bucket_url( |
| 463 | + pretrained_model_name_or_path, |
| 464 | + filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, |
| 465 | + revision=revision, |
| 466 | + ) |
| 467 | + |
| 468 | + # redirect to the cache, if necessary |
| 469 | + try: |
| 470 | + resolved_archive_file = cached_path( |
| 471 | + archive_file, |
| 472 | + cache_dir=cache_dir, |
| 473 | + force_download=force_download, |
| 474 | + proxies=proxies, |
| 475 | + resume_download=resume_download, |
| 476 | + local_files_only=local_files_only, |
| 477 | + use_auth_token=use_auth_token, |
| 478 | + user_agent=user_agent, |
| 479 | + ) |
| 480 | + except EnvironmentError as err: |
| 481 | + logger.error(err) |
| 482 | + msg = ( |
| 483 | + f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" |
| 484 | + f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" |
| 485 | + f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" |
| 486 | + f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" |
| 487 | + ) |
| 488 | + raise EnvironmentError(msg) |
| 489 | + |
| 490 | + if resolved_archive_file == archive_file: |
| 491 | + logger.info(f"loading weights file {archive_file}") |
| 492 | + else: |
| 493 | + logger.info( |
| 494 | + f"loading weights file {archive_file} from cache at {resolved_archive_file}" |
| 495 | + ) |
| 496 | + else: |
| 497 | + resolved_archive_file = None |
| 498 | + |
| 499 | + # init random models |
| 500 | + model = cls(config, *model_args, **model_kwargs) |
| 501 | + |
| 502 | + with open(resolved_archive_file, "rb") as state_f: |
| 503 | + try: |
| 504 | + state = from_bytes(cls, state_f.read()) |
| 505 | + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: |
| 506 | + try: |
| 507 | + with open(resolved_archive_file) as f: |
| 508 | + if f.read().startswith("version"): |
| 509 | + raise OSError( |
| 510 | + "You seem to have cloned a repository without having git-lfs installed. Please install " |
| 511 | + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " |
| 512 | + "you cloned." |
| 513 | + ) |
| 514 | + else: |
| 515 | + raise ValueError from e |
| 516 | + except (UnicodeDecodeError, ValueError): |
| 517 | + raise EnvironmentError( |
| 518 | + f"Unable to convert {archive_file} to Flax deserializable object. " |
| 519 | + ) |
| 520 | + |
| 521 | + # if model is base model only use model_prefix key |
| 522 | + if ( |
| 523 | + cls.base_model_prefix not in dict(model.params) |
| 524 | + and cls.base_model_prefix in state |
| 525 | + ): |
| 526 | + state = state[cls.base_model_prefix] |
| 527 | + |
| 528 | + # if model is head model and we are loading weights from base model |
| 529 | + # we initialize new params dict with base_model_prefix |
| 530 | + if ( |
| 531 | + cls.base_model_prefix in dict(model.params) |
| 532 | + and cls.base_model_prefix not in state |
| 533 | + ): |
| 534 | + state = {cls.base_model_prefix: state} |
| 535 | + |
| 536 | + # flatten dicts |
| 537 | + state = flatten_dict(state) |
| 538 | + |
| 539 | + random_state = flatten_dict(unfreeze(model.params)) |
| 540 | + |
| 541 | + missing_keys = model.required_params - set(state.keys()) |
| 542 | + unexpected_keys = set(state.keys()) - model.required_params |
| 543 | + |
| 544 | + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not |
| 545 | + # matching the weights in the model. |
| 546 | + mismatched_keys = [] |
| 547 | + for key in state.keys(): |
| 548 | + if key in random_state and state[key].shape != random_state[key].shape: |
| 549 | + if ignore_mismatched_sizes: |
| 550 | + mismatched_keys.append( |
| 551 | + (key, state[key].shape, random_state[key].shape) |
| 552 | + ) |
| 553 | + state[key] = random_state[key] |
| 554 | + else: |
| 555 | + raise ValueError( |
| 556 | + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " |
| 557 | + f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " |
| 558 | + "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " |
| 559 | + "model." |
| 560 | + ) |
| 561 | + |
| 562 | + # add missing keys as random parameters |
| 563 | + for missing_key in missing_keys: |
| 564 | + state[missing_key] = random_state[missing_key] |
| 565 | + |
| 566 | + # remove unexpected keys to not be saved again |
| 567 | + for unexpected_key in unexpected_keys: |
| 568 | + del state[unexpected_key] |
| 569 | + |
| 570 | + if len(unexpected_keys) > 0: |
| 571 | + logger.warning( |
| 572 | + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " |
| 573 | + f"initializing {model.__class__.__name__}: {unexpected_keys}\n" |
| 574 | + f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " |
| 575 | + f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" |
| 576 | + f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " |
| 577 | + f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." |
| 578 | + ) |
| 579 | + else: |
| 580 | + logger.info( |
| 581 | + f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" |
| 582 | + ) |
| 583 | + |
| 584 | + if len(missing_keys) > 0: |
| 585 | + logger.warning( |
| 586 | + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " |
| 587 | + f"and are newly initialized: {missing_keys}\n" |
| 588 | + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." |
| 589 | + ) |
| 590 | + elif len(mismatched_keys) == 0: |
| 591 | + logger.info( |
| 592 | + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" |
| 593 | + f"If your task is similar to the task the model of the checkpoint was trained on, " |
| 594 | + f"you can already use {model.__class__.__name__} for predictions without further training." |
| 595 | + ) |
| 596 | + if len(mismatched_keys) > 0: |
| 597 | + mismatched_warning = "\n".join( |
| 598 | + [ |
| 599 | + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" |
| 600 | + for key, shape1, shape2 in mismatched_keys |
| 601 | + ] |
| 602 | + ) |
| 603 | + logger.warning( |
| 604 | + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " |
| 605 | + f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" |
| 606 | + f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." |
| 607 | + ) |
| 608 | + |
| 609 | + # set correct parameters |
| 610 | + model.params = unflatten_dict(state) |
| 611 | + |
| 612 | + return model |
| 613 | + |
362 | 614 |
|
363 | 615 | class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule): |
364 | 616 | """ |
|
0 commit comments