tiny-cnn是一个基于CNN的开源库,它的License是BSD 3-Clause。作者也一直在维护更新,对进一步掌握CNN非常有帮助,因此以下介绍下tiny-cnn在windows7 64bit vs2013的编译及使用。

1.      从https://github.com/nyanp/tiny-cnn下载源代码:

$ git clone https://github.com/nyanp/tiny-cnn.git  版本为77d80a8,更新日期2016.01.22

2.      源文件里已经包括了vs2013project,vc/tiny-cnn.sln,默认是win32的,examples/main.cpp须要OpenCV的支持。这里新建一个x64的控制台projecttiny-cnn。

3.      仿照源project,将对应.h文件加入到新控制台project中。新加一个test_tiny-cnn.cpp文件;

4.      将examples/mnist中test.cpp和train.cpp文件里的代码拷贝到test_tiny-cnn.cpp文件里;

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <tiny_cnn/tiny_cnn.h>
#include <opencv2/opencv.hpp> using namespace tiny_cnn;
using namespace tiny_cnn::activation; // rescale output to 0-100
template <typename Activation>
double rescale(double x)
{
Activation a;
return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);
} void construct_net(network<mse, adagrad>& nn);
void train_lenet(std::string data_dir_path);
// convert tiny_cnn::image to cv::Mat and resize
cv::Mat image2mat(image<>& img);
void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, vec_t& data);
void recognize(const std::string& dictionary, const std::string& filename, int target); int main()
{
//train
std::string data_path = "D:/Download/MNIST";
train_lenet(data_path); //test
std::string model_path = "D:/Download/MNIST/LeNet-weights";
std::string image_path = "D:/Download/MNIST/";
int target[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; for (int i = 0; i < 10; i++) {
char ch[15];
sprintf(ch, "%d", i);
std::string str;
str = std::string(ch);
str += ".png";
str = image_path + str; recognize(model_path, str, target[i]);
} std::cout << "ok!" << std::endl;
return 0;
} void train_lenet(std::string data_dir_path) {
// specify loss-function and learning strategy
network<mse, adagrad> nn; construct_net(nn); std::cout << "load models..." << std::endl; // load MNIST dataset
std::vector<label_t> train_labels, test_labels;
std::vector<vec_t> train_images, test_images; parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte",
&train_labels);
parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte",
&train_images, -1.0, 1.0, 2, 2);
parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte",
&test_labels);
parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte",
&test_images, -1.0, 1.0, 2, 2); std::cout << "start training" << std::endl; progress_display disp(train_images.size());
timer t;
int minibatch_size = 10;
int num_epochs = 30; nn.optimizer().alpha *= std::sqrt(minibatch_size); // create callback
auto on_enumerate_epoch = [&](){
std::cout << t.elapsed() << "s elapsed." << std::endl;
tiny_cnn::result res = nn.test(test_images, test_labels);
std::cout << res.num_success << "/" << res.num_total << std::endl; disp.restart(train_images.size());
t.restart();
}; auto on_enumerate_minibatch = [&](){
disp += minibatch_size;
}; // training
nn.train(train_images, train_labels, minibatch_size, num_epochs,
on_enumerate_minibatch, on_enumerate_epoch); std::cout << "end training." << std::endl; // test and show results
nn.test(test_images, test_labels).print_detail(std::cout); // save networks
std::ofstream ofs("D:/Download/MNIST/LeNet-weights");
ofs << nn;
} void construct_net(network<mse, adagrad>& nn) {
// connection table [Y.Lecun, 1998 Table.1]
#define O true
#define X false
static const bool tbl[] = {
O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
};
#undef O
#undef X // construct nets
nn << convolutional_layer<tan_h>(32, 32, 5, 1, 6) // C1, 1@32x32-in, 6@28x28-out
<< average_pooling_layer<tan_h>(28, 28, 6, 2) // S2, 6@28x28-in, 6@14x14-out
<< convolutional_layer<tan_h>(14, 14, 5, 6, 16,
connection_table(tbl, 6, 16)) // C3, 6@14x14-in, 16@10x10-in
<< average_pooling_layer<tan_h>(10, 10, 16, 2) // S4, 16@10x10-in, 16@5x5-out
<< convolutional_layer<tan_h>(5, 5, 5, 16, 120) // C5, 16@5x5-in, 120@1x1-out
<< fully_connected_layer<tan_h>(120, 10); // F6, 120-in, 10-out
} void recognize(const std::string& dictionary, const std::string& filename, int target) {
network<mse, adagrad> nn; construct_net(nn); // load nets
std::ifstream ifs(dictionary.c_str());
ifs >> nn; // convert imagefile to vec_t
vec_t data;
convert_image(filename, -1.0, 1.0, 32, 32, data); // recognize
auto res = nn.predict(data);
std::vector<std::pair<double, int> > scores; // sort & print top-3
for (int i = 0; i < 10; i++)
scores.emplace_back(rescale<tan_h>(res[i]), i); std::sort(scores.begin(), scores.end(), std::greater<std::pair<double, int>>()); for (int i = 0; i < 3; i++)
std::cout << scores[i].second << "," << scores[i].first << std::endl; std::cout << "the actual digit is: " << scores[0].second << ", correct digit is: "<<target<<std::endl; // visualize outputs of each layer
//for (size_t i = 0; i < nn.depth(); i++) {
// auto out_img = nn[i]->output_to_image();
// cv::imshow("layer:" + std::to_string(i), image2mat(out_img));
//}
//// visualize filter shape of first convolutional layer
//auto weight = nn.at<convolutional_layer<tan_h>>(0).weight_to_image();
//cv::imshow("weights:", image2mat(weight)); //cv::waitKey(0);
} // convert tiny_cnn::image to cv::Mat and resize
cv::Mat image2mat(image<>& img) {
cv::Mat ori(img.height(), img.width(), CV_8U, &img.at(0, 0));
cv::Mat resized;
cv::resize(ori, resized, cv::Size(), 3, 3, cv::INTER_AREA);
return resized;
} void convert_image(const std::string& imagefilename,
double minv,
double maxv,
int w,
int h,
vec_t& data) {
auto img = cv::imread(imagefilename, cv::IMREAD_GRAYSCALE);
if (img.data == nullptr) return; // cannot open, or it's not an image cv::Mat_<uint8_t> resized;
cv::resize(img, resized, cv::Size(w, h)); // mnist dataset is "white on black", so negate required
std::transform(resized.begin(), resized.end(), std::back_inserter(data),
[=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });
}

