窝牛号

pytorch利用CNN卷积神经网络来识别手写数字

pytorch——人工智能的开源深度学习框架pytorch深度学习框架之tensor张量计算机视觉的基石——读懂 CNN卷积神经网络使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络

上期文章我们分享了使用MINIST数据集训练第一个CNN卷积神经网络,并保存了预训练模型,本期我们基于上期的模型,进行神经网络的识别

MINIST数据

MINIST的数据分为2个部分:55000份训练数据(mnist.train)和10000份测试数据(mnist.test)。这个划分有重要的象征意义,他展示了在机器学习中如何使用数据。在训练的过程中,我们必须单独保留一份没有用于机器训练的数据作为验证的数据,它能确保训练的结果的可行性。

前面已经提到,每一份MINIST数据都由图片以及标签组成。我们将图片命名为“x”,将标记数字的标签命名为“y”。训练数据集和测试数据集都是同样的结构,例如:训练的图片名为 mnist.train.images 而训练的标签名为 mnist.train.labels。

每一个图片均为28×28像素,我们可以将其理解为一个二维数组的结构,这里重点强调一下MINIST数据集是黑底白字,这里在识别数字的时候需要转换一下自己的数据图片,一般我们的照片时白底黑字,这里在调试代码时也遇到了类似的问题,神经网络总是识别错误

MINIST

CNN卷积神经网络搭建

首先按照上期的代码搭建一下我们的CNN 卷积神经网络

import torch import torch.nn as nn from PIL import Image 定义神经网络 class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( 输入通道数 out_channels=16, 卷积核大小 stride=1, 如果想要 con2d 出来的图片长宽没有变化, output shape (16, 28, 28) nn.ReLU(), 在 2x2 空间里向下采样, output shape (16, 14, 14) ) self.conv2 = nn.Sequential( output shape (32, 14, 14) nn.ReLU(), output shape (32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 10) 前向反馈 def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) n_filters(输出通道数),第三个参数为卷积核大小,第四个参数为卷积步数,最后一个为pading,此参数为保证输入输出图片的尺寸大小一致

self.conv2 = nn.Sequential( output shape (32, 14, 14) nn.ReLU(), output shape (32, 7, 7) )

全连接层,最后使用nn.linear()全连接层进行数据的全连接数据结构(32*7*7,10)以上便是整个卷积神经网络的结构,

大致为:input-卷积-Relu-pooling-卷积-Relu-pooling-linear-output

卷积神经网络建完后,使用forward()前向传播神经网络进行输入图片的训练

初始化图片数据file_name = &39; 39;L&39;L&39;./model/CNN_NO1.pk&39;cpu&39;0&39;1&39;2&39;3&39;4&39;5&39;6&39;7&39;8&39;9&39;./model/CNN_NO1.pk&39;cpu'))

由于MNIST数据集只有0-9个数字,我们设置一个list存放这10个值

Python 有多种机制可以在本地禁用梯度计算:要在整个代码块中禁用渐变,有上下文管理器,如 no-grad 模式和推理模式。为了从梯度计算中更细粒度地排除子图,可以设置requires_grad张量。这里主要的好处是可以 节省大量的内存,我们只是为了测试我们的神经网络,可以采用类似做法。

然后我们把图片传入神经网络进行预测,我们可以计算出神经网络的预测结果与置信度

tensor([2.3575e-05, 2.2012e-03, 2.9602e-02, 1.7745e-02, 1.5333e-01, 3.1148e-03, 1.2766e-03, 2.0447e-02, 4.6207e-02, 7.2605e-01]) 9 9 0.7260508

运行以上代码,神经网络会输出每个值的预测置信度,我们从中挑选出置信度最大的。

本站所发布的文字与图片素材为非商业目的改编或整理,版权归原作者所有,如侵权或涉及违法,请联系我们删除

窝牛号 wwww.93ysy.com   沪ICP备2021036305号-1