https://paulswith.github.io/2018/02/24/%E8%BF%81%E7%A7%BB%E5%AD%A6%E4%B9%A0InceptionV3/ 上文记录了如何从一个别人训练好的模型, 切入我们自己的图片, 改为我们自己的模型.
本来以为移植到手机很简单, 但是不简单的是我的模型本身就是迁移学习别人的模型,有很多莫名其面的坑, 在CoreML经历了N个坑后,1点14分我搞掂了.

项目源码和转换源码已经上传到git.
https://github.com/Paulswith/machineLearningIntro/tree/master/classification_101

转化为mlmodel

说说转换为mlmodel的工具有两个:

接着往下看:

是否是graph-pb?

如果你跟我一样, 训练的模型, 从tensorflow的代码保存下来的, 调用的:

  1. 1
  1. saver.save(sess, MODEL_SAVEPATH, global_step=50)

它并不会保存出一个pb文件, 其中的.meta也需要其他方式转换似乎也可以, 我没有尝试过.
用这个方法, 你需要在上方代码的下面加两行,就可以继续:

  1. 1
  2. 2
  1. if i %SAVE_EPOCH == 0:
  2. tf.train.write_graph(sess.graph, MODEL_SAVE_DIR, 'model.pbtxt')

pdtxt固化为pd

操作参考链接https://www.jianshu.com/p/091415b114e2
我是直接使用的bezel, 编译tensorflow源码后, 直接使用, 其中参数跟着填, 需要注意的是output_node:

导入化图

导入图和查看图的节点信息:
如果你的图不属于pb文件, 那么就会在导入图的时候报错的.

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  1. with open(TF_MODEL_FILE, 'rb') as f:
  2. serialized = f.read()
  3. tf.reset_default_graph()
  4. original_gdef = tf.GraphDef()
  5. original_gdef.ParseFromString(serialized)
  6. with tf.Graph().as_default() as g:
  7. tf.import_graph_def(original_gdef, name='')
  8. ops = g.get_operations()
  9. try:
  10. for i in range(10000):
  11. print('op id {} : op name: {}, op type: "{}"'.format(str(i),ops[i].name, ops[i].type))
  12. except:
  13. print("全部节点已打印完毕.")
  14. pass

预处理节点

其实这一步个人不是很清楚很知道它做了什么,但确是不得不做的. 最后的大小看着也不像是”减包”
需要注意两点:

  • input_node_names: 这里填写的节点从上方的代码可以打印看得到的, 实际在训练模型的时候, 我们直接喂图片的节点是在”import/DecodeJpeg/contents”, 而这里”必须是import/Mul”.
  • output_node_names: 因为模型是剪切拼接的, 这跟tensorflow直接调用是一样的节点.
    最后生成一个完整的pb文件.
    1. 1
    2. 2
    3. 3
    4. 4
    5. 5
    6. 6
    7. 7
    8. 8
    9. 9
    10. 10
    11. 11
    12. 12
    1. input_node_names = ['import/Mul', 'BottleneckInputPlaceholder'] # 本来以为是import/DecodeJpeg/contents, 实际上是Mul(tfcoreml-git上说的)
    2. output_node_names = ['import/pool_3/_reshape','final_train_ops/softMax_last'] # 想要保存的节点 , 'final_train_ops/softMax_last'
    3. gdef = strip_unused_lib.strip_unused(
    4. input_graph_def = original_gdef,
    5. input_node_names = input_node_names,
    6. output_node_names = output_node_names,
    7. placeholder_type_enum = dtypes.float32.as_datatype_enum)
    8. with gfile.GFile(FROZEN_MODEL_FILE, "wb") as f:
    9. f.write(gdef.SerializeToString())

开始转换

  • input_tensor_shapes: 是placeholder 和 input节点, 方括号的第一个参数是batch大小,代表一张一张的喂给它. 字典里面这两个, 对应生成后的InceptionV3_input的两个属性.

  • output_tensor_names: 训练后得到的节点, 对应生成后的InceptionV3_output的两个属性

    1. 1
    2. 2
    3. 3
    4. 4
    5. 5
    6. 6
    7. 7
    8. 8
    9. 9
    10. 10
    11. 11
    12. 12
    13. 13
    14. 14
    15. 15
    16. 16
    17. 17
    1. input_tensor_shapes = {
    2. "import/Mul:0":[1,299,299,3], # batch size is 1
    3. "BottleneckInputPlaceholder:0":[1,2048],
    4. }
    5. output_tensor_names = ['import/pool_3/_reshape:0','final_train_ops/softMax_last:0']
    6. # Call the converter. This may take a while
    7. coreml_model = tfcoreml.convert(
    8. tf_model_path=FROZEN_MODEL_FILE,
    9. mlmodel_path=COREML_MODEL_FILE,
    10. input_name_shape_dict=input_tensor_shapes,
    11. output_feature_names=output_tensor_names,
    12. image_input_names = ['import/Mul:0'],
    13. red_bias = -1,
    14. green_bias = -1,
    15. blue_bias = -1,
    16. image_scale = 2.0/255.0)

上方具体的参数可以在方法看得到,后面四个参数就是我们输入图片时候的均值化, 还有个特殊的参数class_labels, output后的模型可以直接索引到标签, 但是在实践过程中, 我这个本身是迁移别人的学习的模型并起不到作用.

执行完成后生成文件:

ios-code调用

了解模型:

首先, 直接将inceptionV3.mlmodel拖入到工程:
导入头文件, inceptionV3.h, 点开查看:

他们之间的关系是, inception_v3_input导入 -> 启动inception_v3.model训练 -> 得到inception_v3_output 分别提供了一个实例化方法.

开始代码

首先确认他们之间的调用方向 层次, 我直接是参考tensorflow加载的顺序, 只要理解了, 就可以直接调用了:
这是Py 大专栏  机器学习迁移模型到IOSthon的调用方法:

  1. 1
  2. 2
  1. poo3_frist = sess.run(poo3, feed_dict={inpiut_x: image}) # 按照模型的顺序要, 先喂给它图片, 然后图片提取到瓶颈的tensor
  2. result = sess.run(predict, feed_dict={change_input:poo3_frist}) # 瓶颈的tensor再转入input传入, 得到我们最后的predict

如果参照tensorflow加载模型的做法, 我们直接是一张图片, 得到一个run到pool3, 但实际CoreML只给我们生成了一个实例方法:

  1. 1
  1. - (instancetype)initWithBottleneckInputPlaceholder__0:(MLMultiArray *)BottleneckInputPlaceholder__0 import__Mul__0:(CVPixelBufferRef)import__Mul__0;

方法必须要传入一个MLMultiArray, 而且shape必须一致的. 最后我直接调用MLMultiArray的方法,生成一个0值的2048shape

预测部分的, 完整核心代码 均有详细的注释说明

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  1. - (inception_v3 *)tfModel {
  2. if (!_tfModel) {
  3. // 1 加载模型, 本身代码会调用init的时候, 方法会调用initWithContentsOfURL, 找到inception文件进行初始化
  4. _tfModel = [[inception_v3 alloc] init];
  5. }
  6. return _tfModel;
  7. }
  8. - (NSString *)predictWithFoodImage:(UIImage *)foodImage
  9. {
  10. // step1: 标准为size, 转为可传入的参数.
  11. UIImage *img = [foodImage scaleToSize:CGSizeMake(299, 299)]; // 转换为可传参的图片大小
  12. CVPixelBufferRef refImage = [[UIImage new] pixelBufferFromCGImage:img]; // 转换为可传参的类型
  13. // step2.1: 由于一开始是没有BottleneckInputPlaceholder, 直接0值初始一个传入
  14. MLMultiArray *holder = [[MLMultiArray alloc] initWithShape:@[@2048] dataType:MLMultiArrayDataTypeDouble error:nil];
  15. // step2.2: 启动预测, 预测完成后得到import__pool_3___reshape__0
  16. inception_v3Output *output = [self.tfModel predictionFromBottleneckInputPlaceholder__0:holder import__Mul__0:refImage error:nil];
  17. // step3: 从第二步, 完整得到了想要的BottleneckInputPlaceholder, 直接代入, 图片也代入.
  18. inception_v3Output *output1 = [self.tfModel predictionFromBottleneckInputPlaceholder__0:output.import__pool_3___reshape__0 import__Mul__0:refImage error:nil];
  19. // step4: 从final_train_ops__softMax_last__0提取预测结果
  20. MLMultiArray *__final = output1.final_train_ops__softMax_last__0;
  21. return [self poAccu:__final];
  22. }

