If you plan to use pytorch to load data to train a model, you should use a Dataset class to load. In this tutorial, we will introduce you how to create a custom Dataset class for loading.
How to create a custom Dataset?
Here is the basic structure of a custom Dataset class.
from torch.utils.data import dataset class CustomDataset(dataset.Dataset): def __init__(self): super(CustomDataset, self).__init__() def __getitem__(self, index): return None def __len__(self): return 0
Here CustomDataset is our custom Dataset class, it is a child class pytorch dataset.Dataset.
We should overwrite three basic methods: __init__(), __getitem__() and __len__().
Here:
__getiitem__(): we will get a row by index
__len__(): it means how many rows in our dataset.
For example:
class CustomDataset(dataset.Dataset): def __init__(self): super(CustomDataset, self).__init__() # load all data for training or test self.all_data = [i for i in range(0, 100)] def __getitem__(self, index): return self.all_data[index], 2* self.all_data[index] def __len__(self): return len(self.all_data)
Then we can start to get batch samples from this dataset.
How to load batch data from custom Dataset?
We can use pytorch dataloader.DataLoader(). It is defined as:
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)
It allows us to load batch data from a dataset easily.
For example:
from torch.utils.data import dataloader train_dataset = CustomDataset() train_loader = dataloader.DataLoader( dataset=train_dataset, batch_size=8, shuffle=True ) print(train_loader) for i_batch ,batch_data in enumerate(train_loader): print(i_batch, batch_data, type(batch_data)) print("batch end")
Run this code, we will see:
<torch.utils.data.dataloader.DataLoader object at 0x000002AB31EA9B70> 0 [tensor([85, 62, 12, 35, 67, 52, 60, 0]), tensor([170, 124, 24, 70, 134, 104, 120, 0])] <class 'list'> batch end 1 [tensor([ 7, 21, 99, 41, 32, 23, 82, 45]), tensor([ 14, 42, 198, 82, 64, 46, 164, 90])] <class 'list'> batch end 2 [tensor([34, 1, 36, 43, 78, 10, 56, 98]), tensor([ 68, 2, 72, 86, 156, 20, 112, 196])] <class 'list'> batch end 3 [tensor([ 8, 92, 46, 44, 37, 33, 91, 19]), tensor([ 16, 184, 92, 88, 74, 66, 182, 38])] <class 'list'> batch end 4 [tensor([ 6, 65, 81, 47, 17, 9, 29, 39]), tensor([ 12, 130, 162, 94, 34, 18, 58, 78])] <class 'list'> batch end 5 [tensor([24, 30, 27, 28, 18, 4, 40, 51]), tensor([ 48, 60, 54, 56, 36, 8, 80, 102])] <class 'list'> batch end 6 [tensor([16, 57, 93, 54, 22, 48, 71, 38]), tensor([ 32, 114, 186, 108, 44, 96, 142, 76])] <class 'list'> batch end 7 [tensor([50, 58, 20, 59, 88, 55, 69, 25]), tensor([100, 116, 40, 118, 176, 110, 138, 50])] <class 'list'> batch end 8 [tensor([72, 76, 90, 73, 53, 42, 63, 70]), tensor([144, 152, 180, 146, 106, 84, 126, 140])] <class 'list'> batch end 9 [tensor([83, 96, 66, 75, 5, 77, 49, 61]), tensor([166, 192, 132, 150, 10, 154, 98, 122])] <class 'list'> batch end 10 [tensor([94, 79, 68, 26, 31, 2, 74, 14]), tensor([188, 158, 136, 52, 62, 4, 148, 28])] <class 'list'> batch end 11 [tensor([89, 97, 64, 11, 15, 84, 13, 3]), tensor([178, 194, 128, 22, 30, 168, 26, 6])] <class 'list'> batch end 12 [tensor([95, 80, 87, 86]), tensor([190, 160, 174, 172])] <class 'list'> batch end
From above, we can find:
Each batch is saved in a python list, the data type is pytorch tensor.