pytorch读取自己的数据集,pytorch读取图像

  pytorch读取自己的数据集,pytorch读取图像

  本文主要介绍如何使用PyTorch实现免费数据读取的相关信息。通过示例代码详细介绍,对您的学习或工作有一定的参考价值。有需要的朋友可以参考一下。

  00-1010前言PyTorch数据读取功能介绍ImageFolderDatasetDataLoader问题源码示例实现自定义数据读取总结

  

目录

  很多前辈说过,深度学习就像炼丹术,框架是炼丹炉,网络结构和算法是药方,数据集是原材料。要炼好炼丹,首先需要一个炼丹炉,有好的配方和原料,最后需要炼丹师有足够的经验和技巧,掌握好火候和时机,才能炼出绝世仙丹。

  对于刚进入炼丹行业的炼丹师来说,网上有一些前人总结的炼丹技巧,也有很多炼丹师的心路历程,以及他们对整个炼丹过程的记录。有了这些,他们无疑可以很快知道如何炼丹。但是现在市面上入门级炼丹的手册,往往是把原料放进炼丹炉给你看。你只需要打开炼丹炉,然后简单调试一下,就可以生产出炼丹了。这无疑降低了大家的入门难度,但是真正炼丹的时候,就会不知所措,不知道怎么把原料放进炉里。

  这篇炼金术入门指南是用PyTorch做熔炉,教你如何把原材料放进熔炉。这一步虽然不涉及太多的算法,但是对于炼金术的开始来说,是非常重要的一步。

  

前言

  

PyTorch数据读入函数介绍

  PyTorch中有一个现成的数据读取方法,就是torchwision。这个api是模仿keras编写的,它主要处理分类问题。举个例子,如果有10个类别,那么一个大文件夹下会建立10个子文件夹,每一个子文件夹里都会放入相同种类的数据。

  这个函数可以很容易地建立一个数据I/O,但是问题来了。如果我要处理的数据不是这么简单的分类问题,比如说我要做机器翻译,那么我的输入输出都是句子,那怎么读入数据呢?

  这个问题很好解决。我们可以看一下ImageFolder的实现,发现它是torch.utils.data.Dataset的一个子类,那么下面我们来介绍一下torch.utils.data.dataset这个类。

  

ImageFolder

  我们可以发现数据集的定义如下

  这里的注释意味着这是一个表示数据集的抽象类。所有关于数据集的类都可以定义为它的子类,只需重写__getitem__和__len__。让我们回过头来看看ImageFolder的实现。确实是这样,所以现在问题就变得很简单了。对于机器翻译问题,我们只需要定义整个数据集的长度,同时定义取出其中一个索引的元素。

  那么我们在定义完数据集之后就不能把所有的数据集都放入内存,这样内存肯定会爆仓。我们需要定义一个迭代器来在每一步生成一个批处理。这里已经为我们实现了PyTorch,下面是torch.utils.data.DataLoader。

  

Dataset

  DataLoader可以自动为我们生成一个多线程迭代器,只要传入几个参数。第一个参数是上面定义的数据集,最后一个参数是批处理大小的大小、数据是否被加扰、读取数据的线程数量等等。这样,我们就建立了一个多线程I/O。

  看完这个,你可能会觉得PyTorch真的很方便。这个丹炉真的很好用,然后就迫不及待的想试试了。然后,如果可能的话,您将报告一个错误。而且,你已经一步步落实了。如何报告错误?别急,先来说说为什么报错,以及这个pyhon实现的解读,让你真正知道如何读取自定义数据。

  