调用摄像头进行图片获取

通过整合代码层次, 代码调用也封装好了, 方便代用:
从简书上拿到别人写好的调用摄像头拍照https://www.jianshu.com/p/62d69d89fa43, 提取了下代码:
主要逻辑:
拍照后重置大小展示到view, 异步进行模型预测, 回到主线程展示label结果.

  1. 1
  2. 2
  3. 3
  4. 4
  5. 5
  6. 6
  7. 7
  8. 8
  9. 9
  10. 10
  11. 11
  12. 12
  13. 13
  14. 14
  15. 15
  16. 16
  17. 17
  18. 18
  19. 19
  20. 20
  21. 21
  22. 22
  23. 23
  24. 24
  25. 25
  26. 26
  27. 27
  28. 28
  29. 29
  30. 30
  31. 31
  32. 32
  33. 33
  34. 34
  35. 35
  36. 36
  37. 37
  38. 38
  39. 39
  40. 40
  41. 41
  42. 42
  43. 43
  44. 44
  1. //触发事件:拍照
  2. - (void)addCamera
  3. {
  4. UIImagePickerController *picker = [[UIImagePickerController alloc] init];
  5. picker.delegate = self;
  6. picker.allowsEditing = YES; //可编辑
  7. //判断是否可以打开照相机
  8. if ([UIImagePickerController isSourceTypeAvailable:UIImagePickerControllerSourceTypeCamera]) {
  9. //摄像头
  10. picker.sourceType = UIImagePickerControllerSourceTypeCamera;
  11. } else { //否则打开照片库
  12. picker.sourceType = UIImagePickerControllerSourceTypePhotoLibrary;
  13. }
  14. [self presentViewController:picker animated:YES completion:nil];
  15. }
  16. //拍摄完成后要执行的代理方法
  17. - (void)imagePickerController:(UIImagePickerController *)picker didFinishPickingMediaWithInfo:(NSDictionary *)info
  18. {
  19. NSString *mediaType = [info objectForKey:UIImagePickerControllerMediaType];
  20. if ([mediaType isEqualToString:@"public.image"]) {
  21. //得到照片
  22. UIImage *image = [info objectForKey:UIImagePickerControllerOriginalImage];
  23. image = [image scaleToSize:self.imageView.frame.size];
  24. self.imageView.image = image;
  25. // 异步处理, 不要占用主线程:
  26. dispatch_async(dispatch_queue_create(0, 0), ^{
  27. NSString *preString = [self.prediction predictWithFoodImage:image];
  28. dispatch_async(dispatch_get_main_queue(), ^{
  29. self.preLabel.text = preString;
  30. });
  31. });
  32. }
  33. [self dismissViewControllerAnimated:YES completion:nil];
  34. }
  35. //进入拍摄页面点击取消按钮
  36. - (void)imagePickerControllerDidCancel:(UIImagePickerController *)picker
  37. {
  38. [self dismissViewControllerAnimated:YES completion:nil];
  39. }

测试结果

制作gif, 直接看图片: 或者上面链接从我的git上下载源码:

[好困, 该睡觉了...]

