Skip to content
This repository was archived by the owner on Jan 26, 2024. It is now read-only.

Commit 0e88d2c

Browse files
committed
Initial commit
0 parents  commit 0e88d2c

File tree

18 files changed

+1117
-0
lines changed

18 files changed

+1117
-0
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.vscode
2+
*.pyc
3+
.DS_Store
4+
*.mid
5+
*.midi
6+
*.sess

README.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Performance RNN - PyTorch
2+
3+
PyTorch implementation of Performance RNN, inspired by
4+
[https://magenta.tensorflow.org/performance-rnn](https://magenta.tensorflow.org/performance-rnn).
5+
6+
This model is not implemented in the official way!
7+
8+
9+
## Directory Structure
10+
11+
```
12+
.
13+
├── dataset/
14+
│   ├── midi/
15+
│   │ ├── dataset1/
16+
│   │   │ └── *.mid
17+
│   │ └── dataset2/
18+
│   │   └── *.mid
19+
│   ├── processed/
20+
│   │ └── dataset1/
21+
│   │   └── *.data (generated with preprocess.py)
22+
│   └── scripts/
23+
│   └── *.sh (dataset download scripts)
24+
├── generated/
25+
│   └── *.mid (generated with generate.py)
26+
└── runs/ (tensorboard logdir)
27+
```
28+
29+
30+
## Getting Started
31+
32+
- Download datasets
33+
```
34+
cd dataset/
35+
bash scripts/NAME_scraper.sh midi/NAME
36+
```
37+
38+
- Preprocessing
39+
```shell
40+
# Will preprocess all MIDI files under dataset/midi/NAME
41+
python3 preprocess.py dataset/midi/NAME dataset/processed/NAME
42+
```
43+
44+
- Training
45+
```shell
46+
# Train on .data files in dataset/processed/MYDATA,
47+
# and save to myModel.sess every 10s.
48+
python3 train.py -s myModel.sess -d dataset/processed/MYDATA -i 10
49+
```
50+
51+
- Generating
52+
```shell
53+
py generate.py \
54+
myModel.sess \ # load trained model from myModel.sess
55+
generated/ \ # save to generated/
56+
10 \ # generate 10 event sequences
57+
2000 \ # generate 2000 event steps
58+
0.9 \ # 90% sampling with argmax and 10% multinomial
59+
'1,0,1,0,1,1,0,1,0,1,0,1' \ # pitch histogram ([12] or [0])
60+
3 # note density (0-5)
61+
```
62+
63+
# Requirements
64+
65+
```
66+
pretty_midi
67+
numpy
68+
pytorch
69+
tensorboardX
70+
progress
71+
```

config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch
2+
from sequence import EventSeq, ControlSeq
3+
4+
#pylint: disable=E1101
5+
6+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7+
8+
model = {
9+
'init_dim': 32,
10+
'event_dim': EventSeq.dim(),
11+
'control_dim': ControlSeq.dim(),
12+
'hidden_dim': 512,
13+
'gru_layers': 3,
14+
'gru_dropout': 0.3,
15+
}
16+
17+
train = {
18+
'learning_rate': 0.001,
19+
'batch_size': 64,
20+
'window_size': 200,
21+
'stride_size': 10,
22+
'use_transposition': False,
23+
}

data.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
import torch
3+
import itertools
4+
import numpy as np
5+
import random
6+
7+
import config
8+
import utils
9+
from sequence import EventSeq, ControlSeq
10+
11+
# pylint: disable=E1101
12+
# pylint: disable=W0101
13+
14+
class Dataset:
15+
def __init__(self, root):
16+
paths = utils.find_files_by_extensions(root, ['.data'])
17+
self.samples = []
18+
self.seqlens = []
19+
for path in paths:
20+
for eventseq, controlseq in torch.load(path):
21+
controlseq = ControlSeq.recover_compressed_array(controlseq)
22+
self.samples.append((eventseq, controlseq))
23+
for eventseq, controlseq in self.samples:
24+
assert len(eventseq) == len(controlseq)
25+
self.seqlens.append(len(eventseq))
26+
27+
def batches(self, batch_size, window_size, stride_size):
28+
indeces = [(i, range(j, j + window_size))
29+
for i, seqlen in enumerate(self.seqlens)
30+
for j in range(0, seqlen - window_size, stride_size)]
31+
while True:
32+
eventseq_batch = []
33+
controlseq_batch = []
34+
n = 0
35+
for ii in np.random.permutation(len(indeces)):
36+
i, r = indeces[ii]
37+
eventseq, controlseq = self.samples[i]
38+
eventseq = eventseq[r.start:r.stop]
39+
controlseq = controlseq[r.start:r.stop]
40+
eventseq_batch.append(eventseq)
41+
controlseq_batch.append(controlseq)
42+
n += 1
43+
if n == batch_size:
44+
yield (np.stack(eventseq_batch, axis=1),
45+
np.stack(controlseq_batch, axis=1))
46+
eventseq_batch.clear()
47+
controlseq_batch.clear()
48+
n = 0

dataset/midi/.keep

Whitespace-only changes.

dataset/processed/.keep

Whitespace-only changes.

dataset/scripts/.keep

Whitespace-only changes.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
# Scraper for Classical Piano Midi Page
3+
[ ! "$1" ] && echo 'Error: please specify output dir' && exit
4+
dir=$1
5+
base=http://www.piano-midi.de
6+
pages=$(curl -s --max-time 5 $base/midi_files.htm \
7+
| grep '<tr class="midi"><td class="midi"><a href="' \
8+
| egrep '[^"]+\.htm' -o)
9+
echo Pages: $pages
10+
mkdir -p $dir
11+
for page in $pages; do
12+
midis=$(curl -s --max-time 5 $base/$page | egrep '[^"]+format0\.mid' -o)
13+
for midi in $midis; do
14+
echo "http://www.piano-midi.de/$midi"
15+
done | tee /dev/stderr | wget -P $dir -i -
16+
done
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
# Scraper for Yamaha e-Piano Competition dataset
3+
[ ! "$1" ] && echo 'Error: please specify output dir' && exit
4+
dir=$1
5+
pages='http://www.piano-e-competition.com/ecompetition/midi_2002.asp
6+
http://www.piano-e-competition.com/ecompetition/midi_2004.asp
7+
http://www.piano-e-competition.com/ecompetition/midi_2006.asp
8+
http://www.piano-e-competition.com/ecompetition/midi_2008.asp
9+
http://www.piano-e-competition.com/ecompetition/midi_2009.asp
10+
http://www.piano-e-competition.com/ecompetition/midi_20011.asp
11+
'
12+
mkdir -p $dir
13+
for page in $pages; do
14+
for midi in $(curl -s $page | egrep -i '[^"]+\.mid' -o | sed 's/^\/*/\//g'); do
15+
echo "http://www.piano-e-competition.com$midi"
16+
done
17+
done | wget -P $dir -i -
18+
rm -f $dir/*.{1,2,3,4,5}

dataset/scripts/touhou_scraper.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
[ ! "$1" ] && echo 'Error: please specify output dir' && exit
3+
dir=$1
4+
for url in $(curl -s https://thwiki.cc/%E5%88%86%E7%B1%BB:%E5%AE%98%E6%96%B9MIDI \
5+
| egrep -o '[^"]+?\.mid' \
6+
| egrep '^/' \
7+
| sed 's/^/https:\/\/thwiki.cc/g' \
8+
| uniq);
9+
do url=$(curl -s "$url" \
10+
| egrep -o '[^"]+?\.mid' \
11+
| egrep '^/' \
12+
| grep -v '%' \
13+
| sed 's/^/https:/g');
14+
echo $url | tee /dev/stderr
15+
done | uniq | wget -P $dir -i -

0 commit comments

Comments
 (0)