DataLoader

  通过上面的实现,你可能会遇到各种问题。数据集很简单,一般不会出错。仅仅

  要Dataset实现正确,那么问题的来源只有一个,那就是torch.utils.data.DataLoader中的一个参数collate_fn,这里我们需要找到DataLoader的源码进行查看这个参数到底是什么。

  可以看到collate_fn默认是等于default_collate,那么这个函数的定义如下。

  

  是不是看着有点头大,没有关系,我们先搞清楚他的输入是什么。这里可以看到他的输入被命名为batch,但是我们还是不知道到底是什么,可以猜测应该是一个batch size的数据。我们继续往后找,可以找到这个地方。

  

  我们可以从这里看到collate_fn在这里进行了调用,那么他的输入我们就找到了,从这里看这就是一个list,list中的每个元素就是self.data[i],如果你在往上看,可以看到这个self.data就是我们需要预先定义的Dataset,那么这里self.data[i]就等价于我们在Dataset里面定义的__getitem__这个函数。

  所以我们知道了collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。

  这时我们再去看看collate_fn这个函数,其实可以看到非常简单,就是通过对一些情况的排除,然后最后输出结果,比如第一个if,如果我们的输入是一个tensor,那么最后会将一个batch size的tensor重新stack在一起,比如输入的tensor是一张图片,3x30x30,如果batch size是32,那么按第一维stack之后的结果就是32x3x30x30,这里stack和concat有一点区别就是会增加一维。

  所以通过上面的源码解读我们知道了数据读入具体是如何操作的,那么我们就能够实现自定义的数据读入了,我们需要自己按需要重新定义collate_fn这个函数,下面举个例子。

  

  

自定义数据读入的举例实现

  下面我们来举一个麻烦的例子,比如做文本识别,需要将一张图片上的字符识别出来,比如下面这些图片

  

  那么这个问题的输入就是一张一张的图片,他的label就是一串字符,但是由于长度是变化的,所以这个问题比较麻烦。

  下面我们就来简单实现一下。

  我们有一个train.txt的文件,上面有图片的名称和对应的label,首先我们需要定义一个Dataset。

  

class custom_dset(Dataset):

   def __init__(self,

   img_path,

   txt_path,

   img_transform=None,

   loader=default_loader):

   with open(txt_path, r) as f:

   lines = f.readlines()

   self.img_list = [

   os.path.join(img_path, i.split()[0]) for i in lines

   ]

   self.label_list = [i.split()[1] for i in lines]

   self.img_transform = img_transform

   self.loader = loader

   def __getitem__(self, index):

   img_path = self.img_list[index]

   label = self.label_list[index]

   # img = self.loader(img_path)

   img = img_path

   if self.img_transform is not None:

   img = self.img_transform(img)

   return img, label

   def __len__(self):

   return len(self.label_list)

  这里非常简单,就是将txt文件打开,然后分别读取图片名和label,由于存放图片的文件夹我并没有放上去,因为数据太大,所以读取图片以及对图片做一些变换的操作就不进行了。

  接着我们自定义一个collate_fn,这里可以使用任何名字,只要在DataLoader里面传入就可以了。

  

def collate_fn(batch):

   batch.sort(key=lambda x: len(x[1]), reverse=True)

   img, label = zip(*batch)

   pad_label = []

   lens = []

   max_len = len(label[0])

   for i in range(len(label)):

   temp_label = [0] * max_len

   temp_label[:len(label[i])] = label[i]

   pad_label.append(temp_label)

   lens.append(len(label[i]))

   return img, pad_label, lens

  代码的细节就不详细说了,总体来讲就是先按label长度进行排序,然后进行长度的pad,最后输出图片,label以及每个label的长度的list。

  下面我们可以验证一下,得到如下的结果。

  

  具体的操作大家可以去玩一下,改一改,能够实现任何你想要的输出,比如图片输出为一个32x3x30x30的tensor,将label中的字母转化为数字标示,然后也可以输出为tensor,任何你想要的操作都可以在上面显示的程序中执行。

  以上就是本文所有的内容,后面的例子不是很完整,讲得也不是很详细,因为图片数据太大,不好传到github上,当然通过看代码能够更快的学习。通过本文的阅读,大家应该都能够掌握任何需要的数据读入,如果有问题欢迎评论留言。

  完整代码

  

  

总结

  到此这篇关于如何使用PyTorch实现自由数据读取的文章就介绍到这了,更多相关PyTorch数据读取内容请搜索盛行IT软件开发工作室以前的文章或继续浏览下面的相关文章希望大家以后多多支持盛行IT软件开发工作室!

郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。

留言与评论(共有 条评论)
   
验证码: