基于mnist数据集的手势识别
参考别人的博客搭建出了基于cifar数据集的分类神经网络,接下来打算使用自己尝试搭建个手势识别的神经网络,接下来分为三个部分介绍:1. 读取数据,2. 搭建神经网络,3. 训练与评估
读取数据
通过下载网上公开的mnist数据,发现数据分为test和train部分,且保存为csv格式如下图,因此读取数据集的时候需要对其进行处理和包装,将其转化成能识别的DataSet格式。
通过定义如下新类,来读取和转化格式
搭建神经网络
神经网络与上一个差不多由卷积–>池化–>再卷积–>再池化–>全连接层组成。与训练过程都与上篇博客一致上一篇博客
结果
下图为训练结果,可以看出,已经产生过拟合了,解决过拟合的方法可以调节weight_decay参数。或者增加正则化项。