In pytorch, TensorDataset allows us to zip serial of pytorch tensors as python zip() function. In this tutorial, we will use some examples to show you how to use.
Syntax
TensorDataset is defined as:
class TensorDataset(Dataset[Tuple[Tensor, ...]]): r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
From __getitem__(), we can find it will return elements in each tensors based on index.
How to use TensorDataset?
For example:
import torch from torch.utils.data import TensorDataset x1 = torch.randn(4,5) x2 = torch.randn(4, 10) d = TensorDataset(x1, x2) print(d) for e in d: print(e)
Run this code, we will see:
<torch.utils.data.dataset.TensorDataset object at 0x0000021DBA92FB00> (tensor([ 0.0644, -0.8627, -0.9599, -0.3772, -0.0840]), tensor([ 2.5645, 0.3732, -0.3954, -1.9667, 0.5432, -0.3737, -0.2884, -0.7295, -1.8462, 0.2305])) (tensor([-0.9827, -1.5631, 0.0772, 0.7499, 0.3318]), tensor([ 0.9658, 1.9472, -0.1003, -1.1146, 0.7413, -1.0945, -0.0801, -0.3975, -0.6289, -0.4536])) (tensor([-0.1613, -0.3813, -0.2677, -1.0164, -0.4861]), tensor([ 0.0522, 1.7411, 0.2216, -1.0339, 0.2794, 1.4683, -1.1677, -0.4825, -1.6060, -1.4113])) (tensor([ 0.1819, -0.1794, 0.3319, -0.0702, 0.6290]), tensor([ 0.4300, 1.3721, -0.5497, -0.1086, 0.6109, -0.7664, 1.2882, -1.6521, -0.5760, -0.4642]))
Example 2:
x1 = torch.randn(4, 2) x2 = torch.randn(4, 3) x3 = torch.randn(4, 2) d = TensorDataset(x1, x2, x3) print(d) for e in d: print(e) print(x1)
Then, we will see:
(tensor([ 0.2332, -0.4129]), tensor([0.0416, 0.0045, 0.9388]), tensor([0.5951, 0.4067])) (tensor([ 1.0542, -0.9473]), tensor([-1.2757, 0.0499, -0.7282]), tensor([ 1.0244, -0.4466])) (tensor([-0.4178, 0.9416]), tensor([-1.0980, -0.2778, -0.0483]), tensor([-2.4482, 1.3482])) (tensor([ 3.0787, -0.5175]), tensor([ 0.2855, 1.2509, -1.4400]), tensor([ 1.1726, -0.0982])) tensor([[ 0.2332, -0.4129], [ 1.0542, -0.9473], [-0.4178, 0.9416], [ 3.0787, -0.5175]])