在训练 深度学习 模型之前,样本集的制作是非常重要的环节。在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,按照对应格式自定义数据集,才可以使用DataLoader加载数据,下面是自定义样本集的整个流程。
“三步走”的策略
Pytorch输入数据PipeLine一般遵循“三步走”的策略,一般pytorch 的数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象。必须实现__len__()、__getitem__()这两个方法,这里面会用到transform对数据集进行扩充。
② 创建一个 DataLoader 对象。它是对DataSet对象进行迭代的,一般不需要事先里面的其他方法了。
③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练。
代码示例
1、基本格式
from torch.utils import data
class MyDataset(data.Dataset): # 需要继承data.Dataset
def __init__(self):
pass
def __len__(self):
# 获取数据集的大小
pass
def __getitem__(self, index):
# 通过索引获取数据
# 如果是图片,可以返回PIL.image、numpy数组或者Tensor
# 如果是PIL.image,需要使用transform转化成Tensor
pass
2、初始化和获取数据集长度
def __init__(self, root):
self.train_data = []
train_file = os.path.join(root, 'train/data.txt')
with open(train_file) as f:
for line in f:
path, label = line.strip().split(' ')
file_path = os.path.join(root, 'train', path)
self.train_data.append((file_path, label))
def __len__(self):
# 获取数据集的大小
return len(self.train_data)
创建文件 data.txt,并编辑文件路径和label,例如 cat/1.png 0,之后导入图片。
3、迭代读取文件def __getitem__(self, index):
from PIL import Image
file_path, label = self.train_data[index]
# 读取图片,并转化为np数组
img = Image.open(file_path)
return np.array(img), label
4、添加transform参数
dataset = MyDataset(file_path, transform=transforms.ToTensor())
loader = data.DataLoader(dataset, batch_size=10)
for l in loader:
print(l)
# class MyDataset():
def __init__(self, root, transform):
self.transform = transform
def __getitem__(self, index):
......
if self.transform:
img = self.transform(img)
return img, label
参考文档
https://www.jianshu.com/p/2d9927a70594
https://pytorch.org/docs/stable/data.html
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
本文为 陈华 原创,欢迎转载,但请注明出处:http://www.chenhuax.com/read/236