MNIST
When one learns how to program, there's a tradition that the first thing you do is print "Hello World." Just like programming has Hello World, machine learning has MNIST.
The MNIST Database
The MNIST database of handwritten digits, available from this site, has a training set of 60,000 examples about 1*28*28-pixels images, and a test set of 10,000 examples. Each image corresponds one label between '0' to '9', totally 10 labels.
You can follow this site to download the four files:
- train-images-idx3-ubyte.gz: training set images (9912422 bytes)
- train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
- t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
- t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
We usually to store:
train-images-idx3-ubyte.gzrespectively example 1 ~ 50,000 inX_trainvariable which is 50,000 * 1 * 28 * 28 dimensions, and 50,001 ~ 60,000 inX_valvariable which is 10,000 * 1 * 28 * 28 dimensions,train-labels-idx1-ubyte.gzrespectively example 1 ~ 50,000 iny_trainvariable which is 50,000 * 1 dimensions, and 50,001 ~ 60,000 iny_valvariable which is 10,000 * 1 dimensions,t10k-images-idx3-ubyte.gzinX_testvariable which is 10,000 * 1 * 28 * 28 dimensions, andt10k-labels-idx1-ubyte.gziny_testvariable which is 10,000 * 1 dimensions.
How to download
Using python, we can download files via urllib and gzip packages:
# in python 2
from urllib import urlretrieve
## in python 3, instead of:
# from urllib.request import urlretrieve
import gzip
import numpy as np
source='http://yann.lecun.com/exdb/mnist/'
filename='train-images-idx3-ubyte.gz'
urlretrieve(source + filename, filename)
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
return data
Because the inputs are vectors, we can reshape them to monochrome 2D images (E.g. the training images), following the shape convention: (examples, channels, rows, columns):
data.reshape(-1,1,28,28)
# reshape(-1,-,-,-) means no care about how many eamples you yield
X_train = data[:50001] # i.e., X_train = data[:-10000]
X_var = data[50001:] # i.e., X_var = data[-10000:]