MNIST dataset is a handwritten digits images and common used in tensorflow applications. In this tutorial, we will discuss this dataset for tensorflow beginners in order to help them to use it correctly.
Data in MNIST dataset
MNIST dataset contains three parts:
Train data (mnist.train): It contains 55000 images data and lables.
We can use train data to train our model.
Validation data (mnist.validation): It contains 5000 images and labels.
We can use this data to adjust our hyperparameters in our model.
Test data (mnist.test): It contains 10000 images and labels.
We can use test data to validate our effect of our model.
We can find:
train : validation : test = 55000 : 5000 : 10000 = 11 : 1 : 2
How to read mnist dataset in tensorflow?
We can use input_data() function to load, here is an example:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os # Get and load Mnist Data mnist = input_data.read_data_sets(os.getcwd() + "/MNIST-data/", one_hot= True)
Run this script, we will find mnist dataset in MNIST-data folder. It contains four files.
Print mnist.train, mnist.validation and mnist.test
print("mnist train data") print(mnist.train) print("mnist validation data") print(mnist.validation) print("mnist test data") print(mnist.test)
The result is:
mnist train data <tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x000001F40B10FD68> mnist validation data <tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x000001F413A5CF28> mnist test data <tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x000001F413A5CF60>
From the result, we can find mnist.train, mnist.validation and mnist.test is tensorflow DataSet object.
Check the data type and dimension of mnist trian, validation and test images and labels
As to mnist train data, we print some information on train images and labels data.
mnist_train_images = mnist.train.images print("mnist train images") print(type(mnist_train_images)) print(mnist_train_images.shape) print(mnist_train_images) mnist_train_labels = mnist.train.labels print("mnist train labels") print(type(mnist_train_labels)) print(mnist_train_labels.shape) print(mnist_train_labels)
From the result, we can find:
1. We can use mnist.train.images to get images data and mnist.train.labels to get image labels data.
2.The data type of images and labels in mnist.train is numpy.ndarry
3.The shape of mnist train images is: 55000 * 784, which means mnist.train contains 55000 images and 55000 labels.
4.Each image data is 1*784 and each label is 1*10
As to mnist validation and test data, we also can print them with the same way.
Read mnist train/validation/test batch data
In tensorflow, we often read batch data to train, validate or test our model. To read batch data, we can use next_batch(batch_num) function.
As to read 64 batch test data.
test_images_batch, test_labels_batch = mnist.test.next_batch(64) print(type(test_images_batch)) print(test_images_batch.shape) print(type(test_labels_batch)) print(test_labels_batch.shape)
The print result is:
<class 'numpy.ndarray'> (64, 784) <class 'numpy.ndarray'> (64, 10)
From the result, we can find we have read 64 * 784 image data, which contains 64 images, meanwhile, we also read 64 * 10 labels data, which contains 64 labels.