5.      编译时会提示几个错误,解决方法是:

(1)、error C4996。解决方法:将宏_SCL_SECURE_NO_WARNINGS加入到属性的预处理器定义中;

(2)、调用for_函数时,error C2668,对重载函数的调用不明教,解决方法:将for_中的第三个參数强制转化为size_t类型;

6.      执行程序,train时,执行结果例如以下图所看到的:

watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="" />

7.      对生成的model进行測试,通过绘图工具,每一个数字生成一张图像,共10幅,例如以下图:

通过导入train时生成的model。对这10张图像进行识别,识别结果例如以下图,当中6和9被误识为5和1:

GitHub:https://github.com/fengbingchun/NN

tiny-cnn开源库的使用(MNIST)的更多相关文章

  1. 深度学习开源库tiny-dnn的使用(MNIST)

    tiny-dnn是一个基于DNN的深度学习开源库,它的License是BSD 3-Clause.之前名字是tiny-cnn是基于CNN的,tiny-dnn与tiny-cnn相关又增加了些新层.此开源库 ...

  2. 站在巨人的肩膀上,C++开源库大全

    程序员要站在巨人的肩膀上,C++拥有丰富的开源库,这里包括:标准库.Web应用框架.人工智能.数据库.图片处理.机器学习.日志.代码分析等. 标准库 C++ Standard Library:是一系列 ...

  3. GitHub C 和 C++ 开源库的清单(含示例代码)

    内容包括:标准库.Web应用框架.人工智能.数据库.图片处理.机器学习.日志.代码分析等. 标准库 C++标准库,包括了STL容器,算法和函数等. C++ Standard Library:是一系列类 ...

  4. C++开源库大全(转)

    程序员要站在巨人的肩膀上,C++拥有丰富的开源库,这里包括:标准库.Web应用框架.人工智能.数据库.图片处理.机器学习.日志.代码分析等. 标准库 C++ Standard Library:是一系列 ...

  5. C++开源库大全

        标准库 C++ Standard Library:是一系列类和函数的集合,使用核心语言编写,也是C++ISO自身标准的一部分. Standard Template Library:标准模板库 ...

  6. 【踩坑速记】开源日历控件,顺便全面解析开源库打包发布到Bintray/Jcenter全过程(新),让开源更简单~

    一.写在前面 自使用android studio开始,就被它独特的依赖方式:compile 'com.android.support:appcompat-v7:25.0.1'所深深吸引,自从有了它,麻 ...

  7. Java下好用的开源库推荐

    作者:Jack47 转载请保留作者和原文出处 欢迎关注我的微信公众账号程序员杰克,两边的文章会同步,也可以添加我的RSS订阅源. 本文想介绍下自己在Java下做开发使用到的一些开源的优秀编程库,会不定 ...

  8. 第三方开源库和jar包的区别

    jar包和第三方开源库的根本区别在于,开源库的功能比jar包功能更强大,通过引入库项目可以访问java文件以及该开源库项目下的资源文件,例如图片,layout等文件 jar包中只能放class文件 引 ...

  9. 【转】用JitPack发布开源库时附加文档和源码

    来自:http://www.gcssloop.com/course/jitpack-sources-javadoc 用JitPack发布开源库时附加文档和源码 很早之前写过一篇用JitPack发布An ...

