fixmatch分类,fixmatch改进

  fixmatch分类,fixmatch改进

  纸张:https://arxiv . org/pdf/2001.07685 . pdf代码:https://github . com/Google-research/fix match

  综述该方法由索恩等提出,结合了伪标签和一致性正则化,极大地简化了整个方法。它在广泛的基准测试中得到了最先进的结果。

  如我们所见,我们在有标签图像上使用交叉熵损失训练一个监督模型。对于每一幅未标记的图像,分别采用弱增强和强增强方法得到两幅图像。弱增强的图像被传递给我们的模型,我们得到预测。把置信度最大的类的概率与阈值进行比较。如果它高于阈值,那么我们将这个类作为标签,即伪标签。然后,将强增强后的图像通过模型进行分类预测。该预测方法与基于交叉熵损失的伪标签的方法进行了比较。把两种损失合并来优化模型。

  固定匹配流程图

  代码train.py

  班级培训师(对象):

  def __init__(self,cfg):

  self.cfg=cfg

  ###########超级参数设置############

  self.net=获取模型(CFG。num _ classes,cfg.model_name).至(设备)

  optimizer=RAdam(params=self。网。参数(),lr=cfg.lr,weight_decay=0.0001)

  self.optimizer=Lookahead(优化器)

  里程碑=[范围(5)中x的5x * 60]

  #打印(f 里程碑:{里程碑} )

  scheduler _ c=CyclicCosAnnealingLR(优化器,里程碑=里程碑,eta_min=5e-5)

  自我。scheduler=learningratewamup(optimizer=optimizer,target_iteration=5,target_lr=0.003

  after_scheduler=scheduler_c)

  self.criterion=ComboLoss().至(设备)

  自我G=GridMask(真,真)

  self.best_acc=-100

  def load_net(self,path):

  self.net=火炬。load(path,map_location=cuda:0)[模型_状态]

  # self.best_acc=torch.load(path,map _ location= cuda:0 )[ best _ ACC ]

  # print(f best _ ACC:{ self。best _ ACC } )

  def train_one_epoch(self,loader):

  num_samples=0

  running_loss=0

  trn_error=0

  self.net.train()

  对于图像,加载器中的遮罩:

  if self.cfg.cutMix:

  图像,遮罩=剪切混合(图像,遮罩)

  if self.cfg.fmix:

  w,h=images.size(-1),images.size(-2)

  images,masks=fmix_seg(images,masks,alpha=1 .decay_power=3 .shape=(w,h))

  images=images.to(device,dtype=torch.float)

  if self.cfg.Grid:

  图像=自我100克(图像)

  面具=火炬。挤(口罩。至(设备))

  #打印(图像大小:{},掩码大小:{})。格式(images.size()、masks.size()))

  num_samples=int(images.size(0))

  self.optimizer.zero_grad()

  输出,cls=self.net(图像)

  损耗=自我标准(输出、屏蔽、cls)

  loss.backward()

  batch_loss=loss.item()

  self.optimizer.step()

  运行损失=批量损失

  pred=get_predictions(输出)

  口罩=口罩。类型(火炬。cuda。长时态)

  masks=masks.data.cpu()

  trn_error=compute_error(pred,masks)

  返回running_loss/len(加载器),trn_error/len(加载器)

  定义验证(自身,加载器):

  num_samples=0

  running_loss=0

  trn_error=0

  self.net.eval()

  对于图像,加载器中的遮罩:

  images=images.to(device,dtype=torch.float)

  面具=火炬。挤(口罩。至(设备))

  num_samples=int(images.size(0))

  输出,cls=self.net(图像)

  损耗=自我标准(输出、屏蔽、cls)

  batch_loss=loss.item()

  运行损失=批量损失

  pred=get_predictions(输出)

  口罩=口罩。类型(火炬。cuda。长时态)

  masks=masks.data.cpu()

  trn_error=compute_error(pred,masks)

  返回running_loss/len(加载器),trn_error/len(加载器)

  定义列车(自身):

  mkdir(self.cfg.model_save_path)

  ############准备数据集###############

  train_loader,val_loader,test _ loader=build _ loader(self。CFG)

  对于范围内的时期(self.cfg.num_epochs):

  打印( Epoch: {}/{} .格式(纪元1,self.cfg.num_epochs))

  # optimizer.step()

  self.scheduler.step(纪元)

  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #火车

  train_loss,train _ error=自我。火车时代(火车_装载机)

  start=time.strftime(%H:%M:%S )

  打印(

  f 纪元:{纪元1 }/{自身。CFG。次数} :{start} ,

  f 培训损失:{train_loss:4f}.

  f 训练Acc: {1 - train_error:4f}.

  )

  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #有效

  val_loss,val _ error=self。验证(val _ loader)

  start=time.strftime(%H:%M:%S )

  打印(

  f 纪元:{纪元1 }/{自身。CFG。次数} :{start} ,

  验证损失:{val_loss:4f}.

  验证符合:{1 - val_error:4f}.

  )

  if 1 - val_error self.best_acc:

  州={

  纪元:纪元1,

  示范州:self.net,

  最佳_acc: 1 - val_error

  }

  check point=f"{ self。CFG。模型名称} _最佳。PTH "

  torch.save(状态,操作系统。路径。加入(自我。CFG。模型保存路径,检查点)#保存模型

  打印("模型已成功保存!")

  self.best_acc=1 - val_error

  def train_one_epoch_semi(self,trainloader,testloader):

  running_loss=0

  trn_error=0

  loader=zip(trainloader,testloader)

  self.net.train()

  对于加载器中的数据x、数据u:

  图像_x,目标_x=数据_x

  图像_u_w图像_u_s=数据_u

  # cpu==gpu

  images_x=images_x.to(device,dtype=torch.float)

  targets _ x=火炬。挤压(目标x到(设备))

  images _ u _ w=images _ u _ w . to(device,dtype=torch.float)

  images _ u _ s=images _ u _ s . to(device,dtype=torch.float)

  if self.cfg.Grid:

  images_x=self .g(图像_x)

  images_u_s=self .g(美国图片)

  #打印(图像大小:{},掩码大小:{})。格式(images.size()、masks.size()))

  self.optimizer.zero_grad()

  outputs_x,cls_x=self.net(images_x)

  outputs_u_w,cls_u_w=self.net(images_u_w)

  outputs_u_s,cls_u_s=self.net(images_u_s)

  #获取伪标签

  target _ u=outputs _ u _ w . ge(self。CFG。阈值).浮动()

  loss _ x=self。标准(输出x,目标x,时钟x)

  loss _ u=(self。criteria(outputs _ u _ s,torch.squeeze(targets_u),cls_x,reduction= none )* torch。挤压(targets _ u))。平均值()

  损失=损失_ x自我。CFG。_ u *损失_ u

  loss.backward()

  batch_loss=loss.item()

  self.optimizer.step()

  运行损失=批量损失

  pred=get_predictions(outputs_x)

  masks=targets _ x . type(火炬。cuda。长时态)

  masks=masks.data.cpu()

  trn_error=compute_error(pred,masks)

  返回running_loss/len(列车装载器),trn_error/len(列车装载器)

  定义列车_半(自身):

  自我。load _ net(f“{ self。CFG。模型保存路径}/{自我。CFG。模型名称} _最佳。PTH’)

  模型_保存_路径=自身。CFG。模型保存路径 _ semi

  mkdir(模型保存路径)

  ############准备数据集###############

  火车装载器,阀门装载器,测试_ loader=build _ loader _ v2(self。CFG)

  对于范围内的时期(self.cfg.num_epochs):

  打印( Epoch: {}/{} .格式(纪元1,self.cfg.num_epochs))

  # optimizer.step()

  self.scheduler.step(纪元)

  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #火车

  train_loss,train _ error=自我。train _ one _ epoch _ semi(训练加载器,测试加载器)

  start=time.strftime(%H:%M:%S )

  打印(

  f 纪元:{纪元1 }/{自身。CFG。次数} :{start} ,

  f 培训损失:{train_loss:4f}.

  f 训练Acc: {1 - train_error:4f}.

  )

  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #有效

  val_loss,val _ error=self。验证(val _ loader)

  start=time.strftime(%H:%M:%S )

  打印(

  f 纪元:{纪元1 }/{自身。CFG。次数} :{start} ,

  验证损失:{val_loss:4f}.

  验证符合:{1 - val_error:4f}.

  )

  if 1 - val_error self.best_acc:

  州={

  纪元:纪元1,

  示范州:self.net,

  最佳_acc: 1 - val_error

  }

  check point=f"{ self。CFG。模型名称} _最佳。PTH "

  torch.save(state,os.path.join(model_save_path,checkpoint)) #保存模型

  打印("模型已成功保存!")

  自我。best _ ACC=1-val _ error数据集。巴拉圭

  从torch.utils.data导入数据集,数据加载器

  从火炬. utils.data.sampler导入水下取样器

  进口火炬

  进口火炬视觉

  来自火炬视觉.转换导入合成

  将数组作为铭牌导入

  将cv2作为简历导入

  导入操作系统

  从随机进口样品中

  从实用工具.转换导入*

  来自utils.randaugment导入*

  从实用工具.网格导入网格

  定义图像到张量(图像):

  张量=火炬。from _ numpy(img。转置((2,0,1)))

  返回张量

  定义为单色(十):

  # x_=x.convert(L )

  x_=np.array(x).astype(np.float32) #将图像转换为单色

  返回x_

  定义张量(十):

  x_=np.expand_dims(x,轴=0)

  x_=torch.from_numpy(x_)

  返回x_

  火炬视觉。转变。托特索尔

  极好的自定义_模糊_演示(图片):

  kernel=np.array([[0,-1,0],[-1,5,-1],[0,-1,0]],np.float32) #锐化

  dst=cv.filter2D(image,-1,kernel=kernel)

  返回夏令时

  SasDataset类(数据集):

  def __init__(self,root,mode=train ,is_ndvi=False):

  self.root=root

  自我模式=模式

  self.is_ndvi=is_ndvi

  自我。img list=sorted(OS。listdir(self。根)中的图片的img)

  self.transform=DualCompose([

  RandomFlip(),

  RandomRotate90(),

  Rotate(),

  Shift(),

  粗糙漏失()

  ])

  自我RA=。随机变量(2,10)

  自我。img变换=合成([img _ to _ tensor])

  self.maskTransforms=Compose([

  torchvision.transforms.Lambda(转换为单色),

  火炬视觉。变换(to _ tensor),

  ])

  def __getitem__(self,idx):

  imgPath=os.path.join(self.root,self.imgList[idx])

  img=np.load(imgPath)

  img=自定义模糊演示(img)

  img name=OS。路径。分割(镜像路径)[-1].拆分(.)[0]

  if self.mode==test :

  batch _ data={ img :self。img转换(img),文件名:imgName}

  返回批处理数据

  标签路径=img路径。替换(“图像”,“标签”).替换( npy , png )

  mask=cv.imread(labelPath)/255

  #数据扩充

  if self.mode==train :

  img,遮罩=self.transform(img,遮罩)

  img=self .RA(img)

  # img,mask=img.astype(np.float),mask.astype(np.float)

  w,h=mask.shape[:2]

  mask=mask[:0]

  mask=NP。形状(掩码,(宽,高,1)).转置((2,0,1))

  返回self.imgTransforms(img),self。遮罩变换(NP。挤压(面具))

  def __len__(self):

  return len(self.imgList)

  类别美国数据集(数据集):

  def __init__(self,root,mode=train ):

  self.root=root

  自我模式=模式

  自我。img list=sorted(OS。listdir(self。根)中的图片的img)

  self.transform=DualCompose([

  RandomFlip(),

  RandomRotate90(),

  Rotate(),

  Shift(),

  #断流器(孔数=20,最大高度=20,最大宽度=20,填充值=0)

  ])

  自我RA=。随机变量(2,10)

  自我。img transforms=Compose([ImageToTensor()])

  def __getitem__(self,idx):

  imgPath=os.path.join(self.root,self.imgList[idx])

  img=np.load(imgPath)

  img=自定义模糊演示(img)

  mask=np.zeros_like(img)

  #弱数据扩充

  img_w,_=self.transform(img,mask)

  #严重的数据扩充

  img_s=self .RA(img_w)

  返回self.imgTransforms(img_w),self.imgTransforms(img_s)

  def __len__(self):

  return len(self.imgList)

郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。

留言与评论(共有 条评论)
   
验证码: