If you plan to use pytorch to load data to train a model, you should use a Dataset class to load. In this tutorial, we will introduce you how to create a custom Dataset class for loading.
How to create a custom Dataset?
Here is the basic structure of a custom Dataset class.
from torch.utils.data import dataset
class CustomDataset(dataset.Dataset):
def __init__(self):
super(CustomDataset, self).__init__()
def __getitem__(self, index):
return None
def __len__(self):
return 0
- from torch.utils.data import dataset
- class CustomDataset(dataset.Dataset):
- def __init__(self):
- super(CustomDataset, self).__init__()
- def __getitem__(self, index):
- return None
- def __len__(self):
- return 0
from torch.utils.data import dataset
class CustomDataset(dataset.Dataset):
def __init__(self):
super(CustomDataset, self).__init__()
def __getitem__(self, index):
return None
def __len__(self):
return 0
Here CustomDataset is our custom Dataset class, it is a child class pytorch dataset.Dataset.
We should overwrite three basic methods: __init__(), __getitem__() and __len__().
Here:
__getiitem__(): we will get a row by index
__len__(): it means how many rows in our dataset.
For example:
class CustomDataset(dataset.Dataset):
def __init__(self):
super(CustomDataset, self).__init__()
# load all data for training or test
self.all_data = [i for i in range(0, 100)]
def __getitem__(self, index):
return self.all_data[index], 2* self.all_data[index]
def __len__(self):
return len(self.all_data)
- class CustomDataset(dataset.Dataset):
- def __init__(self):
- super(CustomDataset, self).__init__()
- # load all data for training or test
- self.all_data = [i for i in range(0, 100)]
- def __getitem__(self, index):
- return self.all_data[index], 2* self.all_data[index]
- def __len__(self):
- return len(self.all_data)
class CustomDataset(dataset.Dataset):
def __init__(self):
super(CustomDataset, self).__init__()
# load all data for training or test
self.all_data = [i for i in range(0, 100)]
def __getitem__(self, index):
return self.all_data[index], 2* self.all_data[index]
def __len__(self):
return len(self.all_data)
Then we can start to get batch samples from this dataset.
How to load batch data from custom Dataset?
We can use pytorch dataloader.DataLoader(). It is defined as:
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)
- DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)
It allows us to load batch data from a dataset easily.
For example:
from torch.utils.data import dataloader
train_dataset = CustomDataset()
train_loader = dataloader.DataLoader(
dataset=train_dataset,
batch_size=8,
shuffle=True
)
print(train_loader)
for i_batch ,batch_data in enumerate(train_loader):
print(i_batch, batch_data, type(batch_data))
print("batch end")
- from torch.utils.data import dataloader
- train_dataset = CustomDataset()
- train_loader = dataloader.DataLoader(
- dataset=train_dataset,
- batch_size=8,
- shuffle=True
- )
- print(train_loader)
- for i_batch ,batch_data in enumerate(train_loader):
- print(i_batch, batch_data, type(batch_data))
- print("batch end")
from torch.utils.data import dataloader
train_dataset = CustomDataset()
train_loader = dataloader.DataLoader(
dataset=train_dataset,
batch_size=8,
shuffle=True
)
print(train_loader)
for i_batch ,batch_data in enumerate(train_loader):
print(i_batch, batch_data, type(batch_data))
print("batch end")
Run this code, we will see:
<torch.utils.data.dataloader.DataLoader object at 0x000002AB31EA9B70>
0 [tensor([85, 62, 12, 35, 67, 52, 60, 0]), tensor([170, 124, 24, 70, 134, 104, 120, 0])] <class 'list'>
batch end
1 [tensor([ 7, 21, 99, 41, 32, 23, 82, 45]), tensor([ 14, 42, 198, 82, 64, 46, 164, 90])] <class 'list'>
batch end
2 [tensor([34, 1, 36, 43, 78, 10, 56, 98]), tensor([ 68, 2, 72, 86, 156, 20, 112, 196])] <class 'list'>
batch end
3 [tensor([ 8, 92, 46, 44, 37, 33, 91, 19]), tensor([ 16, 184, 92, 88, 74, 66, 182, 38])] <class 'list'>
batch end
4 [tensor([ 6, 65, 81, 47, 17, 9, 29, 39]), tensor([ 12, 130, 162, 94, 34, 18, 58, 78])] <class 'list'>
batch end
5 [tensor([24, 30, 27, 28, 18, 4, 40, 51]), tensor([ 48, 60, 54, 56, 36, 8, 80, 102])] <class 'list'>
batch end
6 [tensor([16, 57, 93, 54, 22, 48, 71, 38]), tensor([ 32, 114, 186, 108, 44, 96, 142, 76])] <class 'list'>
batch end
7 [tensor([50, 58, 20, 59, 88, 55, 69, 25]), tensor([100, 116, 40, 118, 176, 110, 138, 50])] <class 'list'>
batch end
8 [tensor([72, 76, 90, 73, 53, 42, 63, 70]), tensor([144, 152, 180, 146, 106, 84, 126, 140])] <class 'list'>
batch end
9 [tensor([83, 96, 66, 75, 5, 77, 49, 61]), tensor([166, 192, 132, 150, 10, 154, 98, 122])] <class 'list'>
batch end
10 [tensor([94, 79, 68, 26, 31, 2, 74, 14]), tensor([188, 158, 136, 52, 62, 4, 148, 28])] <class 'list'>
batch end
11 [tensor([89, 97, 64, 11, 15, 84, 13, 3]), tensor([178, 194, 128, 22, 30, 168, 26, 6])] <class 'list'>
batch end
12 [tensor([95, 80, 87, 86]), tensor([190, 160, 174, 172])] <class 'list'>
batch end
- <torch.utils.data.dataloader.DataLoader object at 0x000002AB31EA9B70>
- 0 [tensor([85, 62, 12, 35, 67, 52, 60, 0]), tensor([170, 124, 24, 70, 134, 104, 120, 0])] <class 'list'>
- batch end
- 1 [tensor([ 7, 21, 99, 41, 32, 23, 82, 45]), tensor([ 14, 42, 198, 82, 64, 46, 164, 90])] <class 'list'>
- batch end
- 2 [tensor([34, 1, 36, 43, 78, 10, 56, 98]), tensor([ 68, 2, 72, 86, 156, 20, 112, 196])] <class 'list'>
- batch end
- 3 [tensor([ 8, 92, 46, 44, 37, 33, 91, 19]), tensor([ 16, 184, 92, 88, 74, 66, 182, 38])] <class 'list'>
- batch end
- 4 [tensor([ 6, 65, 81, 47, 17, 9, 29, 39]), tensor([ 12, 130, 162, 94, 34, 18, 58, 78])] <class 'list'>
- batch end
- 5 [tensor([24, 30, 27, 28, 18, 4, 40, 51]), tensor([ 48, 60, 54, 56, 36, 8, 80, 102])] <class 'list'>
- batch end
- 6 [tensor([16, 57, 93, 54, 22, 48, 71, 38]), tensor([ 32, 114, 186, 108, 44, 96, 142, 76])] <class 'list'>
- batch end
- 7 [tensor([50, 58, 20, 59, 88, 55, 69, 25]), tensor([100, 116, 40, 118, 176, 110, 138, 50])] <class 'list'>
- batch end
- 8 [tensor([72, 76, 90, 73, 53, 42, 63, 70]), tensor([144, 152, 180, 146, 106, 84, 126, 140])] <class 'list'>
- batch end
- 9 [tensor([83, 96, 66, 75, 5, 77, 49, 61]), tensor([166, 192, 132, 150, 10, 154, 98, 122])] <class 'list'>
- batch end
- 10 [tensor([94, 79, 68, 26, 31, 2, 74, 14]), tensor([188, 158, 136, 52, 62, 4, 148, 28])] <class 'list'>
- batch end
- 11 [tensor([89, 97, 64, 11, 15, 84, 13, 3]), tensor([178, 194, 128, 22, 30, 168, 26, 6])] <class 'list'>
- batch end
- 12 [tensor([95, 80, 87, 86]), tensor([190, 160, 174, 172])] <class 'list'>
- batch end
<torch.utils.data.dataloader.DataLoader object at 0x000002AB31EA9B70>
0 [tensor([85, 62, 12, 35, 67, 52, 60, 0]), tensor([170, 124, 24, 70, 134, 104, 120, 0])] <class 'list'>
batch end
1 [tensor([ 7, 21, 99, 41, 32, 23, 82, 45]), tensor([ 14, 42, 198, 82, 64, 46, 164, 90])] <class 'list'>
batch end
2 [tensor([34, 1, 36, 43, 78, 10, 56, 98]), tensor([ 68, 2, 72, 86, 156, 20, 112, 196])] <class 'list'>
batch end
3 [tensor([ 8, 92, 46, 44, 37, 33, 91, 19]), tensor([ 16, 184, 92, 88, 74, 66, 182, 38])] <class 'list'>
batch end
4 [tensor([ 6, 65, 81, 47, 17, 9, 29, 39]), tensor([ 12, 130, 162, 94, 34, 18, 58, 78])] <class 'list'>
batch end
5 [tensor([24, 30, 27, 28, 18, 4, 40, 51]), tensor([ 48, 60, 54, 56, 36, 8, 80, 102])] <class 'list'>
batch end
6 [tensor([16, 57, 93, 54, 22, 48, 71, 38]), tensor([ 32, 114, 186, 108, 44, 96, 142, 76])] <class 'list'>
batch end
7 [tensor([50, 58, 20, 59, 88, 55, 69, 25]), tensor([100, 116, 40, 118, 176, 110, 138, 50])] <class 'list'>
batch end
8 [tensor([72, 76, 90, 73, 53, 42, 63, 70]), tensor([144, 152, 180, 146, 106, 84, 126, 140])] <class 'list'>
batch end
9 [tensor([83, 96, 66, 75, 5, 77, 49, 61]), tensor([166, 192, 132, 150, 10, 154, 98, 122])] <class 'list'>
batch end
10 [tensor([94, 79, 68, 26, 31, 2, 74, 14]), tensor([188, 158, 136, 52, 62, 4, 148, 28])] <class 'list'>
batch end
11 [tensor([89, 97, 64, 11, 15, 84, 13, 3]), tensor([178, 194, 128, 22, 30, 168, 26, 6])] <class 'list'>
batch end
12 [tensor([95, 80, 87, 86]), tensor([190, 160, 174, 172])] <class 'list'>
batch end
From above, we can find:
Each batch is saved in a python list, the data type is pytorch tensor.