Skip to content

Commit 668bac6

Browse files
committed
Switch backend to TensorFlow
1 parent ccee4e5 commit 668bac6

File tree

8 files changed

+287
-112
lines changed

8 files changed

+287
-112
lines changed

common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import librosa as lbr
7-
import keras.backend as K
7+
import tensorflow.keras.backend as K
88

99
GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal',
1010
'pop', 'reggae', 'rock']
@@ -20,8 +20,8 @@
2020
def get_layer_output_function(model, layer_name):
2121
input = model.get_layer('input').input
2222
output = model.get_layer(layer_name).output
23-
f = K.function([input, K.learning_phase()], output)
24-
return lambda x: f([x, 0]) # learning_phase = 0 means test
23+
f = K.function([input, K.learning_phase()], [output])
24+
return lambda x: f([x, 0])[0] # learning_phase = 0 means test
2525

2626
def load_track(filename, enforce_shape=None):
2727
new_input, sample_rate = lbr.load(filename, mono=True)

create_data_pickle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import numpy as np
44
from math import pi
5-
from cPickle import dump
5+
from pickle import dump
66
import os
77
from optparse import OptionParser
88

@@ -30,10 +30,10 @@ def collect_data(dataset_path):
3030
track_paths = {}
3131

3232
for (genre_index, genre_name) in enumerate(GENRES):
33-
for i in xrange(TRACK_COUNT // len(GENRES)):
33+
for i in range(TRACK_COUNT // len(GENRES)):
3434
file_name = '{}/{}.000{}.au'.format(genre_name,
3535
genre_name, str(i).zfill(2))
36-
print 'Processing', file_name
36+
print('Processing', file_name)
3737
path = os.path.join(dataset_path, file_name)
3838
track_index = genre_index * (TRACK_COUNT // len(GENRES)) + i
3939
x[track_index], _ = load_track(path, default_shape)
@@ -55,5 +55,5 @@ def collect_data(dataset_path):
5555
(x, y, track_paths) = collect_data(options.dataset_path)
5656

5757
data = {'x': x, 'y': y, 'track_paths': track_paths}
58-
with open(options.output_pkl_path, 'w') as f:
58+
with open(options.output_pkl_path, 'wb') as f:
5959
dump(data, f)

extract_filters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from common import get_layer_output_function, WINDOW_SIZE, WINDOW_STRIDE
2-
from keras.models import model_from_yaml
2+
from tensorflow.keras.models import model_from_yaml
33
import librosa as lbr
44
import numpy as np
55
from functools import partial
@@ -56,7 +56,7 @@ def extract_filters(model, data, filters_path, count0):
5656
if not os.path.exists(layer_path):
5757
os.makedirs(layer_path)
5858

59-
print 'Computing outputs for layer', conv_layer_names[layer_index]
59+
print('Computing outputs for layer', conv_layer_names[layer_index])
6060
output = output_fun(x)
6161

6262
# matrices of shape n_tracks x time x n_filters
@@ -70,9 +70,9 @@ def extract_filters(model, data, filters_path, count0):
7070

7171
undoer = conv_layer_undoers[layer_index]
7272

73-
for filter_index in xrange(argmax_over_track.shape[1]):
74-
print 'Processing layer', conv_layer_names[layer_index], \
75-
'filter', filter_index
73+
for filter_index in range(argmax_over_track.shape[1]):
74+
print('Processing layer', conv_layer_names[layer_index], \
75+
'filter', filter_index)
7676

7777
track_indices = argmax_over_track[:, filter_index]
7878
time_indices = argmax_over_time[track_indices, filter_index]

genre_recognizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from common import load_track, get_layer_output_function
22
import numpy as np
3-
from keras.layers import Input
4-
from keras.models import model_from_yaml, Model
5-
from keras import backend as K
3+
from tensorflow.keras.layers import Input
4+
from tensorflow.keras.models import model_from_yaml, Model
5+
from tensorflow.keras import backend as K
66

77
class GenreRecognizer():
88

@@ -11,10 +11,10 @@ def __init__(self, model_path, weights_path):
1111
model = model_from_yaml(f.read())
1212
model.load_weights(weights_path)
1313
self.pred_fun = get_layer_output_function(model, 'output_realtime')
14-
print 'Loaded model.'
14+
print('Loaded model.')
1515

1616
def recognize(self, track_path):
17-
print 'Loading song', track_path
17+
print('Loading song', track_path)
1818
(features, duration) = load_track(track_path)
1919
features = np.reshape(features, (1,) + features.shape)
2020
return (self.pred_fun(features), duration)

0 commit comments

Comments
 (0)