Tensorflow入门——卷积神经网络MNIST手写数字识别
Image source: unsplash.com by Paweł Czerwiński
之前的文章我们介绍了如何用单层和多的全连接层神经网络识别手写数字,虽然识别率能够达到98%,但是由于全链接神经网络本身的局限性,其识别率已经很难再往上提升了。我们需要改进神经网络的结构,采用卷积神经网络(CNN)的结构来进一步提高的识别率。
关于CNN的原理,我在之前的文章中已经介绍,这篇文章就不过多赘述,我们直接进入实战阶段。
同样的,为了方便与读者交流,所有的代码都放在了这里:
Repository:
https://github.com/zht007/tensorflow-practice
1. 初始化W和B
卷积神经网络中权重W的shape尤其重要,CNN中的W实际上就是一个四维的filter,这个四维的filter由n个三维filter堆叠而成,n的大小等于输出channel的深度。当然三维filter又是由m个二维filter堆叠的,m的大小等于输入Channel的深度。
动画效果可以参见这里。
W的shape为[filter[0], filter[1], input_channel_depth, output_channel_depth]
例如W[6,6,3,4] 表示:二维的filter的size是(6,6), 输入的图片有3个Channel, 输出的图片有4个Channel
偏置B的Shape与output_channel保持一致即可,tensorflow会自动broadcast成正确的维度,B在这里与多层神经网络的的初始化相同。
神经网络的结构一共5层,3层CNN,2层全链接,最后一层与单层神经网络一样,10个神经元输出识别结果:数字是是0-9的概率。
# three convolutional layers with their channel counts, and a
# fully connected layer (the last layer has 10 softmax neurons)
K = 12 # first convolutional layer output depth
L = 24 # second convolutional layer output depth
M = 48 # third convolutional layer
N = 200 # fully connected layer
W1 = tf.Variable(tf.truncated_normal([6,6,1,K], stddev=0.1))
B1 = tf.Variable(tf.ones([K])/10)
W2 = tf.Variable(tf.truncated_normal([5,5,K,L], stddev=0.1))
B2 = tf.Variable(tf.ones([L])/10)
W3 = tf.Variable(tf.truncated_normal([4,4,L,M], stddev=0.1))
B3 = tf.Variable(tf.ones([M])/10)
W4 = tf.Variable(tf.truncated_normal([7*7*M,N], stddev=0.1))
B4 = tf.Variable(tf.ones([N])/10)
W5 = tf.Variable(tf.truncated_normal([N, 10], stddev=0.1))
B5 = tf.Variable(tf.zeros([10]))
该部分代码部分参考[2][3] with Apache License 2.0
2. 神经网络搭建
CNN的部分,我们用tensorflow自带的tf.nn.conv2d()方法:
tf.nn.conv2d(
input,
filter,
strides,
padding,
use_cudnn_on_gpu=True,
data_format='NHWC',
dilations=[1, 1, 1, 1],
name=None
)
用Tensorflow搭建神经网络的时候注意以下几点:
- Padding 这里使用的是'SAME',也就是步长(stride)为1的时候输入与输出图片的shape保持一致。
- 这里没有使用Max-Pooling层来"压缩"图片,而是增加stride(第二层和第三层Stride 为2)的方式,效果是一样的。28x28的图片经过两层CNN之后,压缩成了14x14和7x7的图片。
- CNN与全连接神经网络连接之前,需要将CNN输出的图片拆开拼接成一维的向量(Flatten or Reshape)。
Y1 = tf.nn.relu(tf.nn.conv2d(X, W1, strides = [1,1,1,1], padding='SAME') + B1)
Y2 = tf.nn.relu(tf.nn.conv2d(Y1,W2, strides = [1,2,2,1], padding='SAME') + B2)
Y3 = tf.nn.relu(tf.nn.conv2d(Y2,W3, strides = [1,2,2,1], padding='SAME') + B3)
#flat the inputs for the fully connected nn
YY3 = tf.reshape(Y3, shape = (-1,7*7*M))
Y4 = tf.nn.relu(tf.matmul(YY3, W4) + B4)
Y4d = tf.nn.dropout(Y4,rate = drop_rate)
Ylogits = tf.matmul(Y4d, W5) + B5
Y = tf.nn.softmax(Ylogits)
该部分代码部分参考[2][3] with Apache License 2.0
3. 识别效果
在其他参数都没改变的情况下,仅仅改变了神经网络的结构,可以看出识别率已经超出99%了。
目前我通过CNN的神经网络训练出来的分类器参加Kaggle的比赛,最好成绩是识别率99.3,全球排名第792名。
4. CNN结构的Keras实现
如果用Keras这个高级的API搭建CNN就更加简单了,无需初始化W和B,只需要关心神经网络的结构本身就行了。
使用Keras的layers.Conv2D()方法,注意其中的参数filters 是输出Channel的depth,Kernel_size 是二维filter的shape,实现相同结构的代码如下:
model = models.Sequential()
model.add(layers.Conv2D(filters = 12, kernel_size=(6,6), strides=(1,1),
padding = 'same', activation = 'relu',
input_shape = (28,28,1)))
model.add(layers.Conv2D(filters = 24,kernel_size=(5,5),strides=(2,2),
padding = 'same', activation = 'relu'))
model.add(layers.Conv2D(filters = 48,kernel_size=(4,4),strides=(2,2),
padding = 'same', activation = 'relu'))
model.add(layers.Flatten())
model.add(layers.Dense(units=200, activation='relu'))
model.add(layers.Dropout(0.25))
model.add(layers.Dense(units=10, activation='softmax'))
参考资料
[1]https://www.kaggle.com/c/digit-recognizer/data
[2]https://codelabs.developers.google.com/codelabs/cloud-tensorflow-mnist/#0
[3]https://github.com/GoogleCloudPlatform/tensorflow-without-a-phd.git
[4]https://www.tensorflow.org/api_docs/
相关文章
Tensorflow入门——单层神经网络识别MNIST手写数字
Tensorflow入门——多层神经网络MNIST手写数字识别
Tensorflow入门——分类问题cross_entropy的选择
同步到我的简书
你好鸭,hongtao!
@zerofive给您叫了一份外卖!
由 @cecilian 粥粥 迎着台风 骑着飞鸽 念着软哥金句:"大哥,把你脸上的分辨率调低点好吗?" 给您送来
新鲜出炉的薯条
吃饱了吗?跟我猜拳吧! 石头,剪刀,布~
如果您对我的服务满意,请不要吝啬您的点赞~
@onepagex
This post has received a free upvote by @OnePageX
OnePageX.com is the fastest way to convert STEEM and other assets to over 140 cryptocurrencies!
Great Rates, Low Fees & Fast Crypto Exchanges!
Check out our most recent news update.
真的是饿了
帅哥/美女!这是哪里?你是谁?我为什么会来这边?你不要给我点赞不要点赞,哈哈哈哈哈哈。假如我的留言打扰到你,请回复“取消”。
@teamcn-shop
Posted using Partiko Android
This post has been voted on by the SteemSTEM curation team and voting trail. It is elligible for support from @curie.
If you appreciate the work we are doing, then consider supporting our witness stem.witness. Additional witness support to the curie witness would be appreciated as well.
For additional information please join us on the SteemSTEM discord and to get to know the rest of the community!
Please consider setting @steemstem as a beneficiary to your post to get a stronger support.
Please consider using the steemstem.io app to get a stronger support.
Congratulations @hongtao! You have completed the following achievement on the Steem blockchain and have been rewarded with new badge(s) :
You can view your badges on your Steem Board and compare to others on the Steem Ranking
If you no longer want to receive notifications, reply to this comment with the word
STOP