Skip to content

Commit 8a018b6

Browse files
committed
Minor Updates
1 parent 747fe20 commit 8a018b6

File tree

4 files changed

+338
-6
lines changed

4 files changed

+338
-6
lines changed

environment.yaml

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
name: sae-exp
2+
channels:
3+
- pytorch
4+
- nvidia
5+
- anaconda
6+
- defaults
7+
dependencies:
8+
- pip
9+
- _libgcc_mutex=0.1=main
10+
- _openmp_mutex=5.1=1_gnu
11+
- argon2-cffi-bindings=21.2.0=py310h7f8727e_0
12+
- async-lru=2.0.4=py310h06a4308_0
13+
- backcall=0.2.0=pyhd3eb1b0_0
14+
- beautifulsoup4=4.12.3=py310h06a4308_0
15+
- blas=1.0=mkl
16+
- blosc=1.21.3=h6a678d5_0
17+
- brotli-python=1.0.9=py310h6a678d5_8
18+
- bzip2=1.0.8=h5eee18b_6
19+
- c-ares=1.19.1=h5eee18b_0
20+
- c-blosc2=2.12.0=h80c7b02_0
21+
- ca-certificates=2024.7.2=h06a4308_0
22+
- charset-normalizer=3.3.2=pyhd3eb1b0_0
23+
- cuda-cudart=11.8.89=0
24+
- decorator=5.1.1=pyhd3eb1b0_0
25+
- defusedxml=0.7.1=pyhd3eb1b0_0
26+
- faiss-gpu=1.8.0=py3.10_hedc54c9_0_cuda11.4.4
27+
- h11=0.14.0=py310h06a4308_0
28+
- hdf5=1.12.1=h2b7332f_3
29+
- intel-openmp=2023.1.0=hdb19cb5_46306
30+
- ipykernel=6.28.0=py310h06a4308_0
31+
- jinja2=3.1.4=py310h06a4308_0
32+
- jupyter_client=8.6.0=py310h06a4308_0
33+
- jupyter_core=5.7.2=py310h06a4308_0
34+
- jupyter_events=0.10.0=py310h06a4308_0
35+
- jupyter_server=2.14.1=py310h06a4308_0
36+
- jupyter_server_terminals=0.4.4=py310h06a4308_1
37+
- jupyterlab_pygments=0.1.2=py_0
38+
- jupyterlab_server=2.27.3=py310h06a4308_0
39+
- krb5=1.20.1=h143b758_1
40+
- ld_impl_linux-64=2.38=h1181459_1
41+
- libcublas=11.11.3.6=0
42+
- libcurl=8.7.1=h251f7ec_0
43+
- libedit=3.1.20230828=h5eee18b_0
44+
- libev=4.33=h7f8727e_1
45+
- libfaiss=1.8.0=h5aaf3ed_0_cuda11.4.4
46+
- libffi=3.4.4=h6a678d5_1
47+
- libgcc-ng=11.2.0=h1234567_1
48+
- libgfortran-ng=11.2.0=h00389a5_1
49+
- libgfortran5=11.2.0=h1234567_1
50+
- libgomp=11.2.0=h1234567_1
51+
- libnghttp2=1.57.0=h2d74bed_0
52+
- libsodium=1.0.18=h7b6447c_0
53+
- libssh2=1.11.0=h251f7ec_0
54+
- libstdcxx-ng=11.2.0=h1234567_1
55+
- libuuid=1.41.5=h5eee18b_0
56+
- lz4-c=1.9.4=h6a678d5_1
57+
- lzo=2.10=h7b6447c_2
58+
- matplotlib-inline=0.1.6=py310h06a4308_0
59+
- mkl=2023.1.0=h213fc3f_46344
60+
- mkl-service=2.4.0=py310h5eee18b_1
61+
- mkl_fft=1.3.8=py310h5eee18b_0
62+
- mkl_random=1.2.4=py310hdb19cb5_0
63+
- nbformat=5.9.2=py310h06a4308_0
64+
- ncurses=6.4=h6a678d5_0
65+
- nest-asyncio=1.6.0=py310h06a4308_0
66+
- numexpr=2.8.7=py310h85018f9_0
67+
- numpy=1.26.4=py310h5f9d8c6_0
68+
- numpy-base=1.26.4=py310hb5e798b_0
69+
- openssl=3.0.15=h5eee18b_0
70+
- packaging=24.1=py310h06a4308_0
71+
- pickleshare=0.7.5=pyhd3eb1b0_1003
72+
- prometheus_client=0.14.1=py310h06a4308_0
73+
- prompt_toolkit=3.0.43=hd3eb1b0_0
74+
- ptyprocess=0.7.0=pyhd3eb1b0_2
75+
- pure_eval=0.2.2=pyhd3eb1b0_0
76+
- py-cpuinfo=9.0.0=py310h06a4308_0
77+
- pysocks=1.7.1=py310h06a4308_0
78+
- pytables=3.9.2=py310h0016290_0
79+
- python=3.10.14=h955ad1f_1
80+
- python-dateutil=2.9.0post0=py310h06a4308_2
81+
- python-fastjsonschema=2.16.2=py310h06a4308_0
82+
- python-json-logger=2.0.7=py310h06a4308_0
83+
- pytz=2024.1=py310h06a4308_0
84+
- pyyaml=6.0.1=py310h5eee18b_0
85+
- readline=8.2=h5eee18b_0
86+
- requests=2.32.3=py310h06a4308_0
87+
- rfc3339-validator=0.1.4=py310h06a4308_0
88+
- rfc3986-validator=0.1.1=py310h06a4308_0
89+
- setuptools=69.5.1=py310h06a4308_0
90+
- six=1.16.0=pyhd3eb1b0_1
91+
- soupsieve=2.5=py310h06a4308_0
92+
- sqlite=3.45.3=h5eee18b_0
93+
- stack_data=0.2.0=pyhd3eb1b0_0
94+
- tbb=2021.8.0=hdb19cb5_0
95+
- tk=8.6.14=h39e8969_0
96+
- tomli=2.0.1=py310h06a4308_0
97+
- tornado=6.4.1=py310h5eee18b_0
98+
- traitlets=5.14.3=py310h06a4308_0
99+
- typing_extensions=4.11.0=py310h06a4308_0
100+
- urllib3=2.2.2=py310h06a4308_0
101+
- websocket-client=1.8.0=py310h06a4308_0
102+
- wheel=0.43.0=py310h06a4308_0
103+
- xz=5.4.6=h5eee18b_1
104+
- yaml=0.2.5=h7b6447c_0
105+
- zeromq=4.3.5=h6a678d5_0
106+
- zlib=1.2.13=h5eee18b_1
107+
- zlib-ng=2.0.7=h5eee18b_0
108+
- zstd=1.5.5=hc292b87_2
109+
- pip:
110+
- accelerate==1.1.1
111+
- adjusttext==1.3.0
112+
- aiohttp==3.9.5
113+
- aiosignal==1.3.1
114+
- annotated-types==0.7.0
115+
- anyio==4.4.0
116+
- argon2-cffi==23.1.0
117+
- arrow==1.3.0
118+
- asttokens==2.4.1
119+
- async-timeout==4.0.3
120+
- attrs==23.2.0
121+
- automated-interpretability==0.0.6
122+
- babe==0.0.7
123+
- babel==2.15.0
124+
- beartype==0.14.1
125+
- better-abc==0.0.3
126+
- bidict==0.23.1
127+
- bitsandbytes==0.42.0
128+
- black==23.11.0
129+
- bleach==6.1.0
130+
- blobfile==2.1.1
131+
- boostedblob==0.15.3
132+
- certifi==2024.6.2
133+
- cffi==1.16.0
134+
- cfgv==3.4.0
135+
- chardet==3.0.4
136+
- click==8.1.7
137+
- comm==0.2.2
138+
- config2py==0.1.33
139+
- contourpy==1.2.1
140+
- coverage==7.5.4
141+
- cycler==0.12.1
142+
- dataclasses-json==0.6.7
143+
- datasets==2.21.0
144+
- debugpy==1.8.2
145+
- diffusers==0.30.3
146+
- dill==0.3.8
147+
- distlib==0.3.8
148+
- distro==1.9.0
149+
- docker-pycreds==0.4.0
150+
- docstring-parser==0.16
151+
- dol==0.2.49
152+
- eindex-callum==0.1.1
153+
- einops==0.7.0
154+
- exceptiongroup==1.2.1
155+
- executing==2.0.1
156+
- fancy-einsum==0.0.3
157+
- fastjsonschema==2.20.0
158+
- filelock==3.15.4
159+
- flake8==7.0.0
160+
- fonttools==4.53.0
161+
- fqdn==1.5.1
162+
- frozenlist==1.4.1
163+
- fsspec==2024.5.0
164+
- gitdb==4.0.11
165+
- gitpython==3.1.43
166+
- googletrans==3.0.0
167+
- gprof2dot==2024.6.6
168+
- graze==0.1.17
169+
- h2==3.2.0
170+
- h5py==3.11.0
171+
- hpack==3.0.0
172+
- hstspreload==2024.10.1
173+
- httpcore==1.0.5
174+
- httpx==0.27.2
175+
- huggingface-hub==0.24.7
176+
- hyperframe==5.2.0
177+
- i2==0.1.17
178+
- identify==2.5.36
179+
- idna==2.10
180+
- importlib-metadata==8.5.0
181+
- importlib-resources==6.4.0
182+
- iniconfig==2.0.0
183+
- ipython==8.26.0
184+
- ipywidgets==8.1.3
185+
- isoduration==20.11.0
186+
- isort==5.13.2
187+
- jaxtyping==0.2.36
188+
- jedi==0.19.1
189+
- jiter==0.5.0
190+
- joblib==1.4.2
191+
- json5==0.9.25
192+
- jsonpointer==3.0.0
193+
- jsonschema==4.22.0
194+
- jsonschema-specifications==2023.12.1
195+
- jupyter==1.1.1
196+
- jupyter-client==8.6.2
197+
- jupyter-console==6.6.3
198+
- jupyter-lsp==2.2.5
199+
- jupyter-server-terminals==0.5.3
200+
- jupyterlab==4.2.3
201+
- jupyterlab-pygments==0.3.0
202+
- jupyterlab-server==2.27.2
203+
- jupyterlab-widgets==3.0.11
204+
- kiwisolver==1.4.5
205+
- lxml==4.9.4
206+
- markdown-it-py==3.0.0
207+
- markupsafe==2.1.5
208+
- marshmallow==3.21.3
209+
- matplotlib==3.9.0
210+
- mccabe==0.7.0
211+
- mdurl==0.1.2
212+
- mistune==3.0.2
213+
- mpmath==1.3.0
214+
- multidict==6.0.5
215+
- multiprocess==0.70.16
216+
- mypy-extensions==1.0.0
217+
- natsort==8.4.0
218+
- nbclient==0.10.0
219+
- nbconvert==7.16.4
220+
- networkx==3.3
221+
- nltk==3.8.1
222+
- nnsight==0.3.6
223+
- nodeenv==1.9.1
224+
- notebook==7.2.1
225+
- notebook-shim==0.2.4
226+
- nvidia-cublas-cu12==12.1.3.1
227+
- nvidia-cuda-cupti-cu12==12.1.105
228+
- nvidia-cuda-nvrtc-cu12==12.1.105
229+
- nvidia-cuda-runtime-cu12==12.1.105
230+
- nvidia-cudnn-cu12==9.1.0.70
231+
- nvidia-cufft-cu12==11.0.2.54
232+
- nvidia-curand-cu12==10.3.2.106
233+
- nvidia-cusolver-cu12==11.4.5.107
234+
- nvidia-cusparse-cu12==12.1.0.106
235+
- nvidia-nccl-cu12==2.20.5
236+
- nvidia-nvjitlink-cu12==12.5.40
237+
- nvidia-nvtx-cu12==12.1.105
238+
- openai==1.42.0
239+
- orjson==3.10.5
240+
- overrides==7.7.0
241+
- pandas==2.2.2
242+
- pandocfilters==1.5.1
243+
- parso==0.8.4
244+
- pathspec==0.12.1
245+
- patsy==0.5.6
246+
- pexpect==4.9.0
247+
- pillow==10.3.0
248+
- pip==24.3.1
249+
- platformdirs==4.2.2
250+
- plotly==5.22.0
251+
- plotly-express==0.4.1
252+
- pluggy==1.5.0
253+
- pre-commit==3.6.0
254+
- prometheus-client==0.20.0
255+
- prompt-toolkit==3.0.47
256+
- protobuf==5.27.2
257+
- psutil==6.0.0
258+
- py2store==0.1.20
259+
- pyarrow==16.1.0
260+
- pyarrow-hotfix==0.6
261+
- pycocotools==2.0.8
262+
- pycodestyle==2.11.1
263+
- pycparser==2.22
264+
- pycryptodomex==3.20.0
265+
- pydantic==2.8.2
266+
- pydantic-core==2.20.1
267+
- pyflakes==3.2.0
268+
- pygments==2.18.0
269+
- pyparsing==3.1.2
270+
- pytest==8.2.2
271+
- pytest-cov==4.1.0
272+
- pytest-profiling==1.7.0
273+
- python-dotenv==1.0.1
274+
- python-engineio==4.9.1
275+
- python-socketio==5.11.4
276+
- pyzmq==26.0.3
277+
- qtconsole==5.5.2
278+
- qtpy==2.4.1
279+
- referencing==0.35.1
280+
- regex==2024.5.15
281+
- rfc3986==1.5.0
282+
- rich==13.7.1
283+
- rpds-py==0.18.1
284+
- sae-lens==5.2.0
285+
- sae-vis==0.2.19
286+
- safetensors==0.4.3
287+
- scikit-learn==1.5.0
288+
- scipy==1.14.0
289+
- seaborn==0.13.2
290+
- send2trash==1.8.3
291+
- sentence-transformers==3.1.0
292+
- sentencepiece==0.2.0
293+
- sentry-sdk==2.7.1
294+
- setproctitle==1.3.3
295+
- shellingham==1.5.4
296+
- simple-parsing==0.1.6
297+
- simple-websocket==1.0.0
298+
- smmap==5.0.1
299+
- sniffio==1.3.1
300+
- stack-data==0.6.3
301+
- statsmodels==0.14.2
302+
- sympy==1.12.1
303+
- tenacity==8.4.2
304+
- terminado==0.18.1
305+
- threadpoolctl==3.5.0
306+
- tiktoken==0.6.0
307+
- tinycss2==1.3.0
308+
- tokenizers==0.20.3
309+
- torch==2.4.1
310+
- torchaudio==2.4.0
311+
- torchvision==0.19.1
312+
- tqdm==4.66.4
313+
- transformer-lens==2.10.0
314+
- transformers==4.46.3
315+
- triton==3.0.0
316+
- typeguard==4.4.1
317+
- typer==0.12.3
318+
- types-python-dateutil==2.9.0.20240316
319+
- typing-extensions==4.12.2
320+
- typing-inspect==0.9.0
321+
- tzdata==2024.1
322+
- uri-template==1.3.0
323+
- uvloop==0.19.0
324+
- virtualenv==20.26.3
325+
- wandb==0.17.3
326+
- wcwidth==0.2.13
327+
- webcolors==24.6.0
328+
- webencodings==0.5.1
329+
- widgetsnbextension==4.0.11
330+
- wsproto==1.2.0
331+
- xxhash==3.4.1
332+
- yarl==1.9.4
333+
- zipp==3.20.2
334+
- zstandard==0.22.0
335+
prefix: /home/gridsan/dbaek/.conda/envs/sae-exp

