pytorch 读取图片,pytorch自带数据集
在数据预处理解决深度学习问题的过程中,往往需要花费大量的时间和精力。下面这篇文章主要介绍pytorch加载自己的图像数据集的两种方法。通过示例代码详细介绍,有需要的朋友可以参考一下。
目录
ImageFolder加载数据集使用pytorch提供的数据集类创建您自己的数据集。数据集加载数据集摘要pytorch加载图片数据集有两种方式。1.ImageFolder适用于分类数据集,每个类别的图片都在同一个文件夹中。对于ImageFolder加载的数据集,训练数据是files下的图片,训练标签是对应的文件夹,每个文件夹是一个类别。
导入ImageFolder()包
从torchvision.datasets导入图像文件夹
在Flower_Orig_dataset文件夹下,有Flower_Orig和向日葵两个文件夹,这两个文件夹包含相同类别的图片。用ImageFolder加载的图片会返回图片信息和对应的标签信息,但是标签信息是根据文件夹给出的,比如flower_orig是标签0,向日葵是标签1。
ImageFolder 加载数据集
1.导入包并设置转换
进口火炬
从torchvision导入变换、数据集
将torch.nn作为nn导入
从torch.utils.data导入数据加载器
转换=转换。撰写([
转变。调整大小(256),#将图片的短边缩放到256,保持纵横比不变:
转变。CenterCrop(224),#将图片从中间剪切成3*224*224大小的图片。
转变。ToTensor() #将图片归一化,并将数据转换为张量类型。
])
2.加载数据集3360将分类图片的父目录作为路径传递给ImageFolder(),并将其传递给transform。这样,您就有了要加载的数据集。之后,可以使用DataLoader加载数据,构建网络训练。
path=r d : \ dataset \ Flower _ Orig _ dataset
data_train=数据集。ImageFolder(路径,转换=转换)
数据加载器=数据加载器(数据训练,批量大小=64,洗牌=真)
对于I,枚举数据(data_loader):
图像、标签=数据
打印(图像.形状)
打印(标签.形状)
破裂
使用pytorch提供的Dataset类创建自己的数据集。
具体步骤:
1.首先你要有一个txt文件。此文件格式为:图像路径标签。这样的格式,所以用os库,遍历自己的镜像名,把标签和镜像路径写入txt文件。
2.有了这个txt文件,我们可以在类中构造我们的数据集。
2.1从图片标签上划分图片路径。有两个列表,一个是图片路径名,一个是标签号。有一点是第I个图片列表对应第I个标签。
3.重写__len__方法和__getitem__方法
3.1在getitem方法中,获取相应的图片路径,用PIL库读取文件转换图片后,在GetItem函数中返回读取的图片和标签。
4.您可以构建数据集实例并加载数据集。
定义一个函数来生成一个类似[图片路径标签]的txt文件
def make_txt(根,文件名,标签):
path=os。
path.join(root, file_name)
data = os.listdir(path)
f = open(path+\\+f.txt, w)
for line in data:
f.write(line+ +str(label)+\n)
f.close()
#调用函数生成两个文件夹下的txt文件
make_txt(path, file_name=flower_orig, label=0)
make_txt(path, file_name=sunflower, label=1)
将连个txt文件合并成一个txt文件,表示数据集所有的图片和标签
def link_txt(file1, file2):txt_list = []
path = rD:\数据集\Flower_Orig_dataset\data.txt
f = open(path, a)
f1 = open(file1, r)
data1 = f1.readlines()
for line in data1:
txt_list.append(line)
f2 = open(file2, r)
data2 = f2.readlines()
for line in data2:
txt_list.append(line)
for line in txt_list:
f.write(line)
f.close()
f1.close()
f2.close()
#调用函数, 将两个文件夹下的txt文件合并
file1 = rD:\数据集\Flower_Orig_dataset\flower_orig\f.txt
file2 = rD:\数据集\Flower_Orig_dataset\sunflower\f.txt
link_txt(file1=file1, file2=file2)
现在我们已经有了我们制作数据集所需要的txt文件, 接下来要做的即使继承Dataset类, 来构建自己的数据集 , 别忘了前面说的 构建数据集步骤, 在__getitem__函数中, 需要拿到图片路径和标签, 并且用PIL库方法读取图片,对图片进行transform转换后,返回图片信息和标签信息
Dataset加载数据集
我们读取图片的根目录, 在根目录下有所有图片的txt文件, 拿到txt文件后, 先读取txt文件, 之后遍历txt文件中的每一行, 首先去除掉尾部的换行符, 在以空格切分,前半部分是图片名称, 后半部分是图片标签, 当图片名称和根目录结合,就得到了我们的图片路径class MyDataset(Dataset):
def __init__(self, img_path, transform=None):
super(MyDataset, self).__init__()
self.root = img_path
self.txt_root = self.root + data.txt
f = open(self.txt_root, r)
data = f.readlines()
imgs = []
labels = []
for line in data:
line = line.rstrip()
word = line.split()
imgs.append(os.path.join(self.root, word[1], word[0]))
labels.append(word[1])
self.img = imgs
self.label = labels
self.transform = transform
def __len__(self):
return len(self.label)
def __getitem__(self, item):
img = self.img[item]
label = self.label[item]
img = Image.open(img).convert(RGB)
#此时img是PIL.Image类型 label是str类型
if transforms is not None:
img = self.transform(img)
label = np.array(label).astype(np.int64)
label = torch.from_numpy(label)
return img, label
加载我们的数据集:
path = rD:\数据集\Flower_Orig_datasetdataset = MyDataset(path, transform=transform)
data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)
接下来我们就可以构建我们的网络架构:
class Net(nn.Module):def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3,16,3)
self.maxpool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(16,5,3)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(55*55*5, 1200)
self.fc2 = nn.Linear(1200,64)
self.fc3 = nn.Linear(64,2)
def forward(self,x):
x = self.maxpool(self.relu(self.conv1(x))) #113
x = self.maxpool(self.relu(self.conv2(x))) #55
x = x.view(-1, self.num_flat_features(x))
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
训练我们的网络:
model = Net()criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
epochs = 10
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(data_loader):
images, label = data
out = model(images)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if(i+1)%10 == 0:
print([%d %5d] loss: %.3f%(epoch+1, i+1, running_loss/100))
running_loss = 0.0
print(finished train)
保存网络模型(这里不止是保存参数,还保存了网络结构)
#保存模型torch.save(net, model_name.pth) #保存的是模型, 不止是w和b权重值
# 读取模型
model = torch.load(model_name.pth)
总结
到此这篇关于pytorch加载自己的图片数据集的2种方法的文章就介绍到这了,更多相关pytorch加载图片数据集内容请搜索盛行IT软件开发工作室以前的文章或继续浏览下面的相关文章希望大家以后多多支持盛行IT软件开发工作室!
郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。