随机推荐

  1. 〖Linux〗bash和expect执行ssh命令行sshcmd.exp

    #!/usr/bin/expect -f # sudo apt-get install expect # ./ssh.exp user passwd server set user [lrange $ ...

  2. 【DB2】delete大表不记录日志的正确操作

    一.原始方法 在删除大表的时候,经常会由于数据量太大,造成日志文件满了,接着无法删除数据. 以下是删除大表不记录日志的具体步骤: 1.临时设置自动提交关闭 (使用命令db2 list command ...

  3. .NET/C#中对自定义对象集合进行自定义排序的方法

    一个集合可否排序,要看系统知不知道排序的规则,像内建的系统类型,int ,string,short,decimal这些,系统知道怎么排序,而如果一个集合里面放置的是自定义类型,比如自己定义了一个Car ...

  4. SpringBoot之actuator

    在springBoot中集成actuator可以很方便的管理和监控应用的状态. 暴露的Restful接口有: HTTP方法 路径 描述 鉴权 GET /autoconfig 查看自动配置的使用情况 t ...

  5. 图解Win7如何手动添加受信任证书

    点击开始—>运行,如下图所示:   弹出“控制台”窗口如下,如下图所示:   点击“文件—添加/删除管理单元”,如下图所示:   选择“证书”,并点击“添加”,如下图所示:   在弹出的窗口上选 ...

  6. 三种分布式对象主流技术——COM、Java和COBRA

    既上一遍,看到还有一遍将关于 对象的, 分布式对象, 故摘抄入下: 目前国际上,分布式对象技术有三大流派——COBRA.COM/DCOM和Java.CORBA技术是最早出现的,1991年OMG颁布了C ...

  7. 输出控制台信息到日志 并 通过cronolog对tomcat进行日志切分

    windows下tomcat默认并不会把控制台输出的信息都记录进日志文件.但是在生产环境中,出现问题时,控制台的日志输出是无法查据的,因此需要将日志记录下来. 解决方法: 输出日志到文件 修改tomc ...

  8. numpy中的argpartition

    numpy.argpartition(a, kth, axis=-1, kind='introselect', order=None) 在快排算法中,有一个典型的操作:partition.这个操作指: ...

  9. 什么是分表和分区 MySql数据库分区和分表方法

    1.为什么要分表和分区 日常开发中我们经常会遇到大表的情况,所谓的大表是指存储了百万级乃至千万级条记录的表.这样的表过于庞大,导致数据库在查询和插入的时候耗时太长,性能低下,如果涉及联合查询的情况,性 ...

  10. IT技术需求建立时需考虑的因素

      2012-11-13 内容存档在evernote,笔记名"IT技术需求建立时需考虑的因素"