-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
87 lines (66 loc) · 2.74 KB
/
__init__.py
File metadata and controls
87 lines (66 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#generic program to download the mnist dataset into our directory
import os
import functools
import operator
import gzip
import struct
import array
import numpy as np
from urllib.request import urlretrieve
import sys
DATASET_DIRECTORY = 'data/'
URL = 'http://yann.lecun.com/exdb/mnist/'
def parse_idx(fd):
DATA_TYPES = {0x08: 'B',
0x09: 'b',
0x0b: 'h',
0x0c: 'i',
0x0d: 'f',
0x0e: 'd'}
header = fd.read(4)
if len(header) != 4:
raise IdxDecodeError('Invalid IDX file, file empty or does not contain a full header.')
zeros, data_type, num_dimensions = struct.unpack('>HBB', header)
if zeros != 0:
raise IdxDecodeError('Invalid IDX file, file must start with two zero bytes. '
'Found 0x%02x' % zeros)
try:
data_type = DATA_TYPES[data_type]
except KeyError:
raise IdxDecodeError('Unknown data type 0x%02x in IDX file' % data_type)
dimension_sizes = struct.unpack('>' + 'I' * num_dimensions,
fd.read(4 * num_dimensions))
data = array.array(data_type, fd.read())
data.byteswap()
expected_items = functools.reduce(operator.mul, dimension_sizes)
if len(data) != expected_items:
raise IdxDecodeError('IDX file has wrong number of items. '
'Expected: %d. Found: %d' % (expected_items, len(data)))
return np.array(data).reshape(dimension_sizes)
def print_download_progress(count, block_size, total_size):
pct_complete = int(count * block_size * 100 / total_size)
pct_complete = min(pct_complete, 100)
msg = "\r- Download progress: %d" % (pct_complete) + "%"
sys.stdout.write(msg)
sys.stdout.flush()
def download_and_parse_mnist_file(fname, target_dir=None, force=False):
if not os.path.exists(DATASET_DIRECTORY):
os.makedirs(DATASET_DIRECTORY)
if not os.path.exists(DATASET_DIRECTORY+fname):
print('Downloading '+fname)
file_path = os.path.join(DATASET_DIRECTORY, fname)
url = URL + fname
file_path, _ = urlretrieve(url=url, filename=file_path, reporthook=print_download_progress)
print("\nDownload finished.")
fname = 'data/' + fname
fopen = gzip.open if os.path.splitext(fname)[1] == '.gz' else open
with fopen(fname, 'rb') as fd:
return parse_idx(fd)
def train_images():
return download_and_parse_mnist_file('train-images-idx3-ubyte.gz')
def test_images():
return download_and_parse_mnist_file('t10k-images-idx3-ubyte.gz')
def train_labels():
return download_and_parse_mnist_file('train-labels-idx1-ubyte.gz')
def test_labels():
return download_and_parse_mnist_file('t10k-labels-idx1-ubyte.gz')