Create a Custom Dataset for Loading Data in PyTorch – PyTorch Tutorial

By | April 15, 2022

If you plan to use pytorch to load data to train a model, you should use a Dataset class to load. In this tutorial, we will introduce you how to create a custom Dataset class for loading.

How to create a custom Dataset?

Here is the basic structure of a custom Dataset class.

  1. from torch.utils.data import dataset
  2. class CustomDataset(dataset.Dataset):
  3. def __init__(self):
  4. super(CustomDataset, self).__init__()
  5. def __getitem__(self, index):
  6. return None
  7. def __len__(self):
  8. return 0
from torch.utils.data import dataset

class CustomDataset(dataset.Dataset):

    def __init__(self):
        super(CustomDataset, self).__init__()

    def __getitem__(self, index):
        return None

    def __len__(self):
        return 0

Here CustomDataset is our custom Dataset class, it is a child class pytorch dataset.Dataset.

We should overwrite three basic methods: __init__(), __getitem__() and __len__().

Here:

__getiitem__(): we will get a row by index

__len__(): it means how many rows in our dataset.

For example:

  1. class CustomDataset(dataset.Dataset):
  2. def __init__(self):
  3. super(CustomDataset, self).__init__()
  4. # load all data for training or test
  5. self.all_data = [i for i in range(0, 100)]
  6. def __getitem__(self, index):
  7. return self.all_data[index], 2* self.all_data[index]
  8. def __len__(self):
  9. return len(self.all_data)
class CustomDataset(dataset.Dataset):

    def __init__(self):
        super(CustomDataset, self).__init__()
        # load all data for training or test
        self.all_data = [i for i in range(0, 100)]

    def __getitem__(self, index):
        return self.all_data[index], 2* self.all_data[index]

    def __len__(self):
        return len(self.all_data)

Then we can start to get batch samples from this dataset.

How to load batch data from custom Dataset?

We can use pytorch dataloader.DataLoader().  It is defined as:

  1. DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)

It allows us to load batch data from a dataset easily.

For example:

  1. from torch.utils.data import dataloader
  2. train_dataset = CustomDataset()
  3. train_loader = dataloader.DataLoader(
  4. dataset=train_dataset,
  5. batch_size=8,
  6. shuffle=True
  7. )
  8. print(train_loader)
  9. for i_batch ,batch_data in enumerate(train_loader):
  10. print(i_batch, batch_data, type(batch_data))
  11. print("batch end")
from torch.utils.data import dataloader

train_dataset = CustomDataset()

train_loader = dataloader.DataLoader(
    dataset=train_dataset,
    batch_size=8,
    shuffle=True
    )

print(train_loader)

for i_batch ,batch_data in enumerate(train_loader):
    print(i_batch, batch_data, type(batch_data))
    print("batch end")

