窝牛号

PaddlePaddle——飞桨深度学习实现手写数字识别任务

通过上一期文章的分享,我们介绍了paddlepaddle以及成功安装了paddlepaddle

百度深度学习平台PaddlePaddle——飞桨基础知识介绍

本期我们就来认识一下paddlepaddle的代码,进行一个简单的手写任务的实现

手写数字识别任务

数字识别是计算机从纸质文档、照片或其他来源接收、理解并识别可读的数字的能力,目前比较受关注的是手写数字识别。手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别、手写邮政编码识别等领域,大大缩短了业务处理时间,提升了工作效率和质量。

MINIST数据

MINIST的数据分为2个部分:55000份训练数据(mnist.train)和10000份测试数据(mnist.test)。这个划分有重要的象征意义,他展示了在机器学习中如何使用数据。在训练的过程中,我们必须单独保留一份没有用于机器训练的数据作为验证的数据,这才能确保训练的结果的可行性。

前面已经提到,每一份MINIST数据都由图片以及标签组成。我们将图片命名为“x”,将标记数字的标签命名为“y”。训练数据集和测试数据集都是同样的结构,例如:训练的图片名为 mnist.train.images 而训练的标签名为 mnist.train.labels。

每一个图片均为28×28像素,我们可以将其理解为一个二维数组的结构:

在处理如下图所示的手写邮政编码的简单图像分类任务时,可以使用基于MNIST数据集的手写数字识别模型。MNIST是深度学习领域标准、易用的成熟数据集,包含50 000条训练样本和10 000条测试样本。

手写数字识别任务示意图

任务输入:一系列手写数字图片,其中每张图片都是28x28的像素矩阵。任务输出:经过了大小归一化和居中处理,输出对应的0~9的数字标签。模型搭建--代码import paddle from paddle.nn import Linear import paddle.nn.functional as F import os import numpy as np import matplotlib.pyplot as plt 39;train& 显示第一batch的第一个图像 plt.figure(&34;) 39;on& 关掉坐标轴为 off plt.title(&39;) 34;图像数据形状和对应数据为:&34;图像标签形状和对应数据为:&34;n打印第一个batch的第一个图像,对应标签数字为{}& 定义mnist数据识别网络结构 class MNIST(paddle.nn.Layer): def __init__(self): super(MNIST, self).__init__() 定义网络结构的前向计算过程 def forward(self, inputs): outputs = self.fc(inputs) return outputs

这里我们首先使用一个单层且没有非线性变换的模型,看看此简单的神经网络模型能否正确识别出手写数字

这里由于是一层的神经网络,其跟CNN还是有所差距,这个我们后期进行此方面的代码优化

图像预处理 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28] assert len(img.shape) == 3 batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2] 将图像形式reshape为[batch_size, 784] img = paddle.reshape(img, [batch_size, img_h*img_w]) return img

这里我们通过norm函数进行输出图像的处理,把范围为0--255归一化到0--1

神经网络的训练39;cv2& 声明网络结构 model = MNIST() def train(model): 加载训练集 batch_size 设为 16 train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode=&39;), batch_size=16, shuffle=True) 39;float32&39;float32&前向计算的过程 predicts = model(images) 每训练了1000批次的数据,打印下当前Loss的情况 if batch_id % 1000 == 0: print(&34;.format(epoch, batch_id, avg_loss.numpy())) 39;model/mnist.pdparams&39;model/mnist.pdparams')来保存已经训练好的模型,以便后期进行数字的识别,通过训练loss一直下不去,这样我们进行后期优化。

epoch_id: 0, batch_id: 0, loss is: [26.44594] epoch_id: 0, batch_id: 1000, loss is: [1.5970355] epoch_id: 0, batch_id: 2000, loss is: [3.3931825] epoch_id: 0, batch_id: 3000, loss is: [3.6991172] epoch_id: 1, batch_id: 0, loss is: [2.8589249] epoch_id: 1, batch_id: 1000, loss is: [4.109815] epoch_id: 1, batch_id: 2000, loss is: [5.390366] epoch_id: 1, batch_id: 3000, loss is: [4.619067] epoch_id: 2, batch_id: 0, loss is: [2.4614942] epoch_id: 2, batch_id: 1000, loss is: [2.2340536] epoch_id: 2, batch_id: 2000, loss is: [2.4032607] epoch_id: 2, batch_id: 3000, loss is: [1.7580397] epoch_id: 3, batch_id: 0, loss is: [2.8790944] epoch_id: 3, batch_id: 1000, loss is: [1.5264606] epoch_id: 3, batch_id: 2000, loss is: [3.5032237] epoch_id: 3, batch_id: 3000, loss is: [3.4746733] epoch_id: 4, batch_id: 0, loss is: [2.6894484] epoch_id: 4, batch_id: 1000, loss is: [2.1867495] epoch_id: 4, batch_id: 2000, loss is: [3.2445798] epoch_id: 4, batch_id: 3000, loss is: [8.163915] epoch_id: 5, batch_id: 0, loss is: [1.6421318] epoch_id: 5, batch_id: 1000, loss is: [3.7984645] epoch_id: 5, batch_id: 2000, loss is: [2.2743425] epoch_id: 5, batch_id: 3000, loss is: [2.3635402] epoch_id: 6, batch_id: 0, loss is: [5.423148] epoch_id: 6, batch_id: 1000, loss is: [4.778616] epoch_id: 6, batch_id: 2000, loss is: [3.4756808] epoch_id: 6, batch_id: 3000, loss is: [3.926146] epoch_id: 7, batch_id: 0, loss is: [3.7117333] epoch_id: 7, batch_id: 1000, loss is: [3.4605653] epoch_id: 7, batch_id: 2000, loss is: [4.286289] epoch_id: 7, batch_id: 3000, loss is: [3.027922] epoch_id: 8, batch_id: 0, loss is: [3.116638] epoch_id: 8, batch_id: 1000, loss is: [2.687238] epoch_id: 8, batch_id: 2000, loss is: [4.823868] epoch_id: 8, batch_id: 3000, loss is: [2.307558] epoch_id: 9, batch_id: 0, loss is: [1.770024] epoch_id: 9, batch_id: 1000, loss is: [1.5893741] epoch_id: 9, batch_id: 2000, loss is: [4.77549] epoch_id: 9, batch_id: 3000, loss is: [2.1042237]

本站所发布的文字与图片素材为非商业目的改编或整理,版权归原作者所有,如侵权或涉及违法,请联系我们删除

窝牛号 wwww.93ysy.com   沪ICP备2021036305号-1