Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,20 @@ def binary_to_net(weights, spm_stream, ind_stream, codebook, num_nz):

# Recover from binary stream
spm = np.zeros(num_nz, np.uint8)
ind = np.zeros(num_nz, np.uint8)
ind = np.zeros(num_nz if num_nz % 2 == 0 else num_nz + 1, np.uint8)
if slots == 2:
spm[np.arange(0, num_nz, 2)] = spm_stream % (2**4)
spm[np.arange(1, num_nz, 2)] = spm_stream / (2**4)
else:
spm = spm_stream
ind[np.arange(0, num_nz, 2)] = ind_stream% (2**4)
ind[np.arange(1, num_nz, 2)] = ind_stream/ (2**4)
ind[np.arange(1, num_nz if num_nz % 2 == 0 else num_nz + 1, 2)] = ind_stream/ (2**4)


# Recover the matrix
ind = np.cumsum(ind+1)-1
if num_nz % 2 == 1:
ind = ind[:-1]
code[ind] = spm
data = np.reshape(codebook[code], weights.shape)
np.copyto(weights, data)
Expand All @@ -112,12 +114,21 @@ def binary_to_net(weights, spm_stream, ind_stream, codebook, num_nz):
bits = 4
codebook_size = 2 ** bits
codebook = np.fromfile(fin, dtype = np.float32, count = codebook_size)
bias = np.fromfile(fin, dtype = np.float32, count = net.params[layer][1].data.size)
np.copyto(net.params[layer][1].data, bias)

# we can't access "net.params[layer][index]" directly
# it is neccesary to use __iter__ or either a loop:
for num, param in enumerate(net.params[layer]):
if num == 0:
net_data = param.data
else:
bias_data = param.data

bias = np.fromfile(fin, dtype = np.float32, count = bias_data.size)
np.copyto(bias_data, bias)

spm_stream = np.fromfile(fin, dtype = np.uint8, count = (nz_num[idx]-1) / (8/bits) + 1)
ind_stream = np.fromfile(fin, dtype = np.uint8, count = (nz_num[idx]-1) / 2+1)

binary_to_net(net.params[layer][0].data, spm_stream, ind_stream, codebook, nz_num[idx])
binary_to_net(net_data, spm_stream, ind_stream, codebook, nz_num[idx])

net.save(target)