pytorch源码解析,pytorch实现图像识别

  pytorch源码解析,pytorch实现图像识别

  torch/utils/data/_ utils/data loader . py通常在使用pytorch训练神经网络时,data loader模块是整个网络训练的基础。根据输入界面的参数,将训练集划分为几个批量大小和它们的主要角色。典型的数据加载和批量训练过程包括:((后面会详细解释args。)

  loader=torch . utils . data . data loader(args)for data,label in loader: training这次主要是数据加载器模块的源代码(py torch bar

  从直译的角度来说,iterable和iterator的区别是可重复和可重复。两者的概念非常接近,但在底层实现上略有不同。

  3358www。Sina.com/表示一个对象可以被迭代,而下层表示__iter__方法3358www.Sina.com/表示一个对象是迭代器,而下层

  比如python中的list、dict、str是可重复的,也就是可以使用For循环。对于迭代器,除了For循环之外,还可以使用next()函数检索以下元素:要将迭代器转换为迭代器,请使用iterator()函数。iterable:

  #先获取迭代器对象it=ITER ([1,2,3,4,5]) while true: try: #)获取下一个值x=next(it);当except StopIteration: #遇到StopIteration时,会退出Dataset对象,在break pytorch中加载了循环数据的概念。在使用pytorch训练网络之前,必须将数据集转换成torch支持的数据集格式。这意味着数据集格式的采样器对象必须用数据集实例化。采样器的主要作用是决定如何从数据集中检索数据,比如SequentialSampler和RandomSampler。DataLoader对象可以理解为扮演一个统一的角色。根据输入接口的参数(如批量大小)将采样器对象采样的数据逐个封装成批,然后将数据加载到模型中。iterator:源代码解释此源代码解释针对pytorch版。让我们从下一段代码开始吧!

  Loader=torch . utils . data . data loader(args)对于数据,loader中的标签:培训首先介绍数据加载器需要的参数及其含义。

  参数的含义是数据集dataset (dataset对象)是否在每个历元之前重排数据集。采样器采样器定义了如何从数据集中采样数据(采样器与批量数据num_workers的区别在于一次处理数据加载的进程数量。0表示在将单进程collate_fn采样的样本合并到batchpin_memory并返回张量之前,是否将其加载到GPU的固定内存中。drop_last数据集的大小不是batchpin_memory。设置最后一个比较小的批次放弃超时数据加载的超时。在边界预取因子persistent_workers多进程环境中,调用tDataLoader是可迭代的,不是迭代器。确定进程的生命周期,然后逐步深入到以下级别:数据加载器

  数据加载器。(这里没有理解没关系,源码是需要反复回味的)

  大家可以看到,分为多流程和单流程(评论写错一定是多流程)。之所以出现两种情况,是因为作者实际上是在使用注释。一般来说,在单进程环境中,每次返回迭代器,都要重新创建和重置状态。在多进程环境中,返回的迭代器必须始终存在于整个数据加载器声明循环中。这样,迭代器可以在多个进程中重用。仔细查看代码,可以看到该方法调用了另一个method _get_iterator()。请注意,__iter__方法返回* **_BaseDataLoaderIter** *对象。后面我会详细讲解这个类。

  2.orch.utils.data.DataLoader

  如上所述

  的DataLoader对象是迭代的,其中没有__next__方法。所以你需要在这个类中定义一个迭代器来达到上层用户使用的效果(即迭代器的效果)。由于单进程和多进程的处理逻辑不同,返回的迭代器也不同。值得注意的是,_SingleProcessDataLoaderIter和_MultiProcessingDataLoaderIter均继承了_BaseDataLoaderIter类

  3._BaseDataLoaderIter

  与DataLoader不同,_BaseDataLoaderIter是一个迭代器类,因此它不仅实现了__iter__方法,还实现了**next**方法。__next__方法用于获取下一个元素并遍历数据。

  仔细查看代码,可以发现整个__next__方法中最重要、最关键的是

  数据=自身。_next_data(),而这个方法在这个基类中没有实现,所以需要继承它的子类来实现。

  同样,你也可以看到代码中有self._sampler_iter,这是一个采样器迭代器,用来获取一批数据的索引。

  4.单进程的处理逻辑_SingleProcessDataLoaderIter

  从上面的分析可以看出_ SingleprocessDataLoader是_ BaseDataLoader的子类,它的主要作用是实现函数**_next_data**在单个进程中加载数据。

  _SingleProcessDataLoaderIter.

  分析_next_data函数,先得到一个批量数据的索引,然后用一个取数器根据批量数据的索引取一个批量数据,然后把这些数据整合成一个批量,最后返回一个批量数据。(self._next_index由其父类实现)我们来看看这个函数是如何实现的。

  self._sampler_iter是采样器对象对应的迭代器,根据采样器对象的类型返回数据。

  至此,单进程数据加载过程已经完成。我们来过一遍整个过程的逻辑。这里只是粗略描述一下整个逻辑过程,还有很多细节。如果你想了解这些细节,建议你看看源代码(如果你不明白,请问我)

  多进程处理逻辑根据上面第二步的进程数决定返回哪个迭代器,所以在多进程环境下,返回的迭代器是_ multiprocessing dataloader,也是从_ basedataloader类继承的。与单进程不同,多进程使用多个进程同时加载数据,以加快数据加载速度。

  首先,您需要了解一些数据结构:

  index_queue:用于存储每批的索引。每个worker都有一个index_queue。_data_queue:存储处理好的批次,取出后直接返还给上层用户使用。多个工作人员共享该队列。_worker_result_queue:存储已处理的批次。当pin_memory为false时,与_data_queue相同。在这里,我们只需要关注_data_queue。_send_idx: idx_rcvd_idx,用于要处理的下一批;idx_task_info为下一批要取出的;因为每个工人的执行速度不同,所以采用这种数据结构来保证加工和领料的批次顺序一致。_tasks_outstanding:表示当前有多少批次可用。_MultiProcessingDataLoaderIter.__init

  这里只列出了重要的代码片段。上面的代码片段主要启动num_workers个进程,每个进程有一个index_queue,每个进程运行function _worker_loop。这个函数的主要功能是从index_queue中取出索引,读取数据,处理数据,返回数据,将数据插入data_queue。注意在这里一个worker每次只处理一个batch

  _MultiProcessingDataLoaderIter.__reset

  在这个函数中,需要为每个worker预取几个批处理,并把它们放入它的index_queue中,以防止在一个进程启动时出现空队列。

  如前所述,每个从_ BaseDataLoader继承的类都需要实现_next_data函数。让我们看看这个函数是如何在多进程环境中实现的。

  _MultiProcessingDataLoaderIter._next_data

  先看关键代码片段。

  首先,从self获取批次的idx和数据。_get_data()。此时,在得到数据后,需要将_tasks_outstanding的值减一。为了保证从_data_queue得到的批次与期望批次相同,引入了if判断。如果与预期批次不同,该批次的idx和数据将临时保存在_task_info中,以防止重复数据检索。如果相同,则从_task_info中删除该记录,并在下一步中处理获得的批处理。

  在取数据之前,我们需要使用_task_info来确定这个数据是否需要调用_get_data。也在_next_data函数中。

  第一个while循环用来判断下一批要取出的idx即_rcvd_idx是否还有数据。简而言之,就是得到一个有效的_rcvd_idx。

  _MultiProcessingDataLoaderIter._get_data

  分析这段代码,我们可以看到这个函数实际上是在其内部实现中调用了function _try_get_data。我们来看看_try_get_data是如何实现的。

  _MultiProcessingDataLoaderIter._try_get_data

  可以看到,这里直接从_data_queue中取数据,并返回取数据的状态和数据。

  _MultiProcessingDataLoaderIter._process_data

  从上面的_next_data可以看出,当取出的批次与预期取出的批次相同时,可以处理该批次的数据。上面的代码显示了整个处理逻辑。可以看出,每处理一个批次,_rcvd_idx递增下一次要取出的批次的索引。同时,每处理完一个批处理后,需要将一个待处理的批处理放在_index_queue中等待一个进程来处理它。所以这里调用_try_put_index。

  _MultiProcessingDataLoaderIter._try_put_index

  首先从sampler迭代器中获取下一批的索引,然后将索引放入第一个存活的worker的_index_queue中。最后,更新一些数据结构的状态。

  至此,多流程数据加载流程已经解释完毕。多进程和单进程的处理逻辑唯一的区别就是将单进程中的_ singleprocessdataloader改为_ multiprocessing dataloader。它的internal _next_data方法差别很大,需要仔细体验。有兴趣的同学建议看源代码。

  总结数据加载模块在整个神经网络训练过程中非常重要,所以了解数据加载模块的底层实现有助于我们写出更高效的代码。其次,整个数据加载模块的代码都是用python写的,因为涉及到很多底层的东西,所以会相对容易读懂。

  参考[1]torch . utils . data-py torch 1 . 9 . 0文档

  [2]Pytorch Dataloader研究笔记大专栏(dazhuanlan.com)

  [3]PyTorch学习笔记(6)——DataLoader源代码分析_博客之旅_g11d111 -CSDN博客_dataloader返回值

郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。

留言与评论(共有 条评论)
   
验证码: