MNIST
When one learns how to program, there's a tradition that the first thing you do is print "Hello World." Just like programming has Hello World, machine learning has MNIST.
The MNIST Database
The MNIST database of handwritten digits, available from this site, has a training set of 60,000 examples about 1*28*28-pixels images, and a test set of 10,000 examples. Each image corresponds one label between '0' to '9', totally 10 labels.
You can follow this site to download the four files:
- train-images-idx3-ubyte.gz: training set images (9912422 bytes)
- train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
- t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
- t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
We usually to store:
train-images-idx3-ubyte.gz
respectively example 1 ~ 50,000 inX_train
variable which is 50,000 * 1 * 28 * 28 dimensions, and 50,001 ~ 60,000 inX_val
variable which is 10,000 * 1 * 28 * 28 dimensions,train-labels-idx1-ubyte.gz
respectively example 1 ~ 50,000 iny_train
variable which is 50,000 * 1 dimensions, and 50,001 ~ 60,000 iny_val
variable which is 10,000 * 1 dimensions,t10k-images-idx3-ubyte.gz
inX_test
variable which is 10,000 * 1 * 28 * 28 dimensions, andt10k-labels-idx1-ubyte.gz
iny_test
variable which is 10,000 * 1 dimensions.
How to download
Using python, we can download files via urllib
and gzip
packages:
# in python 2
from urllib import urlretrieve
## in python 3, instead of:
# from urllib.request import urlretrieve
import gzip
import numpy as np
source='http://yann.lecun.com/exdb/mnist/'
filename='train-images-idx3-ubyte.gz'
urlretrieve(source + filename, filename)
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
return data
Because the inputs are vectors, we can reshape them to monochrome 2D images (E.g. the training images), following the shape convention: (examples, channels, rows, columns):
data.reshape(-1,1,28,28)
# reshape(-1,-,-,-) means no care about how many eamples you yield
X_train = data[:50001] # i.e., X_train = data[:-10000]
X_var = data[50001:] # i.e., X_var = data[-10000:]