In this tutorial, we will introduce how to load mnist dataset for training using pytorch. It is very useful for pytorch beginners.
Preliminary
We can use torchvision to load mnist dataset in pytorch. It has created a MNIST class for us to load data.
CLASStorchvision.datasets.MNIST(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
Here:
root: string – Root directory of dataset where MNIST/raw/train-images-idx3-ubyte and MNIST/raw/t10k-images-idx3-ubyte exist.
train: True, creates dataset from train-images-idx3-ubyte, otherwise from t10k-images-idx3-ubyte
transform and target_transform: it can determine how to return image and target.
In MNIST class, we can find how these two callable function to work.
In pytorch, we usually transform image and target to tensors.
Use torchvision to load mnist data
We will import some libraries and start to load.
Here is the example code:
from torchvision import datasets, transforms from torch.utils.data import DataLoader train_dt = datasets.MNIST( root = 'data', train = True, transform = transforms.ToTensor(), download = True, ) test_dt = datasets.MNIST( root = 'data', train = False, transform = transforms.ToTensor(), download= True ) print(type(train_dt)) print(type(test_dt))
Run this code, we will see:
Here we use transforms.ToTensor() to transform images and targets to pytorch tensors.
Then we can find a data directory as follows:
Then, we can start to iterate all images.
Read mnist images and labels with batch size
We will create a DataLoader to read train and test set.
Here is an example:
batch_size = 32 gen_train = DataLoader(dataset=train_dt, batch_size=batch_size, shuffle=True, num_workers=0) for iteration, batch in enumerate(gen_train): print(iteration, type(batch)) print(len(batch)) print(batch[0].shape, batch[1].shape)
In this example code, we will iterate images in train set with batch_size = 32.
Run this code, we will see:
0 <class 'list'> 2 torch.Size([32, 1, 28, 28]) torch.Size([32]) 1 <class 'list'> 2 torch.Size([32, 1, 28, 28]) torch.Size([32]) 2 <class 'list'> 2 torch.Size([32, 1, 28, 28]) torch.Size([32])
Then we can start to use train set to train your pytorch model.