当前位置: 首页>編程日記>正文

Tensorflow手写数字识别

Tensorflow手写数字识别

Tensorflow手写数字识别

  • 前言
  • 一、关于mnist数据集
  • 二、搭建过程
    • 1.导入数据集
    • 2.数据集预处理
    • 3.构建全连接层模型
    • 4.梯度下降求最小Loss
    • 5.测试集查看模型训练精度
    • 6.多进程提高训练速度
    • 7.绘制训练效果图
  • 三、完整代码
  • 总结


前言

机器学习过程记录:
DL陆续学习了一年,读过了书看过了很多源码和视频,目前感觉碰到瓶颈,希望输出能提高输入,早日熟练的掌握相关技巧。


最开始study的源码,看不懂了可以交流

一、关于mnist数据集

  MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

二、搭建过程

1.导入数据集

Tensorflow自带的数据集里有多个都可以导入:
dataset中相关的数据集
代码如下:

#加载手写数据集 dataset.mnist.load_data()
(x_train,y_train),(x_test,y_test) = datasets.mnist.load_data()

2.数据集预处理

代码如下:

#训练集格式化
#tf.convert_to_tensor便于GPU加速,转换数据格式
x_train = tf.convert_to_tensor(x_train,dtype= tf.float32)/255.
y_train = tf.convert_to_tensor(y_train,dtype= tf.int32)
# one-hot 编码
y_train = tf.one_hot(y_train,depth= 10)
#测试集格式化
x_test = tf.convert_to_tensor(x_test,dtype=tf.float32)/255.
# 利用GPU并行加速能力,生成一个Dataset,一次运算多张图片(batch)
# bitch(100) 即将训练集分成60K/100 = 600个Dataset
#对训练和测试数据集分别做切片处理
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(128)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(128)

该处Slices的操作降低了整个数据集的熵,方便后续的参数运算和求解。

3.构建全连接层模型

代码如下:

# 搭建多层神经网络,512,256,10 为输出
#使用激活函数防止梯度弥散或消失
model = Sequential([layers.Dense(512,activation='relu'),layers.Dense(256,activation='relu'),layers.Dense(128,activation='relu'),layers.Dense(64,activation='relu'),layers.Dense(32,activation='relu'),layers.Dense(10)])

自己玩所以构建了很多层的Dense,力争更高的准确度。

4.梯度下降求最小Loss

代码如下:

for step,(x,y) in enumerate(train_dataset):with tf.GradientTape() as tape:# [b,28,28] => [b,28*28]x = tf.reshape(x,(-1,28*28))# [b,784] => [b,10]out = model(x)#compute loss {(out-y)**2/n }loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]#model.trainable_variables => 该层跟踪的所有可训练重量列表。grads = tape.gradient(loss,model.trainable_variables)# Auto update w1,b1,w2,b2,w3,b3# list(zip('abcdefg', range(3), range(4)))>>> [('a', 0, 0), ('b', 1, 1), ('c', 2, 2)]# zip(grads,model.trainable_variables)>>> [(grade,w1),(grade,b1),(grade,w2)...]optimizer.apply_gradients(zip(grads,model.trainable_variables))if step % 100 == 0:print(epoch,step,'Loss:',loss.numpy())

因为要求Loss的最小值,所以这里是梯度下降,如果求最大成功率应该是梯度上升。

5.测试集查看模型训练精度

代码如下:

for step,(x,y) in enumerate(test_dataset):# [b,28,28] => [b,28*28]x = tf.reshape(x, (-1, 28 * 28))#model.summary()#查看模型结构out = model(x)print(out.numpy()[0])predictions = tf.argmax(out,axis=1)labels = tf.cast(y,tf.int64)equalValue =tf.cast(tf.equal(predictions,labels),tf.float32)#如果想直接求测试集的正确率,也可以用reduce_mean直接对整个数据集计算# correct = tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))correct = tf.reduce_sum(tf.cast(tf.equal(predictions,labels),tf.float32))total_correct +=int(correct)total_count += x.shape[0]

如果测试集不大的话可以直接对整个数据集求解,不需要循环注入,代码中对每个epoch都计算准确率,如果想降低运算压力也可以多个epoch计算一次准确率。

