pytorch的自定义数据集/DataLoader和Dataset重写
背景介绍
做Modulation Recognition的时候需要加载自定义的数据集,这就涉及到DataLoader和Dataset类中的方法重写了。
DataLoader介绍
源码中的介绍是:
*Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.*
也就是说,我们可以通过输入一个数据集,及常用参数如:batch_size、shuffle,就可以得到一个打包好的迭代器。这个迭代器包含了batch_size的序号及根据batch_size分割好的数据块。
Dataset 介绍
源码中的介绍是:
An abstract class representing a :class:`Dataset`.
很短,但是很经典。这是一个抽象类。所谓抽象类就是类的抽象化,而类本身就是不存在的,所以抽象类无法实例化。它存在的意义就是被继承。而且继承抽象类的类必须要重写抽象类的方法。
简单的说,我们构造一个MyDataset数据类,需要继承Dataset,并重写Dataset中的方法。
去掉源码中的注释,Dataset抽象类的定义就五行代码,两个方法:
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
根据我们的需要,我们会重写 getitem方法,以及len方法。
工作原理
首先,我们要定义自己的数据集类,例如叫做MyDataset,则代码片段应该为:
class MyDataSet(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
self.length = data.shape[0]
def __getitem__(self, mask):
label = self.label[mask]
data = self.data[mask]
return label, data
def __len__(self):
return self.length
继承
很简单
class MyDataSet(Dataset):
表示我们MyDataSet类继承了抽象类Dataset。该MyDataSet类中的有三个方法。
init方法
init方法是python中的构造方法(java中是叫构造方法,不知道python是不是这么叫,如果不是请大家指正),构造方法会在实例化对象时调用。其传入参数就是我们的==数据集==(data)和==标签集==(label)。
def __init__(self, data, label):
self.data = data
self.label = label
self.length = data.shape[0]
getitem方法
getitem方法是获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在自己使用时,比如想从data = [100, 99, 98, …, 0]的集合中选出下标为[0, 2, 4]的集合,则index/mask 就取[0, 2, 4],返回data[index]即可。
其实在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。
def __getitem__(self, mask):
return self.label[mask], self.data[mask]
len方法
偷了个懒没有去看源码。听说不给返回length的话pytorch会一脸xx。
def __len__(self):
return self.length
使用
完成了MyDataSet,就可以通过DataLoader使用了。例如此处我已经有了一个Xtrain,其中的数据的每一个batch都代表了一个信号。Ytrain当中都是X_train对应的标签。
于是我的代码就是:train_set = MyDataSet(data=X_train, label=Y_train)
num_epoch = 100 # number of epochs to train on
batch_size = 1024 # training batch size
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
for epoch in range(num_epoch ):
model.train()
for batchsz, (label, data) in enumerate(train_data):
# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))
DataLoader会根据设置的batch_size来产生index/mask,然后调用Datase的__getitem方法取出数据。
输出结果如下:
接下来就可以愉快的写模型了!!!
总结
其实看起来很简单的一个Dataset抽象类重写和DataLoader使用,包含了面向对象编程的三大特点:==封装==、==继承==、==多态==。
- 封装体现在Dataset抽象类的封装及我们的MyDataSet类的封装上。
- 继承体现在我们MyDataSet继承Dataset抽象类上。
多态体现在DataLoader对数据集的操作上(这点纯属个人理解,感觉有点像java中的向上转型,但python好像没有这一概念)。
本文作者: Joffrey-Luo Cheng
本文链接: http://lcjoffrey.top/2021/12/04/datasetloader/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!