Skip to content

Commit a8b5198

Browse files
authored
upload files first time
1 parent d348816 commit a8b5198

19 files changed

+3893
-1
lines changed

README.md

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,83 @@
1-
# EEGFoundation
1+
# EEGFoundation: Tokenized Spatiotemporal Foundation Model for EEG Signals
2+
3+
## Overview
4+
5+
EEGFoundation is a novel foundation model that treats neural dynamics as a discrete semantic language, overcoming the limitations of vision-based EEG analysis paradigms. By implementing amplitude-aware tokenization and channel-independent pretraining on a 27,000+ hour EEG corpus, the model learns universal neural oscillation patterns that generalize across diverse EEG analysis tasks.
6+
7+
8+
9+
![Main_fig1](./photos/Main_fig1.png)
10+
11+
**Fig.1 The EEGFoundation framework for spatiotemporal sequence modeling**
12+
13+
## Model Architecture
14+
15+
EEGFoundation follows a three-stage hierarchical approach:
16+
17+
1. **Amplitude-Aware Tokenization**: Continuous EEG signals are normalized and quantized into discrete symbolic tokens that preserve micro-voltage fluctuations while filtering high-frequency noise.
18+
19+
2. **Temporal Pretraining**: Using a RoFormer encoder with Rotary Position Embeddings, the model learns universal temporal dynamics from channel-independent EEG streams.
20+
21+
3. **Spatiotemporal Fusion**: Cross-channel attention dynamically aggregates local representations into a coherent global context for robust downstream task performance.
22+
23+
## Quick Start
24+
25+
### Environment Setup
26+
27+
```bash
28+
# Clone the repository
29+
git clone https://github.com/yourusername/EEGFoundation_github.git
30+
cd EEGFoundation_github
31+
32+
# Install dependencies (Python 3.10+ required)
33+
pip install torch>=2.0.0 transformers>=4.30.0 numpy>=1.24.0 scipy>=1.10.0
34+
pip install mne>=1.4.0 einops>=0.6.0 matplotlib>=3.7.0
35+
```
36+
37+
### Basic Usage
38+
39+
```python
40+
from src.models.downstream_EEGFoundation import load_downstream_model
41+
import torch
42+
import numpy as np
43+
44+
# Load pre-trained model for motor imagery classification
45+
model = load_downstream_model(
46+
model_path="models/BCIC-2a_model.pth",
47+
config_path="configs/BCIC_IV_2a_config.json"
48+
)
49+
50+
# Prepare input data (example)
51+
batch_size = 2
52+
num_channels = 20
53+
seq_length = 2000
54+
55+
eeg_signal = torch.randn(batch_size, num_channels, seq_length).float()
56+
embedding = torch.randn(batch_size, 512).float()
57+
58+
# Forward pass
59+
with torch.no_grad():
60+
outputs = model(input_ids=eeg_signal, embedding_data=embedding)
61+
predictions = torch.softmax(outputs['logits'], dim=-1)
62+
63+
print(f"Predictions shape: {predictions.shape}")
64+
```
65+
66+
### Demo Data
67+
68+
```python
69+
import numpy as np
70+
71+
# Load example data
72+
demo_data = np.load("demo_data/eeg_data.npy")
73+
print(f"Demo data shape: {demo_data.shape}")
74+
75+
# The demo_data directory contains:
76+
# - eeg_data.npy: Sample EEG recordings
77+
# - downstream_eeg_data.npz: Processed data for downstream tasks
78+
# - make_data.py: Script to generate synthetic EEG data
79+
```
80+
81+
## License
82+
83+
This project is licensed under the Apache License 2.0. See the LICENSE file for details.

configs/BCIC_IV_2a_config.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"seq_len": 800,
3+
"patch_size": 100,
4+
"stride": 25,
5+
"d_model": 512,
6+
"num_classes": 4,
7+
"num_channel": 22,
8+
"rms_norm": false,
9+
"embedding_dim": 512,
10+
"projection_embedding_dim": 512,
11+
"classification_dropout": 0.5,
12+
"classification_hidden_dim": 512,
13+
"learning_rate": 1e-4,
14+
"weight_decay": 0.1,
15+
"model_type": "eeg-downstream-classifier"
16+
}