6.多进程提高训练速度

代码如下(示例):

def Test2():p = Pool(5)for i in range(100):print("数据集训练第:" + str(i + 1) + "次")p.apply_async(train_epoch,args=(i,))p.close()p.join()programRunTime = datetime.datetime.now() - programBeginTimeprint("程序运行时间:" + str(programRunTime))

python线程和进程区别的原因,所以这里对于矩阵这种的密集运算,用多进程直接拉满CPU提高运行速度。

7.绘制训练效果图

代码如下(示例):

import matplotlib.pyplot as plt
x1 = np.linspace(1,len(losslist),len(losslist))
y1 = losslist
plt.xlabel('Epoch')
plt.ylabel('MSE-loss')
plt.title('Mean square deviation curve plot')
plt.plot(x1,y1,'bo-',label='Mean square',markersize=2)
#图像说明-'Mean square'展示
#注:loc=(‘best’, ‘upper right’, ‘upper left’, ‘lower left’, ‘lower right’, ‘right’, ‘center left’, ‘center , right’, ‘lower center’, ‘upper center’, ‘center’) 
plt.legend(loc="upper right")
#展示图像
plt.show()

这块儿没写进完整代码里,不过画图看训练效果调 matplotlib就可以了。

三、完整代码

import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential
from  tensorflow.keras import datasets,optimizers,layers
import datetime
from multiprocessing.pool import PoolprogramBeginTime = datetime.datetime.now()#加载手写数据集 dataset.mnist.load_data()
(x_train,y_train),(x_test,y_test) = datasets.mnist.load_data()
#训练集格式化
#tf.convert_to_tensor便于GPU加速,转换数据格式
x_train = tf.convert_to_tensor(x_train,dtype= tf.float32)/255.
y_train = tf.convert_to_tensor(y_train,dtype= tf.int32)
# one-hot 编码
y_train = tf.one_hot(y_train,depth= 10)
#测试集格式化
x_test = tf.convert_to_tensor(x_test,dtype=tf.float32)/255.
# 利用GPU并行加速能力,生成一个Dataset,一次运算多张图片(batch)
# bitch(100) 即将训练集分成60K/100 = 600个Dataset#slices :
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(128)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(128)
# 搭建多层神经网络,512,256,10 为输出# model :模型 /dense : 稠密的 / layers :/ Sequential : 按次序的;顺序的;序列的
model = Sequential([layers.Dense(512,activation='relu'),layers.Dense(256,activation='relu'),layers.Dense(128,activation='relu'),layers.Dense(64,activation='relu'),layers.Dense(32,activation='relu'),layers.Dense(10)])# 自动更新W和G的参数
# SGP Gradient descent (with momentum) optimizer
# optimizer:优化器
optimizer = optimizers.SGD(learning_rate= 0.01)def train_epoch(epoch):#Return an enumerate object#enumerate : 枚举for step,(x,y) in enumerate(train_dataset):with tf.GradientTape() as tape:# [b,28,28] => [b,28*28]x = tf.reshape(x,(-1,28*28))# [b,784] => [b,10]out = model(x)#compute loss {(out-y)**2/n }loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]#model.trainable_variables => 该层跟踪的所有可训练重量列表。grads = tape.gradient(loss,model.trainable_variables)# Auto update w1,b1,w2,b2,w3,b3# list(zip('abcdefg', range(3), range(4)))>>> [('a', 0, 0), ('b', 1, 1), ('c', 2, 2)]optimizer.apply_gradients(zip(grads,model.trainable_variables))if step % 100 == 0:print(epoch,step,'Loss:',loss.numpy())total_correct,total_count = 0,0for step,(x,y) in enumerate(test_dataset):# [b,28,28] => [b,28*28]x = tf.reshape(x, (-1, 28 * 28))#model.summary()#查看模型结构out = model(x)print(out.numpy()[0])predictions = tf.argmax(out,axis=1)labels = tf.cast(y,tf.int64)equalValue =tf.cast(tf.equal(predictions,labels),tf.float32)#如果想直接求测试集的正确率,也可以用reduce_mean直接对整个数据集计算# correct = tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))correct = tf.reduce_sum(tf.cast(tf.equal(predictions,labels),tf.float32))total_correct +=int(correct)total_count += x.shape[0]print(epoch,'测试集正确率Acc:', total_correct/total_count)
def Test1():for i in range(100):print("数据集训练第:" + str(i+1) + "次")train_epoch(i)programRunTime = datetime.datetime.now() - programBeginTimeprint("程序运行时间:"+str(programRunTime))
def Test2():p = Pool(5)for i in range(100):print("数据集训练第:" + str(i + 1) + "次")p.apply_async(train_epoch,args=(i,))p.close()p.join()programRunTime = datetime.datetime.now() - programBeginTimeprint("程序运行时间:" + str(programRunTime))
if __name__ == '__main__':#Test1()Test2()

