文章

理解 PyTorch 中 DataLoader 与 Dataset 的使用

理解 PyTorch 中 DataLoader 与 Dataset 的使用

为了提高代码的可读性与模块化特性,我们希望数据集代码与模型训练代码分离。于是 PyTorch 提供了两个原始类型(Data Primitive):torch.utils.data.DataLoadertorch.utils.data.Dataset,分别用于定义数据集对象、迭代读取数据条目。

下面将先介绍如何快速上手,之后对两个原始类型的参数作详细解释。

快速上手

实现 Dataset 子类

Dataset 是抽象基类,需要以它为基类编写子类,接着将子类实例化。

这部分代码需要根据实际的数据形式进行调整。下面只是一个简单的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, source_arr, target_arr):
        super().__init__()

        num_sample = len(source_arr)

        self.source_arr = source_arr
        self.target_arr = target_arr
        self.num_sample = num_sample

    def __getitem__(self, idx):
        return self.source_arr[idx], self.source_arr[idx]
    
    def __len__(self):
        return self.num_sample

导入类

导入刚刚写的 Dataset 子类与 DataLoader

1
2
from dataloader.dataset import MyDataset as Dataset
from torch.utils.data import DataLoader

实例化 Dataset 的子类对象

在主程序中,调用已经实现的 Dataset 子类来将其实例化。

1
train_dataset = Dataset(source_arr, target_arr)

实例化 DataLoader 对象

在主程序中调用 DataLoader 类进行实例化,并传入 Dataset 对象。

1
train_dataloader = DataLoader(train_dataset, arg1, arg2, ...)

与其在创建 DataLoader 对象的时候填写参数,更推荐的是提前将参数打包成字典,在创建对象时进行解包:

1
2
dataloader_args = {"batch_size": 256, "shuffle": True, "num_workers": 8}
train_dataloader = DataLoader(train_dataset, **dataloader_args)

迭代 DataLoader 对象

1
2
for sample in train_dataloader:
    ...

DataLoader 的迭代返回值取决于 collate_fn 参数。

如果 collate_fn == None,则对原始 batch 值使用默认函数进行处理并返回。否则,将原始 batch 值使用 collate_fn 指定的函数进行处理并返回。

具体来看,对于下面的最小测试单元:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from dataloader.dataset import MyDataset as Dataset
from torch.utils.data import DataLoader

source_arr = [1, 2, 3, 4, 5]
target_arr = [1, 2, 3, 4, 5]

dataset = Dataset(source_arr, target_arr)
dataloader = DataLoader(dataset, batch_size=3, shuffle=False)

for i, (source, target) in enumerate(dataloader):
    print(f'Batch {i + 1}:')
    print(f'Source: {source}')
    print(f'Target: {target}')
    print()

处理得到的原始 batch 值形如:

1
[(1, 1), (2, 2), (3, 3)]

因为没有指定 collate_fn,控制台的输出内容是默认函数处理的结果。

1
2
3
4
5
6
7
Batch 1:
Source: tensor([1, 2, 3])
Target: tensor([1, 2, 3])

Batch 2:
Source: tensor([4, 5])
Target: tensor([4, 5])

创建 collate_fn 函数

如果希望对数据进行进一步处理,可以创建一个 collate_fn() 函数,这样迭代返回值就是原始 batch 值经过这个函数处理后的内容。

一般将这个函数写在 MyDataset 类里,并使用 @staticmethod 装饰器。

1
2
3
4
5
6
7
class MyDataset(Dataset):
    ...
    
    @staticmethod
    def collate_fn(batch):
        # Your code here
        return ...

如果希望在未指定 collate_fn 时的默认结果基础上处理,可以通过下面的函数实现。

1
2
3
4
5
6
7
8
9
    @staticmethod
    def collate_fn(batch):
        source, target = zip(*batch)
        source = np.stack(source)
        target = np.stack(target)

        # Your code here

        return source, target

介绍

Dataset

Dataset 的作用是传入 DataLoader 供其包装、迭代。但它是一个抽象类(Abstract Class),意味着需要以 Dataset 为基类创建一个新的子类。

PyTorch 要求子类必须重写 __getitem__() 方法。

可选实现 __len__()__getitems__() 方法,来提高性能表现。

DataLoader

DataLoader 接收 Dataset 的子类对象作为数据集,自身为可迭代对象。其提供了较多可选参数,下面进行介绍。

  1. dataset (Dataset): 加载数据的数据集。
  2. batch_size (int, optional): 每个 batch 有多少个样本。默认为 1
  3. shuffle (bool, optional): 在每个 epoch 对数据重新排序。默认为 False
  4. sampler (Sampler or Iterable, optional): 自定义抽取样本的策略。与 shuffle 互斥。
  5. batch_sampler (Sampler or Iterable, optional): 类似于 sampler,但一次返回一批索引。与 batch_size, shuffle, sampler, drop_last等参数互斥。
  6. num_workers (int, optional): 用于数据加载的子进程(worker)数。设置为 0 表示数据将在主进程中加载,大于 0 表示使用多个子进程加载数据。默认为 0
  7. collate_fn (Callable, optional): 这个函数合并一个 list 类型样本来形成一个 mini-batch
  8. pin_memory (bool, optional): 如果为 True,将在返回 Tensors 之前将其复制到设备/CUDA 固定内存中。
  9. drop_last (bool, optional): 设置为 True 会丢弃最后一个不完整的批次。默认为 False
  10. timeout (numeric, optional): 从工作进程收集批次的超时值。默认为 0
  11. worker_init_fn (Callable, optional): 每个 worker 初始化的函数,一般不需要自己设置。默认为 None
  12. multiprocessing_context (str or multiprocessing.context.BaseContext, optional): 多进程上下文,一般使用操作系统的默认上下文即可。
  13. generator (torch.Generator, optional): 用于生成随机索引的随机数生成器,一般不需要自己设置。
  14. prefetch_factor (int, optional, keyword-only arg): 每个工作进程提前加载的批次数,可以提高数据加载效率。如果 num_workers=0 默认值为 None。否则,如果 num_workers > 0 默认值为 2
  15. persistent_workers (bool, optional): 如果为 True,不会在数据集被使用一次后关闭工作进程。这允许保持工作数据集实例处于活动状态。默认为 False
  16. pin_memory_device (str, optional): 如果 pin_memoryTrue,指定将数据固定到的设备。

具体了解 DataLoader 的数据处理流程,可以参考这两篇文章:

阿里云:PyTorch 小课堂开课啦!带你解析数据处理全流程(一)

阿里云:PyTorch 小课堂开课啦!带你解析数据处理全流程(二)

References

[1] PyTorch - torch.utils.data. Link.

[2] PyTorch - Datasets & DataLoaders. Link.

本文由作者按照 CC BY 4.0 进行授权