configs/FACED_config.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"seq_len": 2000,
3+
"patch_size": 200,
4+
"stride": 100,
5+
"d_model": 512,
6+
"num_classes": 9,
7+
"num_channel": 32,
8+
"rms_norm": false,
9+
"embedding_dim": 512,
10+
"projection_embedding_dim": 512,
11+
"classification_dropout": 0.5,
12+
"classification_hidden_dim": 512,
13+
"learning_rate": 1e-4,
14+
"weight_decay": 0.1,
15+
"model_type": "eeg-downstream-classifier"
16+
}

configs/Stress_config.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"seq_len": 1000,
3+
"patch_size": 150,
4+
"stride": 100,
5+
"d_model": 512,
6+
"num_classes": 2,
7+
"num_channel": 20,
8+
"rms_norm": false,
9+
"embedding_dim": 512,
10+
"projection_embedding_dim": 512,
11+
"classification_dropout": 0.5,
12+
"classification_hidden_dim": 512,
13+
"learning_rate": 1e-4,
14+
"weight_decay": 0.1,
15+
"model_type": "eeg-downstream-classifier"
16+
}

configs/TUAB_config.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"seq_len": 4000,
3+
"patch_size": 150,
4+
"stride": 100,
5+
"d_model": 512,
6+
"num_classes": 2,
7+
"num_channel": 20,
8+
"rms_norm": false,
9+
"embedding_dim": 512,
10+
"projection_embedding_dim": 512,
11+
"classification_dropout": 0.5,
12+
"classification_hidden_dim": 512,
13+
"learning_rate": 1e-4,
14+
"weight_decay": 0.1,
15+
"model_type": "eeg-downstream-classifier"
16+
}

configs/pretrain_config.json

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"activation": "silu",
3+
"attention_probs_dropout_prob": 0.1,
4+
"classifier_dropout": null,
5+
"cls_token": "[CLS]",
6+
"cls_token_id": 2003,
7+
"hidden_act": "gelu",
8+
"hidden_dropout_prob": 0.1,
9+
"hidden_size": 768,
10+
"initializer_range": 0.02,
11+
"intermediate_size": 3072,
12+
"layer_norm_eps": 1e-12,
13+
"mask_token": "[MASK]",
14+
"mask_token_id": 2001,
15+
"max_position_embeddings": 2000,
16+
"model_type": "bert",
17+
"num_attention_heads": 12,
18+
"num_hidden_layers": 12,
19+
"pad_token": "[PAD]",
20+
"pad_token_id": 2002,
21+
"position_embedding_type": "rotary",
22+
"rope_theta": 10000,
23+
"rotary_dim": 64,
24+
"sep_token": "[SEP]",
25+
"sep_token_id": 2004,
26+
"transformers_version": "4.48.0",
27+
"type_vocab_size": 2,
28+
"unk_token": "[UNK]",
29+
"unk_token_id": 2005,
30+
"use_cache": true,
31+
"use_flash_attention": true,
32+
"vocab_size": 2006
33+
}
34+

demo_data/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
EEG Dataset Demo
2+
Overview
3+
This dataset demo contains synthetic EEG (electroencephalogram) data designed for testing and development purposes. It includes two main data files with different formats commonly used in EEG signal processing and machine learning applications.
4+
5+
Data Structure
6+
1. eeg_data.npy
7+
Shape: (50, 2000)
8+
Description: Contains 50 samples of time-series data, each with 2000 time points
9+
10+
11+
2. dictionary_data.npz
12+
signal: Time-series data with shape (1, 2000)
13+
embedding: Embedding vector with shape (1, 512)
14+
15+

demo_data/downstream_eeg_data.npz

10.3 KB
Binary file not shown.

demo_data/eeg_data.npy

391 KB
Binary file not shown.

photos/Main_fig1.png

2.46 MB
Loading

0 commit comments

Comments
 (0)