机器学习迁移模型到IOS的更多相关文章

  1. ios开发——实用技术篇&Pist转模型详细介绍

    Pist转模型详细介绍 关于Plist转模型在iOS开发中是非常常见的,每开一一个项目或者实现一个功能都要用到它,所以今天就给大家讲讲Plist怎么转成模型数据, 前提:必须有一个Plist文件或者通 ...

  2. 对于iOS开发人工智能意味着什么

    对于iOS开发人工智能意味着什么? 前言 近几年来人工智能的话题那是炙手可热.在国内很多大佬言必谈机器学习和大数据:在美国刚毕业的人工智能 PHD 也是众人追捧,工资直逼 NFL 四分卫.人工智能甚至 ...

  3. 从BSP模型到Apache Hama

    一.什么是BSP模型 概述 BSP(Bulk Synchronous Parallel,整体同步并行计算模型)是一种并行计算模型,由英国计算机科学家Viliant在上世纪80年代提出.Google发布 ...

  4. ios基础篇(十八)——Delegate 、NSNotification 和 KVO用法及其区别

    一.Delegate Delegate本质是一种程序设计模型,iOS中使用Delegate主要用于两个页面之间的数据传递.iphone中常用@protocol和delegate的机制来实现接口的功能. ...

  5. iOS Simulator功能介绍关于Xamarin IOS开发

    iOS Simulator功能介绍关于Xamarin IOS开发 iOS Simulator功能介绍 在图1.38所示的运行效果中,所见到的类似于手机的模型就是iOS Simulator.在没有iPh ...

  6. 利用GBDT模型构造新特征具体方法

    利用GBDT模型构造新特征具体方法 数据挖掘入门与实战  公众号: datadw   实际问题中,可直接用于机器学**模型的特征往往并不多.能否从"混乱"的原始log中挖掘到有用的 ...

  7. 学习笔记TF066:TensorFlow移动端应用,iOS、Android系统实践

    TensorFlow对Android.iOS.树莓派都提供移动端支持. 移动端应用原理.移动端.嵌入式设备应用深度学习方式,一模型运行在云端服务器,向服务器发送请求,接收服务器响应:二在本地运行模型, ...

  8. iOS系统及客户端软件测试的基础介绍

    iOS系统及客户端软件测试的基础介绍 iOS现在的最新版本iOS5是10月12号推出,当前版本是4.3.5 先是硬件部分,采用iOS系统的是iPad,iPhone,iTouch这三种设备,其中iPho ...

  9. ios数据持久化(转)

    文件系统 归档和序列化 数据库 1.文件系统 不管是Mac OS X 还是iOS的文件系统都是建立在UNIX文件系统基础之上的. 1.1 沙盒模型 在iOS中,一个App的读写权限只局限于自己的沙盒目 ...

随机推荐

  1. springboot访问请求404问题

    新手在刚接触springboot的时候,可能会出现访问请求404的情况,代码没问题,但就是404. 疑问:在十分确定代码没问题的时候,可以看下自己的包是不是出问题了? 原因:SpringBoot 注解 ...

  2. 【python】两行代码实现近百年的正反日期查询--20200202

    到2020年了.有个日期也火了,记得上一次还是2011年11月2日.为啥捏,因为日期写成数字形式 正反是一样的. 2020年也有一个这样的日期.20200202:2020年2月2日. 于是乎想写一段代 ...

  3. Oscar的拓扑笔记本

    目录 Euler characteristic Euler定理 引入:绝对值 度量空间 Example: 开集,闭集 Topological space 什么是拓扑 拓扑空间 例子: Exercise ...

  4. mysql统计指定数据库的各表的条数

    mysql统计指定数据库的各表的条数 SELECT table_schema,table_name,table_rows,CREATE_TIME FROM TABLES WHERE TABLE_SCH ...

  5. 2018-1 WebStorm最新版本破解激活方法

    在激活页面选择License Server,输入:http://idea.codebeta.cn,点击Activate即可激活. 如果失效用这个:  http://idea.ibdyr.com

  6. linux压缩管理系统

    Linux压缩管理系统windows        rar       zipLinux       zip        tar.gz       tar.bz2       tar.xz 压缩的好 ...

  7. 查看linux系统安装的服务

    如何查看linux系统安装了哪些服务呢,因不同版本的操作系统可能使用的命令不一样或者有些命令在某些操作系统不可用,现列举一些常用查看命令(基于我的linux版本). 我的操作系统版本如下: 1.ser ...

  8. linux上安装 mysql

    一.linux 上安装 mysql 1.查看mysql是否安装 rpm -qa|grep mysql 2.卸载 mysql yum remove mysql mysql-server mysql-li ...

  9. vue点击复制文本粘贴

    <template>  <ul>      <li> <input type="text" class="inpNone&quo ...

  10. 实现JS脏话筛选替换的几种途径

    一.逐个替换用replace 缺点:筛选的脏话集太少 var oSize = $(this).siblings('.flex-text-wrap').find('.comment-input').va ...