总结

未完待续:
后续整理一下 再更新,模型训练完成后的使用以及优化<压缩、剪枝>,多平台部署后面完善。


https://www.fengoutiyan.com/post/14858.html

相关文章:

  • 手写识别算法
  • 手写数字识别原理
  • 手写英文字母识别
  • 手写字母识别
  • 手写数字的识别
  • 手写数字1
  • 手写字体识别代码
  • 识别手写数字的软件
  • 鏡像模式如何設置在哪,圖片鏡像操作
  • 什么軟件可以把圖片鏡像翻轉,C#圖片處理 解決左右鏡像相反(旋轉圖片)
  • 手機照片鏡像翻轉,C#圖像鏡像
  • 視頻鏡像翻轉軟件,python圖片鏡像翻轉_python中鏡像實現方法
  • 什么軟件可以把圖片鏡像翻轉,利用PS實現圖片的鏡像處理
  • 照片鏡像翻轉app,java實現圖片鏡像翻轉
  • 什么軟件可以把圖片鏡像翻轉,python圖片鏡像翻轉_python圖像處理之鏡像實現方法
  • matlab下載,matlab如何鏡像處理圖片,matlab實現圖像鏡像
  • 圖片鏡像翻轉,MATLAB:鏡像圖片
  • 鏡像翻轉圖片的軟件,圖像處理:實現圖片鏡像(基于python)
  • canvas可畫,JavaScript - canvas - 鏡像圖片
  • 圖片鏡像翻轉,UGUI優化:使用鏡像圖片
  • Codeforces,CodeForces 1253C
  • MySQL下載安裝,Mysql ERROR: 1253 解決方法
  • 勝利大逃亡英雄逃亡方案,HDU - 1253 勝利大逃亡 BFS
  • 大一c語言期末考試試題及答案匯總,電大計算機C語言1253,1253《C語言程序設計》電大期末精彩試題及其問題詳解
  • lu求解線性方程組,P1253 [yLOI2018] 扶蘇的問題 (線段樹)
  • c語言程序設計基礎題庫,1253號C語言程序設計試題,2016年1月試卷號1253C語言程序設計A.pdf
  • 信奧賽一本通官網,【信奧賽一本通】1253:抓住那頭牛(詳細代碼)
  • c語言程序設計1253,1253c語言程序設計a(2010年1月)
  • 勝利大逃亡英雄逃亡方案,BFS——1253 勝利大逃亡
  • 直流電壓測量模塊,IM1253B交直流電能計量模塊(艾銳達光電)
  • c語言程序設計第三版課后答案,【渝粵題庫】國家開放大學2021春1253C語言程序設計答案
  • 18轉換為二進制,1253. 將數字轉換為16進制
  • light-emitting diode,LightOJ-1253 Misere Nim
  • masterroyale魔改版,1253 Dungeon Master
  • codeformer官網中文版,codeforces.1253 B
  • c語言程序設計考研真題及答案,2020C語言程序設計1253,1253計算機科學與技術專業C語言程序設計A科目2020年09月國家開 放大學(中央廣播電視大學)
  • c語言程序設計基礎題庫,1253本科2016c語言程序設計試題,1253電大《C語言程序設計A》試題和答案200901
  • 肇事逃逸車輛無法聯系到車主怎么辦,1253尋找肇事司機