Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import dataclasses
import json
import os
import sys
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
from copy import copy
Expand All @@ -36,6 +37,10 @@
get_type_hints,
)

from omegaconf import DictConfig, OmegaConf

from ..utils.log import logger

DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)

Expand Down Expand Up @@ -303,6 +308,61 @@
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
types.
"""

def to_regular_dict(obj):
if isinstance(obj, DictConfig):
obj = OmegaConf.to_container(obj, resolve=True)
if isinstance(obj, dict):
return {k: to_regular_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_regular_dict(v) for v in obj]
return obj

Check warning on line 319 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L312-L319

Added lines #L312 - L319 were not covered by tests

def get_resume_checkpoint_path(args):

Check warning on line 321 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L321

Added line #L321 was not covered by tests
"""
get resume checkpoint path from mpirun env
"""
pdc_init_step = os.getenv("PDC_INIT_STEP")

Check warning on line 325 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L325

Added line #L325 was not covered by tests
# user defined resume_from_checkpoint
user_defined_resume_from_checkpoint = args.get("resume_from_checkpoint", None)
if pdc_init_step is None:
logger.info(f"user has defined resume_from_checkpoint: {user_defined_resume_from_checkpoint}")
return user_defined_resume_from_checkpoint

Check warning on line 330 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L327-L330

Added lines #L327 - L330 were not covered by tests
else:
if pdc_init_step == "0":

Check warning on line 332 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L332

Added line #L332 was not covered by tests
# from_scratch train process launched by pdc longjob
if user_defined_resume_from_checkpoint is None:
logger.info("resume training process from scratch (step 0)")
return None

Check warning on line 336 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L334-L336

Added lines #L334 - L336 were not covered by tests
else:
# Launching the sft_base training process using an initial checkpoint with the starting step set to 0.
# For instance, resume training from the checkpoint located at ‘./output/eb/checkpoint-init’.
logger.info(

Check warning on line 340 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L340

Added line #L340 was not covered by tests
f"init_step == 0 and user has defined resume_from_checkpoint: {user_defined_resume_from_checkpoint}"
)
return user_defined_resume_from_checkpoint

Check warning on line 343 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L343

Added line #L343 was not covered by tests
else:
# pdc_init_step > 0
logger.info(f"resume training process by pdc longjob with resume step: {pdc_init_step}")
resume_checkpoint = os.path.join(args.get("output_dir", None), f"checkpoint-{pdc_init_step}")
if user_defined_resume_from_checkpoint is not None:
logger.warning(

Check warning on line 349 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L346-L349

Added lines #L346 - L349 were not covered by tests
f"pdc_init_step:{pdc_init_step} and resume_ckpt:{user_defined_resume_from_checkpoint} exist together, use resume_checkpoint:{resume_checkpoint}"
)
return resume_checkpoint

Check warning on line 352 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L352

Added line #L352 was not covered by tests

args["resume_from_checkpoint"] = get_resume_checkpoint_path(args)
args_for_json = to_regular_dict(args)

Check warning on line 355 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L354-L355

Added lines #L354 - L355 were not covered by tests

json_filename = args_for_json.get("args_output_to_local")
if json_filename:
try:
with open(json_filename, "w") as json_file:
json.dump(args_for_json, json_file, indent=4)
except Exception as e:
logger.error(f"Failed to write args output JSON file: {e}")

Check warning on line 363 in paddlenlp/trainer/argparser.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/argparser.py#L357-L363

Added lines #L357 - L363 were not covered by tests
# Optionally handle the error or log it, then continue

outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
Expand Down
21 changes: 10 additions & 11 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
SAFE_WEIGHTS_INDEX_NAME,
VERA_WEIGHTS_NAME,
)
from ..utils.fault_tolerance import LOSS_INF_ERROR, LOSS_NAN_ERROR
from ..utils.import_utils import is_datasets_available, is_paddle_cuda_available
from ..utils.log import MetricsDumper, logger
from ..utils.tools import get_env_device
Expand Down Expand Up @@ -137,6 +138,7 @@
ShardingOption,
TrainerMemoryTracker,
TrainOutput,
download_recovery_ckpt_from_pdc,
find_batch_size,
get_last_checkpoint,
get_scheduler,
Expand Down Expand Up @@ -201,17 +203,6 @@
return False


try:
from paddle.framework.recall_error import LOSS_NAN_ERROR
except ImportError:
LOSS_NAN_ERROR = "PaddleRecall error(102): LossNan"

try:
from paddle.framework.recall_error import LOSS_INF_ERROR
except ImportError:
LOSS_INF_ERROR = "PaddleRecall error(104): LossInf"


__all__ = ["Trainer"]


Expand Down Expand Up @@ -756,6 +747,14 @@
os.makedirs(resume_from_checkpoint, exist_ok=True)
logger.info(f"Reset resume_from_checkpoint to temp directory : {resume_from_checkpoint}")

if resume_from_checkpoint is not None and self.args.pdc_download_ckpt:
if self.is_local_process_zero():
download_recovery_ckpt_from_pdc(resume_from_checkpoint, self.args.pdc_download_timeout)
if self.args.world_size > 1:
logger.info("Wait all processes finish downloading...")
paddle.distributed.barrier()
logger.info("All processes finished downloading from pdc")

Check warning on line 756 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L751-L756

Added lines #L751 - L756 were not covered by tests

train_dataloader = self.get_train_dataloader()

total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.dataset_world_size
Expand Down
43 changes: 43 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.io import IterableDataset
Expand All @@ -44,8 +45,10 @@

from ..trainer.argparser import strtobool
from ..transformers.tokenizer_utils_base import BatchEncoding
from ..utils.fault_tolerance import PDC_DOWNLOAD_ERROR
from ..utils.import_utils import is_paddle_cuda_available, is_psutil_available
from ..utils.log import logger
from ..utils.pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
from .utils.helper import distributed_file

__all__ = [
Expand Down Expand Up @@ -1202,3 +1205,43 @@
else:
parallel_config = set(parallel_config.split(" "))
return parallel_config


def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
"""Download checkpoint from PDC for resuming training after failover. Longjob envrionment is necessary.

