pytorch的自定义数据集/DataLoader和Dataset重写
背景介绍
做Modulation Recognition的时候需要加载自定义的数据集,这就涉及到DataLoader和Dataset类中的方法重写了。 # DataLoader介绍
源码中的介绍是: 1
*Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.*1
An abstract class representing a :class:`Dataset`.
去掉源码中的注释,Dataset抽象类的定义就五行代码,两个方法: 1
2
3
4
5
6
7
8
9
10class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
工作原理
首先,我们要定义自己的数据集类,例如叫做MyDataset,则代码片段应该为: 1
2
3
4
5
6
7
8
9
10
11
12
13class MyDataSet(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
self.length = data.shape[0]
def __getitem__(self, mask):
label = self.label[mask]
data = self.data[mask]
return label, data
def __len__(self):
return self.length
继承
很简单 1
class MyDataSet(Dataset):1
2
3
4def __init__(self, data, label):
self.data = data
self.label = label
self.length = data.shape[0]
__getitem__方法
__getitem__方法是获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在自己使用时,比如想从data = [100, 99, 98, …, 0]的集合中选出下标为[0, 2, 4]的集合,则index/mask 就取[0, 2, 4],返回data[index]即可。 其实在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。 1
2def __getitem__(self, mask):
return self.label[mask], self.data[mask]1
2def __len__(self):
return self.length1
2
3
4
5
6
7
8
9train_set = MyDataSet(data=X_train, label=Y_train)
num_epoch = 100 # number of epochs to train on
batch_size = 1024 # training batch size
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
for epoch in range(num_epoch ):
model.train()
for batchsz, (label, data) in enumerate(train_data):
# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))
接下来就可以愉快的写模型了!!!
总结
其实看起来很简单的一个Dataset抽象类重写和DataLoader使用,包含了面向对象编程的三大特点:==封装==、==继承==、==多态==。 - 封装体现在Dataset抽象类的封装及我们的MyDataSet类的封装上。 - 继承体现在我们MyDataSet继承Dataset抽象类上。 - 多态体现在DataLoader对数据集的操作上(这点纯属个人理解,感觉有点像java中的向上转型,但python好像没有这一概念)。