Run this code, we will see:

  1. <torch.utils.data.dataloader.DataLoader object at 0x000002AB31EA9B70>
  2. 0 [tensor([85, 62, 12, 35, 67, 52, 60, 0]), tensor([170, 124, 24, 70, 134, 104, 120, 0])] <class 'list'>
  3. batch end
  4. 1 [tensor([ 7, 21, 99, 41, 32, 23, 82, 45]), tensor([ 14, 42, 198, 82, 64, 46, 164, 90])] <class 'list'>
  5. batch end
  6. 2 [tensor([34, 1, 36, 43, 78, 10, 56, 98]), tensor([ 68, 2, 72, 86, 156, 20, 112, 196])] <class 'list'>
  7. batch end
  8. 3 [tensor([ 8, 92, 46, 44, 37, 33, 91, 19]), tensor([ 16, 184, 92, 88, 74, 66, 182, 38])] <class 'list'>
  9. batch end
  10. 4 [tensor([ 6, 65, 81, 47, 17, 9, 29, 39]), tensor([ 12, 130, 162, 94, 34, 18, 58, 78])] <class 'list'>
  11. batch end
  12. 5 [tensor([24, 30, 27, 28, 18, 4, 40, 51]), tensor([ 48, 60, 54, 56, 36, 8, 80, 102])] <class 'list'>
  13. batch end
  14. 6 [tensor([16, 57, 93, 54, 22, 48, 71, 38]), tensor([ 32, 114, 186, 108, 44, 96, 142, 76])] <class 'list'>
  15. batch end
  16. 7 [tensor([50, 58, 20, 59, 88, 55, 69, 25]), tensor([100, 116, 40, 118, 176, 110, 138, 50])] <class 'list'>
  17. batch end
  18. 8 [tensor([72, 76, 90, 73, 53, 42, 63, 70]), tensor([144, 152, 180, 146, 106, 84, 126, 140])] <class 'list'>
  19. batch end
  20. 9 [tensor([83, 96, 66, 75, 5, 77, 49, 61]), tensor([166, 192, 132, 150, 10, 154, 98, 122])] <class 'list'>
  21. batch end
  22. 10 [tensor([94, 79, 68, 26, 31, 2, 74, 14]), tensor([188, 158, 136, 52, 62, 4, 148, 28])] <class 'list'>
  23. batch end
  24. 11 [tensor([89, 97, 64, 11, 15, 84, 13, 3]), tensor([178, 194, 128, 22, 30, 168, 26, 6])] <class 'list'>
  25. batch end
  26. 12 [tensor([95, 80, 87, 86]), tensor([190, 160, 174, 172])] <class 'list'>
  27. batch end
<torch.utils.data.dataloader.DataLoader object at 0x000002AB31EA9B70>
0 [tensor([85, 62, 12, 35, 67, 52, 60,  0]), tensor([170, 124,  24,  70, 134, 104, 120,   0])] <class 'list'>
batch end
1 [tensor([ 7, 21, 99, 41, 32, 23, 82, 45]), tensor([ 14,  42, 198,  82,  64,  46, 164,  90])] <class 'list'>
batch end
2 [tensor([34,  1, 36, 43, 78, 10, 56, 98]), tensor([ 68,   2,  72,  86, 156,  20, 112, 196])] <class 'list'>
batch end
3 [tensor([ 8, 92, 46, 44, 37, 33, 91, 19]), tensor([ 16, 184,  92,  88,  74,  66, 182,  38])] <class 'list'>
batch end
4 [tensor([ 6, 65, 81, 47, 17,  9, 29, 39]), tensor([ 12, 130, 162,  94,  34,  18,  58,  78])] <class 'list'>
batch end
5 [tensor([24, 30, 27, 28, 18,  4, 40, 51]), tensor([ 48,  60,  54,  56,  36,   8,  80, 102])] <class 'list'>
batch end
6 [tensor([16, 57, 93, 54, 22, 48, 71, 38]), tensor([ 32, 114, 186, 108,  44,  96, 142,  76])] <class 'list'>
batch end
7 [tensor([50, 58, 20, 59, 88, 55, 69, 25]), tensor([100, 116,  40, 118, 176, 110, 138,  50])] <class 'list'>
batch end
8 [tensor([72, 76, 90, 73, 53, 42, 63, 70]), tensor([144, 152, 180, 146, 106,  84, 126, 140])] <class 'list'>
batch end
9 [tensor([83, 96, 66, 75,  5, 77, 49, 61]), tensor([166, 192, 132, 150,  10, 154,  98, 122])] <class 'list'>
batch end
10 [tensor([94, 79, 68, 26, 31,  2, 74, 14]), tensor([188, 158, 136,  52,  62,   4, 148,  28])] <class 'list'>
batch end
11 [tensor([89, 97, 64, 11, 15, 84, 13,  3]), tensor([178, 194, 128,  22,  30, 168,  26,   6])] <class 'list'>
batch end
12 [tensor([95, 80, 87, 86]), tensor([190, 160, 174, 172])] <class 'list'>
batch end

From above, we can find:

Each batch is saved in a python list, the data type is pytorch tensor.

Leave a Reply