经典卷积神经网络——resnet( 二 )


在这里可以看一下对比图,发现添加学习率自动衰减,loss下降速度会快一些,这说明模型拟合效果比较好 。
6.加载数据集,数据增强
这里我们仍然选择数据集,首先对数据进行增强,增加模型的泛华能力 。
【经典卷积神经网络——resnet】transs=trans.Compose([trans.Resize(256),trans.RandomHorizontalFlip(),trans.RandomCrop(64),trans.ColorJitter(brightness=0.5,contrast=0.5,hue=0.3),trans.ToTensor(),trans.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
函数中(亮度)(对比度)
(饱和度)hue(色调)
加载数据集:
train=tv.datasets.CIFAR10(root=r'E:\桌面\资料\cv3\数据集\cifar-10-batches-py',train=True,download=True,transform=transs)trainloader=data.DataLoader(train,num_workers=4,batch_size=8,shuffle=True,drop_last=True)
7.训练数据
for i in range(3):running_loss=0for index,data in enumerate(trainloader):x,y=datax=x.cuda()y=y.cuda()x=Variable(x)y=Variable(y)opt.zero_grad()h=model(x)loss1=loss(h,y)loss1.backward()opt.step()running_loss+=loss1.item()if index%100==99:avg_loos=running_loss/100running_loss=0print("avg_loss",avg_loos)
8.保存模型
torch.save(model.state_dict(),'resnet18.pth')
9.加载测试集数据,进行模型测试
首先加载训练好的模型
model.load_state_dict(torch.load('resnet18.pth'),False)
读取数据
test = tv.datasets.ImageFolder(root=r'E:\桌面\资料\cv3\数据',transform=transs,)testloader = data.DataLoader(test,batch_size=16,shuffle=False,)
测试数据
acc=0total=0for data in testloader:inputs,indel=dataout=model(inputs.cuda())_,prediction=torch.max(out.cpu(),1)total+=indel.size(0)b=(prediction==indel)acc+=b.sum()print("准确率%d %%"%(100*acc/total))
四、深层对比
上面提到VGG网络层次越深,准确率越低,为了解决这一问题,才提出了残差网络(),那么在网络中,到底会不会出现这一问题 。
如图所示:随着,训练层次不断提高,模型越来越好,成功解决了VGG网络的问题,到现在为止,残差网络还是被大多数人使用 。