I followed every step in the README without modifying the code. My command was:
python train_mamba.py --model /mnt/afs/xylu/model/state-spaces-mamba-2.8b --tokenizer EleutherAI/gpt-neox-20b --learning_rate 5e-5 --batch_size 4 --data_path ./data/ultrachat_small.jsonl --num_epochs 3
My env is:
accelerate 0.25.0
bitsandbytes 0.41.3
causal-conv1d 1.0.0
certifi 2024.7.4
charset-normalizer 3.3.2
einops 0.8.0
filelock 3.15.4
fsspec 2024.6.1
huggingface-hub 0.17.3
idna 3.7
Jinja2 3.1.4
mamba-ssm 1.0.1
MarkupSafe 2.1.5
mpmath 1.3.0
networkx 3.2.1
ninja 1.11.1.1
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.6.20
nvidia-nvtx-cu12 12.1.105
packaging 24.1
pip 24.2
psutil 6.0.0
PyYAML 6.0.2
regex 2024.7.24
requests 2.32.3
safetensors 0.4.4
scipy 1.11.4
setuptools 72.1.0
sympy 1.13.2
tokenizers 0.14.1
torch 2.1.0
tqdm 4.66.5
transformers 4.35.0
triton 2.1.0
typing_extensions 4.12.2
urllib3 2.2.2
wheel 0.43.0
My error message is:
Traceback (most recent call last):
File "/data/home/xylu/mamba-chat/train_mamba.py", line 60, in
run(args)
File "/data/home/xylu/mamba-chat/train_mamba.py", line 45, in run
trainer.train()
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
File "/data/home/xylu/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss
lm_logits = model(input_ids).logits
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
output.reraise()
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/_utils.py", line 694, in reraise
raise exception
TypeError: Caught TypeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 75, in _bench
full_nargs = {**self.nargs, **current}
TypeError: 'NoneType' object is not a mapping
I followed every step in the README without modifying the code. My command was:
python train_mamba.py --model /mnt/afs/xylu/model/state-spaces-mamba-2.8b --tokenizer EleutherAI/gpt-neox-20b --learning_rate 5e-5 --batch_size 4 --data_path ./data/ultrachat_small.jsonl --num_epochs 3
My env is:
accelerate 0.25.0
bitsandbytes 0.41.3
causal-conv1d 1.0.0
certifi 2024.7.4
charset-normalizer 3.3.2
einops 0.8.0
filelock 3.15.4
fsspec 2024.6.1
huggingface-hub 0.17.3
idna 3.7
Jinja2 3.1.4
mamba-ssm 1.0.1
MarkupSafe 2.1.5
mpmath 1.3.0
networkx 3.2.1
ninja 1.11.1.1
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.6.20
nvidia-nvtx-cu12 12.1.105
packaging 24.1
pip 24.2
psutil 6.0.0
PyYAML 6.0.2
regex 2024.7.24
requests 2.32.3
safetensors 0.4.4
scipy 1.11.4
setuptools 72.1.0
sympy 1.13.2
tokenizers 0.14.1
torch 2.1.0
tqdm 4.66.5
transformers 4.35.0
triton 2.1.0
typing_extensions 4.12.2
urllib3 2.2.2
wheel 0.43.0
My error message is:
Traceback (most recent call last):
File "/data/home/xylu/mamba-chat/train_mamba.py", line 60, in
run(args)
File "/data/home/xylu/mamba-chat/train_mamba.py", line 45, in run
trainer.train()
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
File "/data/home/xylu/mamba-chat/trainer/mamba_trainer.py", line 9, in compute_loss
lm_logits = model(input_ids).logits
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 110, in parallel_apply
output.reraise()
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/_utils.py", line 694, in reraise
raise exception
TypeError: Caught TypeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/data/home/xylu/miniconda3/envs/mamba_chat/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 75, in _bench
full_nargs = {**self.nargs, **current}
TypeError: 'NoneType' object is not a mapping