Args:
recovery_checkpoint_path (`str`):
local path to load checkpoint for training recovery
timeout (`int`):
max wait time for download
"""

try:
base_dir, download_dir = os.path.split(os.path.normpath(recovery_checkpoint_path))
if not os.path.exists(base_dir) and base_dir != "":
os.makedirs(base_dir, exist_ok=True)
download_step = int(_re_checkpoint.search(download_dir).groups()[0])
except Exception as e:
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
start_time = time.time()

Check warning on line 1227 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L1220-L1227

Added lines #L1220 - L1227 were not covered by tests
# TODO(@gexiao): temporary workaround for environment variable conflicts.
original_trainer_id = os.getenv("PADDLE_TRAINER_ID")
original_trainers_num = os.getenv("PADDLE_TRAINERS_NUM")
cards_per_node = int(os.getenv("PADDLE_LOCAL_SIZE", "8"))
os.environ["PADDLE_TRAINER_ID"] = str(dist.get_rank() // cards_per_node)
os.environ["PADDLE_TRAINERS_NUM"] = str(dist.get_world_size() // cards_per_node)
result = pdc_tool.pdc_download_checkpoint(download_step, timeout)
os.environ["PADDLE_TRAINER_ID"] = original_trainer_id
os.environ["PADDLE_TRAINERS_NUM"] = original_trainers_num
end_time = time.time()
if result == PDCErrorCode.Success:
logger.info(f"Successfully downloaded checkpoint from PDC, total time cost: {end_time - start_time} seconds.")
elif result == PDCErrorCode.LocalPathExist:
logger.warning(

Check warning on line 1241 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L1229-L1241

Added lines #L1229 - L1241 were not covered by tests
f"Skipping download checkpoint since file exists at local, total time cost: {end_time - start_time} seconds."
)
else:
raise RuntimeError(

Check warning on line 1245 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L1245

Added line #L1245 was not covered by tests
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download checkpoint from PDC, recovery_checkpoint_path: {recovery_checkpoint_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
)
20 changes: 20 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import paddle.distributed as dist
from paddle.distributed import fleet

from ..utils.fault_tolerance import is_ft_env
from ..utils.log import logger
from .trainer_utils import (
IntervalStrategy,
Expand Down Expand Up @@ -953,6 +954,17 @@
default=False,
metadata={"help": "Offload optimizer after optimizer.step()"},
)
save_sharding_stage1_model_include_freeze_params: Optional[bool] = field(
default=False, metadata={"help": "Save Sharding Stage1 Model Exclude Freeze Params"}
)
pdc_download_ckpt: Optional[bool] = field(
default=False,
metadata={"help": "Download checkpoint in paddlecloud longjob environment"},
)
pdc_download_timeout: Optional[int] = field(
default=300,
metadata={"help": "Timeout seconds for downloading checkpoint from remote cluster."},
)

def __post_init__(self):
if in_auto_parallel_align_mode():
Expand Down Expand Up @@ -1818,6 +1830,14 @@
refined_recompute_dict = dict()
self.refined_recompute = refined_recompute_dict

# process fault tolerance settings
if not is_ft_env():
if self.pdc_download_ckpt:
logger.warning(

Check warning on line 1836 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1836

Added line #L1836 was not covered by tests
"pdc_download_ckpt can only be set as true inside FT environment. Automatically disable it now."
)
self.pdc_download_ckpt = False

Check warning on line 1839 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1839

Added line #L1839 was not covered by tests

def __str__(self):
self_as_dict = asdict(self)
self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}
Expand Down
36 changes: 36 additions & 0 deletions paddlenlp/utils/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from tqdm.auto import tqdm

from .env import DOWNLOAD_SERVER, FAILED_STATUS, SUCCESS_STATUS
from .fault_tolerance import PDC_DOWNLOAD_ERROR
from .log import logger
from .pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool

__all__ = ["get_weights_path_from_url"]

Expand Down Expand Up @@ -469,3 +471,37 @@
return True
except EntryNotFoundError:
return False


def download_from_pdc(remote_path, local_path, timeout):
"""Download from remote_path and place to a local_path through PaddleCloud. remote_path has to be uploaded through PaddleCloud as well.