src/run_exp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@
5555
'weight_decay':weight_decay
5656
}
5757

58-
results_root = "../results_1"
58+
results_root = "../results"
5959

6060
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
61-
results_root = f"{results_root}/{current_datetime}"
61+
results_root = f"{results_root}/{seed}-{data_id}-{model_id}"
6262
os.mkdir(results_root)
6363

6464
param_dict_json = {k: v for k, v in param_dict.items() if k != 'device'} # since torch.device is not JSON serializable

src/utils/FamilyTreeGenerator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
import h5py
2525
import multiprocessing
2626

27-
#from transformer_lens import *
28-
29-
from transformers import pipeline
3027

3128
from sklearn.decomposition import PCA
3229
from sklearn.utils import shuffle

src/utils/driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def train_single_model(param_dict: dict):
7979
elif data_id == "family_tree":
8080
dataset = family_tree_dataset_2(p=127, num=data_size, seed=seed, device=device)
8181
elif data_id == "equivalence":
82-
input_token = 1
82+
input_token = 2
8383
dataset = mod_equiv_dataset(p=50, num=data_size, seed=seed, device=device)
8484
elif data_id == "circle":
8585
dataset = modular_addition_dataset(p=31, num=data_size, seed=seed, device=device)

0 commit comments

Comments
 (0)