diff --git a/apps/pLM/__init__.py b/apps/pLM/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/pLM/configs/debug.yaml b/apps/pLM/configs/debug.yaml new file mode 100644 index 00000000..2f352e3b --- /dev/null +++ b/apps/pLM/configs/debug.yaml @@ -0,0 +1,52 @@ + +name: "debug-plm" +steps: 1000 +probe_freq: 100 +seed: 777 +optim: + lr: 3e-4 + warmup: 2000 + lr_min_ratio: 0.000001 + clip: 10.0 + +distributed: + fsdp_type: full_shard + compile: true + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +model: + dim: 1024 + n_layers: 8 + n_heads: 8 + vocab_size: 24 + +data: + root_dir: /lus/eagle/projects/FoundEpidem/hippekp/genslm-foundation/data/ncbi/refseq.parsed/faa-jsonl + sources: + refseq: 1.0 + batch_size: 32 + prefetch_size: 64 + seq_len: 2048 + n_views: 2 + load_async: true + add_bos: true + add_eos: true + tokenizer: + name: aa + +profiling: + run: true + +checkpoint: + dump: + every: 100 + keep: 1 + eval: + every: 100 + keep: 1 + +logging: + freq: 10 diff --git a/apps/pLM/eval.py b/apps/pLM/eval.py new file mode 100644 index 00000000..59794ddf --- /dev/null +++ b/apps/pLM/eval.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from dataclasses import asdict, dataclass, field +from datetime import datetime +import json +import logging +from pathlib import Path +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from typing import Any, List, Optional, Tuple, Union +from lm_eval import simple_evaluate +from omegaconf import OmegaConf +import torch +from apps.main.generate import ( + PackedCausalTransformerGenerator, + PackedCausalTransformerGeneratorArgs, + load_consolidated_model_and_tokenizer, +) +from apps.main.transformer import LMTransformer, LMTransformerArgs +from lingua.args import dump_config +from lingua.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from lingua.distributed import ( + DistributedArgs, + get_global_rank, + get_world_size, + setup_torch_distributed, +) + +EVAL_FOLDER_NAME = "{:010d}" + +logger = logging.getLogger() + + +@dataclass +class LMHarnessArgs: + tasks: Optional[List[Any]] = None + num_fewshot: Optional[int] = None + device: Optional[str] = None + use_cache: Optional[str] = None + cache_requests: bool = False + rewrite_requests_cache: bool = False + delete_requests_cache: bool = False + limit: Optional[Union[int, float]] = None + bootstrap_iters: int = 100000 + check_integrity: bool = False + write_out: bool = False + log_samples: bool = True + system_instruction: Optional[str] = None + apply_chat_template: Union[bool, str] = False + fewshot_as_multiturn: bool = False + gen_kwargs: Optional[str] = None + verbosity: str = "INFO" + predict_only: bool = False + random_seed: int = 0 + numpy_random_seed: int = 1234 + torch_random_seed: int = 1234 + fewshot_random_seed: int = 1234 + + +@dataclass +class EvalArgs: + name: str = "evals" + dump_dir: Optional[str] = None + metric_log_dir: Optional[str] = None + ckpt_dir: str = "" + generator: PackedCausalTransformerGeneratorArgs = field( + default_factory=PackedCausalTransformerGeneratorArgs + ) + harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs) + + wandb: Optional[Any] = None + + global_step: Optional[int] = None # for in-training evaluation + + +def all_dicts_same(dict_list): + if not dict_list: # Check if the list is empty + return True + + # Compare each dictionary to the first one + first_dict = dict_list[0] + return all(d == first_dict for d in dict_list) + + +class MockAccelerator: + def gather(self, tensor): + l = [torch.zeros_like(tensor) for _ in range(get_world_size())] + torch.distributed.all_gather(l, tensor) + return torch.stack(l) + + def wait_for_everyone(self): + torch.distributed.barrier() + + +# Light wrapper around generator for lm-eval harness +class EvalHarnessLM(LM): + def __init__(self, generator): + super().__init__() + self.generator = generator + self.accelerator = MockAccelerator() + self._rank = get_global_rank() + self._world_size = get_world_size() + self.device = generator.device + + def generate_until(self, requests: List[Instance]) -> List[str]: + prompts, gen_args = zip(*[req.args for req in requests]) + assert all_dicts_same(gen_args), "Doesn't support different gen args for now" + gen_args = gen_args[0] + temperature = gen_args.get("temperature", 0.0) + top_p = gen_args.get("top_p", None) + top_k = gen_args.get("top_k", None) + until = gen_args.get("until", []) + + self.generator.temperature = temperature + self.generator.top_p = top_p + self.generator.top_k = top_k + self.generator.until = until + generations, _, _ = self.generator.generate(prompts) + filtered_gen = [] + for g in generations: + for e in until: + g = g.replace(e, "") + filtered_gen.append(g) + return filtered_gen + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + prompts, continuations = zip(*[req.args for req in requests]) + inputs = [req.args[0] + req.args[1] for req in requests] + max_gen_len = self.generator.max_gen_len + # We temporarily lower max gen len + self.generator.max_gen_len = 1 + _, lls, greedy = self.generator.generate(inputs) + results = [] + for p, ll, gr in zip(prompts, lls, greedy): + p_len = len(self.generator.tokenizer.encode(p, add_bos=True, add_eos=False)) + results.append((ll[p_len:].sum().item(), gr[p_len:].all().item())) + + self.generator.max_gen_len = max_gen_len + return results + + def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + prompts = [req.args[0] for req in requests] + max_gen_len = self.generator.max_gen_len + # We temporarily lower max gen len + self.generator.max_gen_len = 1 + _, lls, _ = self.generator.generate(prompts) + results = [] + for ll in lls: + results.append((ll.sum().item(),)) + self.generator.max_gen_len = max_gen_len + + return results + + +def launch_eval(cfg: EvalArgs): + if not torch.distributed.is_initialized(): + setup_torch_distributed(DistributedArgs()) + if ( + Path(cfg.ckpt_dir).exists() + and (Path(cfg.ckpt_dir) / "params.json").exists() + and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None + ): + consolidate_path = Path(cfg.ckpt_dir) + else: + consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER + if not consolidate_path.exists() and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(cfg.ckpt_dir) + + Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True) + dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False) + + consolidate_path = str(consolidate_path) + torch.distributed.barrier() + logger.info("Loading model") + model, tokenizer = load_consolidated_model_and_tokenizer( + consolidate_path, + model_cls=LMTransformer, + model_args_cls=LMTransformerArgs, + ) + logger.info("Model loaded") + model.eval() + generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer) + + wrap = EvalHarnessLM(generator) + results = simple_evaluate(wrap, **asdict(cfg.harness)) + if get_global_rank() == 0: + with open(Path(cfg.dump_dir) / "results.json", "w") as f: + f.write(json.dumps(results)) + logger.info(f"All evaluation results: {results['results']}") + if cfg.metric_log_dir and get_global_rank() == 0: + metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl" + + logger.info(f"Writing metric logs to {metric_log_path}") + timestamp = { + "created_at": datetime.utcnow().isoformat(), + } + if cfg.global_step is not None: + timestamp["global_step"] = cfg.global_step + print( + json.dumps(timestamp | results["results"]), + file=open(metric_log_path, mode="a"), + flush=True, + ) + del generator + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate EvalArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call eval.py with eval.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in EvalArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(EvalArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + launch_eval(cfg) + + +if __name__ == "__main__": + main() diff --git a/apps/pLM/generate.py b/apps/pLM/generate.py new file mode 100644 index 00000000..baf4f12c --- /dev/null +++ b/apps/pLM/generate.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from dataclasses import dataclass, field +from pathlib import Path +import time +from typing import List, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from omegaconf import OmegaConf +from torch.nn import functional as F +import xformers + +from apps.main.transformer import LMTransformer, LMTransformerArgs +from lingua.args import dataclass_from_dict +from lingua.checkpoint import CONSOLIDATE_NAME +from lingua.tokenizer import Tokenizer, build_tokenizer +from lingua.transformer import ( + Attention, + causal_mask, + generate_doc_mask_mod, + lengths_to_local_ids, + lengths_to_start_ids, +) +from torch.nn.attention.flex_attention import create_block_mask + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +def sample_top_k(probs, k): + topk_value, _ = torch.topk(probs, k) # batch_sz x topk + min_value_top_k = topk_value[:, [-1]] + probs[probs < min_value_top_k] = 0.0 + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs, num_samples=1) + return next_token + + +def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None): + shape = logits.shape + logits = logits.flatten(end_dim=-2) + if temperature > 0.0: + probs = torch.softmax(logits / temperature, dim=-1) + + if top_p is not None: + next_token = sample_top_p(probs, top_p) + elif top_k is not None: + next_token = sample_top_k(probs, top_k) + else: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1) + return next_token.view(shape[:-1]) + + +def pack_prompts(prompts: List[int]): + res = [] + lengths = [] + for i, p in enumerate(prompts): + p = torch.tensor(p, dtype=torch.long) + l = p.size(0) + res.append(p) + lengths.append(l) + lengths = torch.tensor(lengths, dtype=torch.long) + res = torch.cat(res) + return res, lengths + + +def batch_prompts(prompts, max_elements, lengths=None): + batches = [] + current_batch = [] + current_count = 0 + + for i in range(len(prompts)): + prt = prompts[i] + prompt_size = len(prt) if lengths is None else lengths[i] + if current_count + prompt_size <= max_elements: + current_batch.append(prt) + current_count += prompt_size + else: + if current_batch: # Add the current batch to batches + batches.append(current_batch) + # Start a new batch with the current prompt + current_batch = [prt] + current_count = prompt_size + + # Add the last batch if it contains any prompts + if current_batch: + batches.append(current_batch) + + return batches + + +class KVCache(nn.Module): + def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device): + super().__init__() + shape = (bsz, seqlen, n_heads, head_dim) + self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device)) + self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device)) + self.offset = 0 + + def reset(self): + self.k_cache.zero_() + self.v_cache.zero_() + self.offset = 0 + + def update(self, k_val, v_val, tok_idx): + # input_pos: [B], k_val: [B, S, H, D] + self.k_cache.index_copy_(1, self.offset + tok_idx, k_val) + self.v_cache.index_copy_(1, self.offset + tok_idx, v_val) + return self.k_cache, self.v_cache + + +@dataclass +class PackedCausalTransformerGeneratorArgs: + temperature: float = 0.0 + top_p: Optional[float] = None + top_k: Optional[float] = None + max_gen_len: int = 512 # Maximum number of tokens to generate + max_tokens: int = 1024 # Maximum number of tokens that can go through the model + max_prompt_len: Optional[int] = None + until: List[str] = field(default_factory=list) + compile_prefilling: bool = False + reduce_generation_overhead: bool = False + show_progress: bool = False + dtype: Optional[str] = "bf16" + device: Optional[str] = "cuda" + + +class PackedCausalTransformerGenerator: + def __init__( + self, + cfg: PackedCausalTransformerGeneratorArgs, + model: nn.Module, + tokenizer: Tokenizer, + ): + """ + This class wraps a causal transformer model with its corresponding tokenizer + and provides an efficient way to pack prompts together and do generation on + the packed sequence. + + For example, if we had the prompts "Hello, I am a " and "Initiating calibration " + Then this class will concatenate those sequence (pack them together) + "Hello, I am a Initiating calibration" + And make the necessary attention masks such that a sequence only attends to itself + during prefilling and generation. + + This class creates a fixed size cache of size max_tokens or sum of prompt sizes + + the max number of generated tokens per sequence. + """ + self.model = model + self.tokenizer = tokenizer + self.temperature = cfg.temperature + self.top_p = cfg.top_p + self.top_k = cfg.top_k + + self.max_gen_len = cfg.max_gen_len + self.max_tokens = cfg.max_tokens + self.max_prompt_len = cfg.max_prompt_len + self.until = cfg.until + self.max_until_size = max([len(e) for e in self.until]) if self.until else 1 + self.device = cfg.device + + # Compile if necessary + self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling) + self.generate_next_token = torch.compile( + self.generate_next_token, + mode="reduce-overhead", + disable=not cfg.reduce_generation_overhead, + ) + + self.show_progress = cfg.show_progress + self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype] + + self.prefill_doc_id, self.prefill_tok_id = None, None + self.padded_doc_id, self.padded_tok_id = None, None + self.current_doc_id, self.current_tok_id = None, None + self.padded_doc_start = None + self.prefill_mask = None + + def clear_cache(self, offset): + for module in self.model.modules(): + if isinstance(module, Attention): + if not hasattr(module, "kv_cache"): + module.kv_cache = KVCache( + 1, + self.max_tokens, + module.n_heads, + module.head_dim, + self.dtype, + self.device, + ) + module.kv_cache.offset = offset + + @torch.compiler.disable + def setup_prefilling(self, lengths: torch.Tensor): + # The KV cache is a fixed size tensor of size max_tokens that we need + # to update in order to do correct autoregressive generation. + + # Here we will generate token by token but on multiple sequences + # at once. To do so, we need to have an attention mask that makes + # each sequence independent. + + # Each sequence will write to its allocated space in the KV Cache. + # We allocate len(seq) + max_gen_len to each sequence in the cache. + + # We will generate max_gen_len for each document + padded_lengths = lengths + self.max_gen_len + max_tokens = self.max_tokens or padded_lengths.sum().item() + # The last document might have more padding to fill up to max_tokens + padded_lengths[-1] += max_tokens - padded_lengths.sum() + + # This is the start index in the cache for each document + self.padded_doc_start = lengths_to_start_ids(padded_lengths) + # For example with ab--123--cdef-- + # this would be 0, 4, 9 if max_gen_len is 2 + + # We repeat interleave to align with tokens for prefilling + # Ex: ab--123--cdef-- + # 000044444999999 + prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths) + # This offset will make sure the tokens are written to the + # correct positions in the cache during prefilling + + # We either init the cache or clear it by resetting the offset to prefill_offset + self.clear_cache(prefill_offset) + + # The prefilling mask looks like the following for + # the two packed sequences ab and 123 : ab123 + # Where spaces are empty cache positions + # keys + # ab---123--- + # queries a 10000000000 + # b 11000000000 + # 1 00000100000 + # 2 00000110000 + # 3 00000111000 + # We make sure to skip the empty cache positions + # and only attend to positions within the same sequence + doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths) + self.prefill_mask = create_block_mask( + doc_mask_mod, 1, None, lengths.sum(), max_tokens + ) + + # This creates the prefilling token ids which look like + # the following for the packed sequence abcdefg1234 + # abcdefg1234 + # 01234560123 + # The token id gives us the position within each sequence + # This is used to compute ROPE and to update the cache + # At each forward pass the current tokens are written to + # offset + tok_id + self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths) + + # This creates the padded token and document ids + # which look like the following for the packed sequence ab123 + # ab---123--- ab---123--- + # padded_doc_id 00000111111 padded_tok_id 01234012345 + # This will later be useful for the attention mask at generation + self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths) + + @torch.compiler.disable + def setup_generation(self, lengths): + # KV Cache offset is set to the start of the padded documents + for module in self.model.modules(): + if isinstance(module, Attention): + module.kv_cache.offset = self.padded_doc_start + # The token ids during generations correspond to the lengths of each doc + # current_tok_id will be incremented during generation + self.current_tok_id = lengths.clone() + # Since we're generating one token per document + # the document id is just an arange + self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device) + + # From here on some methods for generation + def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor): + # Prefilling is done by taking multiple packed sequences and + # doing block diagonal attention on them so they remain independent + self.setup_prefilling(lengths=lengths) + prefill_out = self.model.forward( + tokens, + tok_idx=self.prefill_tok_id, + mask=self.prefill_mask, + attn_impl="flex_attention", + ) + self.setup_generation(lengths=lengths) + return prefill_out + + def generate_next_token(self, current_token): + # Since we're doing generation with multiple sequences at once + # we need to ignore tokens and cache entries from other sequences + # or in the future. + # Example mask : + # keys + # abc--1234-- + # queries c 11100000000 + # 4 00000111100 + + # mask shape : (n_seqs, cache_size) + doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0) + caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0) + mask = doc_mask & caus_mask + out = self.model.forward( + current_token, + tok_idx=self.current_tok_id, # n_seqs + mask=mask, + attn_impl="sdpa", + ) + self.current_tok_id += 1 + return out + + @torch.inference_mode() + def generate(self, prompts): + # Tokenize + prompts = [ + self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts + ] + # Truncate + max_seqlen = ( + self.max_tokens + if not hasattr(self.model, "max_seqlen") + else self.model.max_seqlen + ) + max_prompt_len = self.max_prompt_len or min( + max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len + ) + prompts = [p[-max_prompt_len:] for p in prompts] + # Account for the generation in lengths + padded_lengths = [len(p) + self.max_gen_len for p in prompts] + generation = [] + loglikelihood = [] + greedy = [] + it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths) + if self.show_progress: + it = tqdm(it) + for batch in it: + n_seqs = len(batch) + generated_tokens = [[] for _ in range(n_seqs)] + is_done = [False for _ in range(n_seqs)] + packed_batch, lengths = pack_prompts(batch) + packed_batch, lengths = packed_batch.cuda(), lengths.cuda() + n_seqs = lengths.size(0) + + # Prefilling cache + prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths) + # Selecting last token in each prompt + all_tokens = sample_tokens( + prompt_logits, self.temperature, self.top_p, self.top_k + ) + start_token = all_tokens[:, lengths.cumsum(0) - 1] + + for seq_id, tok in enumerate(start_token.squeeze(0).tolist()): + generated_tokens[seq_id].append(tok) + + current_token = start_token + for i in range(1, self.max_gen_len): + + next_logits = self.generate_next_token(current_token) + next_token = sample_tokens( + next_logits.clone(), self.temperature, self.top_p, self.top_k + ) + + for seq_id, tok in enumerate(next_token.squeeze(0).tolist()): + if not is_done[seq_id]: + generated_tokens[seq_id].append(tok) + current_end_str = self.tokenizer.decode( + generated_tokens[seq_id][-self.max_until_size :] + ) + contains_end_string = any( + [e in current_end_str for e in self.until] + ) + is_done[seq_id] = ( + contains_end_string or tok == self.tokenizer.eos_id + ) + if all(is_done): + break + + current_token = next_token + + generation.extend([self.tokenizer.decode(g) for g in generated_tokens]) + + for p, logit in zip( + batch, prompt_logits.squeeze(0).split(lengths.tolist()) + ): + x = logit[:-1] + y = torch.tensor(p[1:], device=x.device) + loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu()) + greedy.append((x.argmax(dim=-1) == y).cpu()) + + return generation, loglikelihood, greedy + + +def load_consolidated_model_and_tokenizer( + consolidated_path, + model_cls=LMTransformer, + model_args_cls=LMTransformerArgs, +): + ckpt_path = Path(consolidated_path) + config = ckpt_path / "params.json" + config = OmegaConf.load(config) + + param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ + config.distributed.model_dtype + ] + model_args = dataclass_from_dict(model_args_cls, config.model, strict=False) + tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path) + model = model_cls(model_args) + st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) + model.load_state_dict(st_dict["model"]) + model = model.cuda().eval() + for param in model.parameters(): + param.data = param.data.to(dtype=param_dtype) + return model, tokenizer + + +def main(): + # Load CLI arguments (overrides) and combine with a YAML config + cfg = OmegaConf.from_cli() + gen_cfg = dataclass_from_dict( + PackedCausalTransformerGeneratorArgs, cfg, strict=False + ) + print(cfg) + + model, tokenizer = load_consolidated_model_and_tokenizer(cfg.ckpt) + + generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer) + + # Allow multiple prompts + prompts = [] + while True: + prompt = input("Enter a prompt (or press enter to finish): ") + if not prompt: + break + prompts.append(prompt) + + # Start generation + start_time = time.time() + generation, loglikelihood, greedy = generator.generate(prompts) + end_time = time.time() + + # Calculate tokens per second + total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation) + tokens_per_second = total_tokens / (end_time - start_time) + + # Display the results + for i, gen in enumerate(generation): + print(f"\nPrompt {i+1}: {prompts[i]}") + print(f"Generated Text: {gen}") + + print(f"\nTokens per second: {tokens_per_second:.2f}") + + +if __name__ == "__main__": + main() diff --git a/apps/pLM/train.py b/apps/pLM/train.py new file mode 100644 index 00000000..b6720bd9 --- /dev/null +++ b/apps/pLM/train.py @@ -0,0 +1,654 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from copy import deepcopy +import gc +import logging +import os +import sys +import time +from contextlib import ExitStack +from dataclasses import asdict, dataclass, field +from pathlib import Path +from timeit import default_timer as timer +from typing import Any, Dict, List, Optional + +import numpy as np +from omegaconf import OmegaConf +import torch +import torch.distributed +import torch.nn.functional as F +import xformers.profiler +from torch.optim import lr_scheduler +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed._tensor import DTensor + +from lingua.args import dataclass_from_dict, dump_config, flatten_dict +from lingua.checkpoint import CheckpointArgs, CheckpointManager +from lingua.data import ( + DataArgs, + PackTokensState, + build_dataloader_from_args, + init_dataloader_state_from_args, +) +from lingua.distributed import ( + DistributedArgs, + EnvironmentArgs, + init_signal_handler, + dist_mean_dict, + get_device_mesh, + get_is_master, + get_world_size, + parallelize_model, + setup_env, + setup_torch_distributed, + clean_env, + requeue_slurm_job, + check_model_value_range, +) +from lingua.logger import init_logger +from lingua.metrics import ( + GPUMemoryMonitor, + LoggingArgs, + MetricLogger, + get_num_params, +) +from lingua.optim import OptimArgs, build_optimizer +from lingua.profiling import ProfilerArgs, maybe_run_profiler +from lingua.tokenizer import build_tokenizer +from apps.pLM.transformer import ( + LMTransformerArgs, + LMTransformer, + get_num_flop_per_token, + build_fsdp_grouping_plan, + tp_parallelize, + get_no_recompute_ops, +) +from lingua.probe import AutoProbeD +from lingua.stool import StoolArgs, launch_job + +import wandb + +logger = logging.getLogger() + + +@dataclass +class TrainArgs: + name: str = "lingua" + dump_dir: str = "" + + seed: int = 42 + + # Number of gradient accumulation steps + # Total batch size is batch_size*grad_acc_steps + grad_acc_steps: int = 1 + + gc_collect_freq: int = 1000 + probe_freq: Optional[int] = None + + # Nb optimizer steps to take + steps: int = 1000 + + data: DataArgs = field(default_factory=DataArgs) + optim: OptimArgs = field(default_factory=OptimArgs) + model: LMTransformerArgs = field(default_factory=LMTransformerArgs) + distributed: DistributedArgs = field(default_factory=DistributedArgs) + env: EnvironmentArgs = field(default_factory=EnvironmentArgs) + + checkpoint: CheckpointArgs = field(default_factory=CheckpointArgs) + profiling: ProfilerArgs = field(default_factory=ProfilerArgs) + logging: LoggingArgs = field(default_factory=LoggingArgs) + + # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus + async_eval_gpus: Optional[int] = None + eval: Optional[Any] = None + + +@dataclass +class TrainState(Stateful): + step: int # Nb of steps taken by the optimizer + acc_step: int # Nb of accumulation steps done since last optimizer step + scheduler: lr_scheduler.LambdaLR + data_loader_state: PackTokensState + + def state_dict(self) -> Dict[str, Any]: + return { + "step": self.step, + "acc_step": self.acc_step, + "data_loader_state": self.data_loader_state, + "scheduler": self.scheduler.state_dict(), + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.acc_step = state_dict["acc_step"] + self.data_loader_state = PackTokensState(**state_dict["data_loader_state"]) + self.scheduler.load_state_dict(state_dict["scheduler"]) + + +def validate_train_args(args: TrainArgs, output_size: int): + if args.model.vocab_size < 0: + logger.info(f"Setting model output size to {args.model.vocab_size}") + args.model.vocab_size = output_size + assert ( + args.model.vocab_size == output_size + ), "Vocab size should be the same as output size" + + assert args.dump_dir, "Dump dir not set" + + if args.checkpoint.path is None: + logger.info(f"Setting checkpoint path to {args.checkpoint.path}") + args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + + for source in args.data.sources: + data_path = os.path.join(args.data.root_dir, source) + assert os.path.exists(data_path), f"{data_path} doesn't exist" + + if ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + != get_world_size() + ): + assert get_world_size() % args.distributed.dp_shard == 0 + args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard + + assert args.distributed.dp_replicate % args.distributed.tp_size == 0 + args.distributed.dp_replicate = ( + args.distributed.dp_replicate // args.distributed.tp_size + ) + + logger.warning( + f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}" + ) + assert ( + args.distributed.dp_replicate + * args.distributed.dp_shard + * args.distributed.tp_size + == get_world_size() + ) + + if args.distributed.fsdp_type == "no_shard": + assert ( + args.distributed.dp_shard == 1 + and args.distributed.dp_replicate == get_world_size() + ) + + args.model.max_seqlen = args.data.seq_len + + if args.distributed.tp_size == 1: + logger.warning( + "Tensor parallelism has not been tested for a while, use at your own risk" + ) + + assert ( + args.probe_freq != args.profiling.mem_steps + ), "Don't profile during probe step" + assert ( + args.probe_freq != args.profiling.profile_steps + ), "Don't profile during probe step" + if args.logging.wandb is not None: + args.logging.wandb.name = args.name + + if args.probe_freq is not None: + assert ( + args.distributed.tp_size == 1 + ), "Probing not supported with tensor parallelism" + assert ( + args.distributed.selective_activation_checkpointing is False + ), "Probing not supported with selective activation checkpointing" + + +preemption_flag = dict(flag=False) + + +def set_preemption_flag(signum, frame): + logger.warning("Signal handler called with signal " + str(signum)) + logger.warning("Preemption ! checkpointing asap and exiting.") + preemption_flag["flag"] = True + + +def every_n_steps(train_state, freq, acc_step=None, acc_freq=None): + test = train_state.step % freq == 0 + if acc_step is not None: + test = test and (train_state.acc_step == acc_step) + elif acc_freq is not None: + test = test and ((train_state.acc_step % acc_freq) == 0) + return test + + +def train(args: TrainArgs): + with ExitStack() as context_stack: + tokenizer = build_tokenizer(args.data.tokenizer.name, args.data.tokenizer.path) + validate_train_args( + args, + tokenizer.n_words, + ) + if get_is_master(): + os.makedirs(args.dump_dir, exist_ok=True) + dump_config(args, Path(args.dump_dir) / "config.yaml") + init_logger(Path(args.dump_dir) / "train.log") + init_signal_handler(set_preemption_flag) # For handling preemption signals. + setup_env(args.env) + setup_torch_distributed(args.distributed) + world_mesh = get_device_mesh(args.distributed) + logger.info(f"Starting job: {args.name}") + + + # build dataloader + # need dp world size and rank + dp_mesh = world_mesh["dp_replicate"] + dp_degree = dp_mesh.size() + dp_rank = dp_mesh.get_local_rank() + if args.distributed.dp_shard > 1: + dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank() + dp_degree *= world_mesh["dp_shard"].size() + + logger.info(f"Running on dp rank : {dp_rank}") + logger.info(f"Running on dp size : {dp_degree}") + + torch.manual_seed(args.seed) + logger.info(f"Building model") + + # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory + with torch.device("meta"): + model = LMTransformer(args.model) + logger.info(f"Model is built !") + + model_param_count = get_num_params(model) + + model = parallelize_model( + model, + world_mesh, + args.model, + args.distributed, + fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + tp_parallelize=tp_parallelize, + no_recompute_ops=get_no_recompute_ops(), + ) + + # Once we shard the model on different gpus we can actually initialize the model + # First we create empty tensors of the correct shapes + model = model.to_empty(device="cuda") + # Then we init the model. Please make sure this function initializes *ALL* parameters + # and buffers, otherwise you will have random values in the unitialized tensors + # which will silently fail (give nan gradients for example) + + if args.checkpoint.init_ckpt_path: + st_dict = torch.load(args.checkpoint.init_ckpt_path) + model.load_state_dict(st_dict) + else: + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + torch.manual_seed(args.model.seed) + model.init_weights() + check_model_value_range(model, range=10.0, std=1.0) + + # log model size + + logger.info(f"Model size: {model_param_count:,} total parameters") + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info( + f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " + f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" + ) + logger.info(f"GPU memory usage: {gpu_memory_monitor}") + + # build optimizer after apply parallelisms to the model + optimizer, scheduler = build_optimizer(model, args.optim, args.steps) + data_loader_state = init_dataloader_state_from_args( + args.data, dp_rank, dp_degree + ) + + train_state = TrainState( + step=0, + acc_step=0, + data_loader_state=data_loader_state, + scheduler=scheduler, + ) + + checkpoint = CheckpointManager(args.checkpoint) + checkpoint.load(model, optimizer, train_state, world_mesh) + # Either load from latest checkpoint or start from scratch + if args.probe_freq is not None: + if get_is_master(): + os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + torch.distributed.barrier() + probe = AutoProbeD( + model, + ( + Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + if (dp_rank % 128 == 0) + else None + ), + ) + probe_mod = model._orig_mod if args.distributed.compile else model + + gc.disable() + + # train loop + model.train() + metric_logger = context_stack.enter_context( + MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + ) + data_loader = context_stack.enter_context( + build_dataloader_from_args( + args.data, + state=train_state.data_loader_state, + ) + ) + torch_profiler = context_stack.enter_context( + maybe_run_profiler(args.dump_dir, model, args.profiling) + ) + + nwords_since_last_log = 0 + time_last_log = timer() + gc.collect() + while train_state.step < args.steps: + # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1 + train_state.acc_step += 1 + train_state.acc_step = train_state.acc_step % args.grad_acc_steps + + # get batch + curr_lr = float(optimizer.param_groups[0]["lr"]) + data_load_start = timer() + batch, train_state.data_loader_state = next(data_loader) + batch = torch.tensor( + batch, + dtype=torch.long, + ) + + if every_n_steps(train_state, args.gc_collect_freq, acc_step=0): + logger.info("garbage collection") + # we do garbage collection manually otherwise different processes + # run the GC at different times so they slow down the whole pipeline + gc.collect() + + input_ids = batch[:, :, 0].cuda() + labels = batch[:, :, 1].cuda() + data_load_time = round(timer() - data_load_start, 4) + nwords_since_last_log += input_ids.numel() + + bsz, seqlen = labels.shape + + # forward + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + start_timer.record() + + # This is an automatic probe that will compute statistics + # of all linears' inputs, weights and outputs + # along with attention logits and entropy + # both in forward and backward pass + if (args.probe_freq is not None) and every_n_steps( + train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps + ): + # Here we do a fake forward and backward pass on a smaller + # batch size to avoid OOM + # This assumes the model has no stateful layers (batch norm..) + assert ( + next(probe_mod.parameters()).grad is None + ), "Can't probe model if grads are not reset" + + with probe: + probe.metadata = { + "it": train_state.step, + "global_step": train_state.step, + "loop": "lingua", + } + # Non compiled model uses roughly 2x memory in our exps + # So we divide bsz by 2 or seqlen by 2 + probe_bsz = max(1, bsz // 2) + probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2) + probe_loss = probe_mod( + input_ids[:probe_bsz, :probe_seq], + labels[:probe_bsz, :probe_seq], + ) + probe_loss.backward() + # We zero grads to cancel this fake step + optimizer.zero_grad() + + assert ( + next(probe_mod.parameters()).grad is None + ), "Probe model shouldn't have grads at this point" + + loss = model(input_ids, labels) + + # We scale loss with grad_acc_steps so the gradient is the same + # regardless of grad_acc_steps + loss = loss / args.grad_acc_steps + # backward on scaled loss to create scaled gradients + loss.backward() + # For logging we undo that scaling + loss = loss.detach() * args.grad_acc_steps + + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.optim.clip, foreach=True + ) + + grad_norm = ( + grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm + ).item() + + # optimizer step + if train_state.acc_step == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + train_state.step += 1 + + # updates the scale for next iteration + # training iteration complete + end_timer.record() + + torch.cuda.synchronize() + + curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4) + + # if profiler is active + if torch_profiler: + xformers.profiler.step() + + # log metrics + if every_n_steps( + train_state, + args.logging.freq, + acc_step=None if args.logging.acc_freq else 0, + acc_freq=args.logging.acc_freq, + ): + time_delta = timer() - time_last_log + wps = nwords_since_last_log / (time_delta * args.distributed.tp_size) + + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + + total_acc_steps = ( + args.grad_acc_steps * train_state.step + train_state.acc_step + ) + tokens_per_gpu = ( + total_acc_steps * args.data.batch_size * args.data.seq_len + ) + total_tokens = dp_degree * tokens_per_gpu + # This is an estimate and the correct values may change + # if you change the architecture + # Use xformer's analyze profile trace to get actual measurement + FLOPS = ( + get_num_flop_per_token( + model_param_count - args.model.vocab_size * args.model.dim, + args.model.n_layers, + args.model.dim, + args.data.seq_len, + ) + * wps + ) + metrics = flatten_dict( + { + "global_step": train_state.step, + "acc_step": train_state.acc_step, + "speed": { + "wps": wps, + "FLOPS": FLOPS, + "curr_iter_time": curr_iter_time, + "data_load_time": data_load_time, + }, + "optim": { + "grad_norm": grad_norm, + "lr": curr_lr, + "total_tokens": total_tokens, + }, + "memory": gpu_mem_stats._asdict(), + }, + sep="/", + ) + + to_sync = {} + to_sync["loss/out"] = loss.item() + metrics.update(dist_mean_dict(to_sync)) + + if get_is_master(): + metric_logger.log(metrics) + + gpu_memory_monitor.reset_peak_stats() + nwords_since_last_log = 0 + time_last_log = timer() + if get_is_master(): + logger.info( + f"step: {train_state.step}" + f" acc: {train_state.acc_step}" + f" loss: {round(loss.item(),4):>7}" + f" grad: {grad_norm:.2e}" + f" flops: {FLOPS:.2e}" + f" wps: {wps:.2e}" + f" iter: {curr_iter_time:>7}" + f" data: {data_load_time:>5}" + f" lr: {curr_lr:.2e}" + f" mem: {gpu_mem_stats.max_active_pct:.0f}%" + f" pow: {gpu_mem_stats.power_draw/1000} W" + ) + + saved = False + if every_n_steps( + train_state, args.checkpoint.dump.every, acc_step=0 + ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + saved = checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + + if args.eval is not None and every_n_steps( + train_state, args.checkpoint.eval.every, acc_step=0 + ): + from apps.pLM.eval import ( + launch_eval, + EVAL_FOLDER_NAME, + EvalArgs, + ) + + eval_args = dataclass_from_dict(EvalArgs, args.eval) + + eval_args.global_step = train_state.step + eval_args.ckpt_dir = str(checkpoint.existing_saves[-1]) + eval_args.dump_dir = str( + os.path.join( + args.dump_dir, + "evals", + EVAL_FOLDER_NAME.format(train_state.step), + ) + ) + eval_args.metric_log_dir = args.dump_dir + if args.async_eval_gpus is None: + launch_eval(eval_args) + elif get_is_master(): + if wandb.run is not None and args.logging.wandb is not None: + eval_args.wandb = deepcopy(args.logging.wandb) + assert args.async_eval_gpus > 0 + logger.info(f"Launching evals on {args.async_eval_gpus} gpus") + with clean_env(): + launch_job( + StoolArgs( + asdict(eval_args), + script="apps.pLM.eval", + copy_code=False, + nodes=args.async_eval_gpus // 8, + qos="lowest", + ) + ) + + if preemption_flag["flag"]: + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + requeue_slurm_job() + sys.exit(0) + + if not saved: + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + gc.collect() + + +def main(): + """ + The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments + This accepts arguments as a dot list + So if the dataclass looks like + + @dataclass + class DummyArgs: + name: str + model: LMTransformerArgsgs + + @dataclass + class LMTransformerArgsgs: + dim: int + + Then you can pass model.dim=32 to change values in LMTransformerArgsgs + or just name=tictac for top level attributes. + + The behavior here is as follows: + 1. We instantiate TrainArgs with its default values + 2. We override those default values with the ones in the provided config file + 3. We override the result with the additional arguments provided through command line + + For example, if the config is the following + + model: + dim: 128 + n_layers: 4 + + and you call train.py with train.py model.dim=64 + + Then the final TrainArgs will have + + model: + dim: 64 + n_layers: 4 + + Plus all the default values in TrainArgs dataclass. + """ + cli_args = OmegaConf.from_cli() + file_cfg = OmegaConf.load(cli_args.config) + # We remove 'config' attribute from config as the underlying DataClass does not have it + del cli_args.config + + default_cfg = OmegaConf.structured(TrainArgs()) + cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) + cfg = OmegaConf.to_object(cfg) + + train(cfg) + + +if __name__ == "__main__": + main() diff --git a/apps/pLM/transformer.py b/apps/pLM/transformer.py new file mode 100644 index 00000000..487284f3 --- /dev/null +++ b/apps/pLM/transformer.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask, BlockMask + +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + SequenceParallel, + PrepareModuleInput, + parallelize_module, +) + +from xformers.ops import fmha, AttentionBias +from lingua.transformer import ( + BaseTransformer, + BaseTransformerArgs, + RMSNorm, + cross_entropy, +) + + +def create_causal_mask(seqlen, attn_impl, sliding_window): + if sliding_window is not None and attn_impl == "xformers": + return fmha.attn_bias.LocalAttentionFromBottomRightMask( + window_left=sliding_window - 1, window_right=0 + ) + elif attn_impl == "xformers": + return fmha.attn_bias.LowerTriangularMask() + elif attn_impl == "sdpa": + return "causal" + elif attn_impl == "flex_attention": + return create_block_mask(causal_mask, None, None, seqlen, seqlen) + else: + raise NotImplementedError( + f"Attention {attn_impl} with {sliding_window} sliding window not implemented" + ) + + +def attention_flops_per_token(n_layers, seq_len, dim, causal): + # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 + return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1)) + + +def get_num_flop_per_token( + num_non_embed_params: int, n_layers: int, dim: int, seq_len: int +) -> int: + return 6 * num_non_embed_params + attention_flops_per_token( + n_layers, seq_len, dim, True + ) + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +@dataclass +class LMTransformerArgs(BaseTransformerArgs): + + seed: int = 42 + + vocab_size: int = -1 + weight_tying: bool = False + + sliding_window: Optional[int] = None + + +class LMTransformer(BaseTransformer): + def __init__(self, args: LMTransformerArgs): + super().__init__(args) + self.weight_tying = args.weight_tying + self.sliding_window = args.sliding_window + + assert args.vocab_size > 0 + + self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + self.output = nn.Linear( + args.dim, + args.vocab_size, + bias=False, + ) + + if args.weight_tying: + self.output.weight = self.embeddings.tok_embeddings.weight + + self.init_weights() + + def forward( + self, + token_values: torch.Tensor, + target: Optional[torch.Tensor] = None, + tok_idx: Optional[torch.Tensor] = None, + mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None, + attn_impl: str = "sdpa", + ): + bsz, seqlen = token_values.shape + + h = self.tok_embeddings(token_values) + + mask = ( + mask + if mask is not None + else create_causal_mask(seqlen, attn_impl, self.sliding_window) + ) + + h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl) + + logits = self.output(self.norm(h)) + if target is not None: + return cross_entropy(logits, target) + else: + return logits + + def reset_parameters(self, init_std=None): + # Either use fixed base std or sqrt model dim + super().reset_parameters() + init_std = init_std or (self.dim ** (-0.5)) + self.norm.reset_parameters() + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + if not self.weight_tying: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) + + +# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops) +def get_no_recompute_ops(): + return None + + +# Optional and only used for fully shard options (fsdp) is choose. Highly recommanded for large models +def build_fsdp_grouping_plan(model_args: LMTransformerArgs): + group_plan: Tuple[int, bool] = [] + + # Grouping and output seperately + group_plan.append(("tok_embeddings", False)) + + # Grouping by layers + for i in range(model_args.n_layers): + group_plan.append((f"layers.{i}", False)) + + group_plan.append(("output", True)) + + return group_plan + + +# Optional and only used for model/tensor parallelism when tp_size > 1 +def tp_parallelize(model, tp_mesh, model_args: LMTransformerArgs, distributed_args): + assert model_args.dim % distributed_args.tp_size == 0 + assert model_args.vocab_size % distributed_args.tp_size == 0 + assert model_args.n_heads % distributed_args.tp_size == 0 + assert (model_args.n_kv_heads or 0) % distributed_args.tp_size == 0 + assert model_args.n_heads % (model_args.n_kv_heads or 1) == 0 + + # Embedding layer tp + main_plan = {} + main_plan["tok_embeddings"] = ColwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ) + main_plan["norm"] = SequenceParallel() + main_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), output_layouts=Replicate() + ) + + parallelize_module( + model, + tp_mesh, + main_plan, + ) + + # Attention layers tp + for layer in model.layers: + layer_plan = {} + + layer_plan["attention"] = PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ) + layer_plan["attention_norm"] = SequenceParallel() + layer_plan["attention.wq"] = ColwiseParallel() + layer_plan["attention.wk"] = ColwiseParallel() + layer_plan["attention.wv"] = ColwiseParallel() + layer_plan["attention.wo"] = RowwiseParallel(output_layouts=Shard(1)) + + # Feedforward layers tp + layer_plan["feed_forward"] = PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ) + layer_plan["ffn_norm"] = SequenceParallel() + layer_plan["feed_forward.w1"] = ColwiseParallel() + layer_plan["feed_forward.w3"] = ColwiseParallel() + layer_plan["feed_forward.w2"] = RowwiseParallel(output_layouts=Shard(1)) + + parallelize_module( + layer, + tp_mesh, + layer_plan, + ) + + # Adjusting the number of heads and kv heads according to the tp size + attn_layer = layer.attention + attn_layer.n_heads = attn_layer.n_heads // distributed_args.tp_size + attn_layer.n_kv_heads = attn_layer.n_kv_heads // distributed_args.tp_size diff --git a/lingua/distributed.py b/lingua/distributed.py index dbaeda0d..8dd1a69c 100644 --- a/lingua/distributed.py +++ b/lingua/distributed.py @@ -148,12 +148,18 @@ def get_is_slurm_job() -> bool: return "SLURM_JOB_ID" in os.environ and not get_is_torch_run() +@lru_cache() +def is_pbs_job() -> bool: + return "PBS_JOBID" in os.environ and not get_is_torch_run() + @lru_cache() def get_global_rank() -> int: if get_is_torch_run(): return int(os.environ["RANK"]) elif get_is_slurm_job(): return int(os.environ["SLURM_PROCID"]) + elif is_pbs_job(): + return int(os.environ["PMI_RANK"]) else: return 0 @@ -164,6 +170,8 @@ def get_local_rank() -> int: return int(os.environ["LOCAL_RANK"]) elif get_is_slurm_job(): return int(os.environ["SLURM_LOCALID"]) + elif is_pbs_job(): + return int(os.environ["PMI_LOCAL_RANK"]) else: return 0 @@ -174,6 +182,8 @@ def get_world_size() -> int: return int(os.environ["WORLD_SIZE"]) elif get_is_slurm_job(): return int(os.environ["SLURM_NTASKS"]) + elif is_pbs_job(): + return int(os.environ["PMI_SIZE"]) else: return 1 @@ -187,6 +197,8 @@ def get_is_master() -> bool: def get_master_port(job_id: int) -> int: if get_is_torch_run(): return int(os.environ["MASTER_PORT"]) + if is_pbs_job(): #TODO: this is normally set in PBS, remove when done debugging + return int(os.environ.get("MASTER_PORT", 34567)) else: MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000) rng = random.Random(job_id) @@ -257,10 +269,12 @@ def setup_torch_distributed(dist_args): logger.info(f"Run launched with torchrun, local rank: {local_rank}") elif get_is_slurm_job(): logger.info(f"Run launched with slurm, local rank: {local_rank}") + elif is_pbs_job(): + logger.info(f"Run launched with PBS (or MPI), local rank: {local_rank}") else: logger.info("Single GPU job") - logger.info(f"ENV: {os.environ}") + # logger.info(f"ENV: {os.environ}") # set GPU device assert 0 <= local_rank < 8 diff --git a/lingua/tokenizer.py b/lingua/tokenizer.py index 71b29c36..c7ec1a09 100644 --- a/lingua/tokenizer.py +++ b/lingua/tokenizer.py @@ -198,6 +198,77 @@ def get_token_offsets( return substrs, offsets +class AminoAcidTokenizer(Tokenizer): + + def __init__(self) -> None: + # Define standard amino acids and their single-letter codes + AMINO_ACIDS = [ + 'A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', + 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', + ] + + # Special tokens dictionary + SPECIAL_TOKENS = { + '': 0, # Beginning of sequence + '': 1, # End of sequence + '': 2, # Unknown token + '': 3, # Padding + } + + # Create token mappings, starting IDs after special tokens + TOKEN_TO_ID = {aa: idx + len(SPECIAL_TOKENS) for idx, aa in enumerate(AMINO_ACIDS)} + ID_TO_TOKEN = {idx + len(SPECIAL_TOKENS): aa for idx, aa in enumerate(AMINO_ACIDS)} + + # Include special tokens in the mappings + TOKEN_TO_ID.update(SPECIAL_TOKENS) + ID_TO_TOKEN.update({id_: token for token, id_ in SPECIAL_TOKENS.items()}) + + self.SPECIAL_TOKENS = SPECIAL_TOKENS + self.TOKEN_TO_ID = TOKEN_TO_ID + self.ID_TO_TOKEN = ID_TO_TOKEN + + # Required to satisfy API + self.bos_id = self.SPECIAL_TOKENS[''] + self.eos_id = self.SPECIAL_TOKENS[''] + self.n_words = len(TOKEN_TO_ID) + + def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: + tokens = [] + for char in text: + token_id = self.TOKEN_TO_ID.get(char, self.SPECIAL_TOKENS['']) + tokens.append(token_id) + if add_bos: + tokens = [self.SPECIAL_TOKENS['']] + tokens + if add_eos: + tokens = tokens + [self.SPECIAL_TOKENS['']] + return tokens + + def decode(self, tokens: List[int]) -> str: + chars = [] + for token in tokens: + token_str = self.ID_TO_TOKEN.get(token, '') + if token_str in self.SPECIAL_TOKENS: + continue # Skip special tokens + chars.append(token_str) + return ''.join(chars) + + def get_token_offsets( + self, text: str, tokens: Optional[List[int]] = None + ) -> Tuple[List[str], List[int]]: + tokens = tokens or self.encode(text) + token_texts = [] + offsets = [] + idx = 0 + for token in tokens: + token_str = self.ID_TO_TOKEN.get(token, '') + if token_str in self.SPECIAL_TOKENS: + continue # Skip special tokens + token_texts.append(token_str) + offsets.append(idx) + idx += 1 + return token_texts, offsets + + def build_tokenizer(name: str, path: Optional[str] = None) -> Tokenizer: if name == "bytes": return ByteTokenizer() @@ -209,5 +280,7 @@ def build_tokenizer(name: str, path: Optional[str] = None) -> Tokenizer: elif name == "tiktoken": assert has_tiktoken, "tiktoken not installed" return TikTokenTokenizer(path) + elif name == "aa": + return AminoAcidTokenizer() else: raise NotImplementedError(f"{name} tokenizer type is not implemented")