Args:
remote_path (`str`):
remote path url for download
local_path (`str`):
local path to place downloaded object
timeout (`int`):
max wait time for download
"""

try:
base_dir, _ = os.path.split(os.path.normpath(remote_path))
if not os.path.exists(base_dir) and base_dir != "":
os.makedirs(base_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
start_time = time.time()
result = pdc_tool.pdc_download(remote_path, local_path, timeout)
end_time = time.time()
if result == PDCErrorCode.Success:
logger.info(f"Successfully downloaded object from PDC, total time cost: {end_time - start_time} seconds.")
elif result == PDCErrorCode.LocalPathExist:
logger.warning(

Check warning on line 501 in paddlenlp/utils/downloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/downloader.py#L489-L501

Added lines #L489 - L501 were not covered by tests
f"Skipping download object since file exists at local, total time cost: {end_time - start_time} seconds."
)
else:
raise RuntimeError(

Check warning on line 505 in paddlenlp/utils/downloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/downloader.py#L505

Added line #L505 was not covered by tests
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download object from PDC, remote_path: {remote_path}, local_path: {local_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
)
34 changes: 34 additions & 0 deletions paddlenlp/utils/fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

try:
from paddle.framework.recall_error import LOSS_NAN_ERROR
except ImportError:
LOSS_NAN_ERROR = "PaddleRecall error(102): LossNan"

Check warning on line 20 in paddlenlp/utils/fault_tolerance.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/fault_tolerance.py#L19-L20

Added lines #L19 - L20 were not covered by tests

try:
from paddle.framework.recall_error import LOSS_INF_ERROR
except ImportError:
LOSS_INF_ERROR = "PaddleRecall error(104): LossInf"

Check warning on line 25 in paddlenlp/utils/fault_tolerance.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/fault_tolerance.py#L24-L25

Added lines #L24 - L25 were not covered by tests

PDC_DOWNLOAD_ERROR = "PaddleRecall error(105): PDCDownloadError"


def is_ft_env():
"""
Check if the current environment is a FT environment.
"""
return "PDC_LONGJOB_ID" in os.environ
Loading