Skip to content

[feat] Support sparse copy for a large number of experts #56

@sohamparikh

Description

@sohamparikh

🐞 Describe the Bug

Facing an OutOfResources error with 64 fine-grained experts and dropless MoE enabled, even though there is sufficient GPU memory.

🔄 Steps to Reproduce

Steps to reproduce the behavior:

  1. Fast-LLM Docker image tag: sha-8f06975
  2. Training config:
model:
  base_model:
    transformer:
      normalization:
        type: rms_norm
      num_layers: 16
      hidden_size: 2048
      num_attention_heads: 16
      head_groups: 1
      add_linear_biases: false
      ffn_hidden_size: 1024
      kv_channels: 128
      use_rotary_embeddings: true
      gated: true
      activation_type: silu
      num_experts: 64
      num_experts_per_token: 8
      init_method_std: 0.02
      init_method_std_qkv: 0.02
      init_method_std_attn_proj: 0.0035355339059327372
      init_method_std_mlp_1: 0.02
      init_method_std_mlp_2: 0.0035355339059327372
      expert_z_loss_coefficient: 0.001
      mlp_lr_scale:
      - null
    vocab_size: 131072
    use_position_embeddings: false
    tie_word_embeddings: false
    init_method_std_embed: 0.02
    cross_entropy_impl: fused
  multi_stage:
    zero_stage: 1
  distributed:
    world_size: 16
    rank: 0
    local_world_size: 8
    distributed_timeout: 3600.0
    seed: 984059
    training_dtype: bfloat16
pretrained:
  format: fast_llm
  path: null
  load_config: architecture
training:
  validation:
    interval: 1000
    iterations: 25
  logs:
    interval: 10
  export:
    format: fast_llm
    optimizer_state: true
  wandb:
    group_name: pilot
    project_name: olmoe-1b-7b
    entity_name: tscholak
  train_iters: 1000
  num_workers: 8
batch:
  micro_batch_size: 2
  depth_first_micro_batches: 1
  breadth_first_micro_batches: 1
  sequential_micro_batches: 1
  batch_size: 32
  sequence_length: 2048
  micro_sequence_length: 2048
data:
  tokenizer:
    path: null
  fim:
    rate: 0.0
  split:
  - 998.0
  - 2.0
  - 0.0
  format: file
  path:
  - /mnt/workspace/inputs/openwebtext-SmolLM2/fast_llm_dataset.json
profiling:
  cuda: false
optimizer:
  learning_rate:
    base: 0.0001
  weight_decay: 0.1
  beta_2: 0.95
  1. Error log
  File "/app/fast_llm/layers/transformer/mixture_of_experts.py", line 114, in forward
    return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None  # noqa
  File "/app/fast_llm/layers/transformer/mixture_of_experts.py", line 118, in _forward_dropless
    sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape)
  File "/app/fast_llm/functional/triton/sparse_copy.py", line 301, in get_sparse_map
    sparse_map_kernel[(triton.cdiv(num_rows_dense, block_size),)](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 180, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 412, in run
    kernel.run(grid_0, grid_1, grid_2, metadata.num_warps,
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 339, in __getattribute__
    self._init_handles()
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 332, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.

🎯 Expected Behavior

Should work without going out of resources.

📝 Additional Context

Include any other information that may help us understand the issue, such as:

  • Works for micro batch size 1 and micro-sequence length 2048. Increasing the micro-batch size to 2 or micro sequence length to 4096 throws the error.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request
No fields configured for Feature.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions