在tensorflow中可以调用keras,有时候让模型的建立更加简单。如下这种是官方写法:

  1. import tensorflow as tf
  2. from keras import backend as K
  3. from keras.layers import Dense
  4. from keras.objectives import categorical_crossentropy
  5. from keras.metrics import categorical_accuracy as accuracy
  6. from tensorflow.examples.tutorials.mnist import input_data
  7. # create a tf session,and register with keras。
  8. sess = tf.Session()
  9. K.set_session(sess)
  10.  
  11. # this place holder is the same with input layer in keras
  12. img = tf.placeholder(tf.float32, shape=(None, 784))
  13. # keras layers can be called on tensorflow tensors
  14. x = Dense(128, activation='relu')(img)
  15. x = Dense(128, activation='relu')(x)
  16. preds = Dense(10, activation='softmax')(x)
  17. # label
  18. labels = tf.placeholder(tf.float32, shape=(None, 10))
  19. # loss function
  20. loss = tf.reduce_mean(categorical_crossentropy(labels, preds))
  21.  
  22. train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
  23.  
  24. mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)
  25.  
  26. # initialize all variables
  27. init_op = tf.global_variables_initializer()
  28. sess.run(init_op)
  29.  
  30. with sess.as_default():
  31. for i in range(1000):
  32. batch = mnist_data.train.next_batch(50)
  33. train_step.run(feed_dict={img:batch[0],
  34. labels:batch[1]})
  35.  
  36. acc_value = accuracy(labels, preds)
  37. with sess.as_default():
  38. print(acc_value.eval(feed_dict={img:mnist_data.test.images,
  39. labels:mnist_data.test.labels}))

上述代码中,在训练阶段直接采用了tf的方式,甚至都没有定义keras的model!官网说 最重要的一步就是这里:

  1. K.set_session(sess)

创建一个TensorFlow会话并且注册Keras。这意味着Keras将使用我们注册的会话来初始化它在内部创建的所有变量。 
keras的层和模型都充分兼容tensorflow的各种scope, 例如name scope,device scope和graph scope。

经过测试,下面这种不需要k.set_session()也是可以的。

  1.  
  1. import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
  2.  
  3. # build module
  4.  
  5. img = tf.placeholder(tf.float32, shape=(None, 784))
    labels = tf.placeholder(tf.float32, shape=(None, 10))
  6.  
  7. x = tf.keras.layers.Dense(128, activation='relu')(img)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    prediction = tf.keras.layers.Dense(10, activation='softmax')(x)
  8.  
  9. loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=prediction, labels=labels))
  10.  
  11. train_optim = tf.train.AdamOptimizer().minimize(loss)
    path="/home/vv/PycharmProject/Cnnsvm/MNIST_data"
    mnist_data = input_data.read_data_sets(path, one_hot=True)
  12.  
  13. with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
  14.  
  15. for _ in range(1000):
    batch_x, batch_y = mnist_data.train.next_batch(50)
    sess.run(train_optim, feed_dict={img: batch_x, labels: batch_y})
  16.  
  17. acc_pred = tf.keras.metrics.categorical_accuracy(labels, prediction)
    pred = sess.run(acc_pred, feed_dict={labels: mnist_data.test.labels, img: mnist_data.test.images})
  18.  
  19. print('accuracy: %.3f' % (sum(pred) / len(mnist_data.test.labels)))
    print(pred)

如果在下载导入mnist数据出错,可以在网站上下好,本地导入。

  1. mnist_data = input_data.read_data_sets(path, one_hot=True)
  1. x1 = tf.layers.conv2d(img2,64,2)
  2. x2 = tf.keras.layers.Conv2D(img2,64,2)
  3. x3 = tf.keras.layers.Conv2D(64,2)(img2)

x1和x3卷积效果相同

tensorflow和keras混用的更多相关文章

  1. 『计算机视觉』Mask-RCNN_推断网络其二:基于ReNet101的FPN共享网络暨TensorFlow和Keras交互简介

    零.参考资料 有关FPN的介绍见『计算机视觉』FPN特征金字塔网络. 网络构架部分代码见Mask_RCNN/mrcnn/model.py中class MaskRCNN的build方法的"in ...

  2. 深度学习基础系列(五)| 深入理解交叉熵函数及其在tensorflow和keras中的实现

    在统计学中,损失函数是一种衡量损失和错误(这种损失与“错误地”估计有关,如费用或者设备的损失)程度的函数.假设某样本的实际输出为a,而预计的输出为y,则y与a之间存在偏差,深度学习的目的即是通过不断地 ...

  3. windows安装TensorFlow和Keras遇到的问题及其解决方法

    安装TensorFlow在Windows上,真是让我心力交瘁,想死的心都有了,在Windows上做开发真的让人发狂. 首先说一下我的经历,本来也就是起初,网上说python3.7不支持TensorFl ...

  4. Ubuntu18.04 安装TensorFlow 和 Keras

    TensorFlow和Keras是当前两款主流的深度学习框架,Keras被采纳为TensorFlow的高级API,平时做深度学习任务,可以使用Keras作为深度学习框架,并用TensorFlow作为后 ...

  5. Anaconda 安装 tensorflow 和 keras

    说明:此操作是在 Anaconda Prompt 窗口完成的 CPU版 tensorflow 的安装. 1.用 conda 创建虚拟环境 tensorflow python=3.6 conda cre ...

  6. 成功解决 AttributeError: module 'tensorflow.python.keras.backend' has no attribute 'get_graph'

    在导入keras包时出现这个问题,是因为安装的tensorflow版本和keras版本不匹配,只需卸载keras,重新安装自己tensorflow对应的版本就OK了.可以在这个网址查看tensorfl ...

  7. tensorflow和keras的安装

    1 卸载tensorflow方法,在终端输入:  把protobuf删除了才能卸载干净. sudo pip uninstall protobuf sudo pip uninstall tensorfl ...

  8. win10+anaconda安装tensorflow和keras遇到的坑小结

    win10下利用anaconda安装tensorflow和keras的教程都大同小异(针对CPU版本,我的gpu是1050TI的MAX-Q,不知为啥一直没安装成功),下面简单说下步骤. 一 Anaco ...

  9. Anaconda安装tensorflow和keras(gpu版,超详细)

    本人配置:window10+GTX 1650+tensorflow-gpu 1.14+keras-gpu 2.2.5+python 3.6,亲测可行 一.Anaconda安装 直接到清华镜像网站下载( ...

随机推荐

  1. Java+Selenium环境搭建

    初学者---简单的selenium环境搭建: 1. 安装JAVA环境 2.下载eclipse 3.下载firefox (不要最高版本,容易出现selenium不兼容问题) 4. 下载selenium需 ...

  2. GDB查看堆栈局部变量

    GDB查看堆栈局部变量 “参数从右到左入栈”,“局部变量在栈上分配空间”,听的耳朵都起茧子了.最近做项目涉及C和汇编互相调用,写代码的时候才发现没真正弄明白.自己写了个最简单的函数,用gdb跟踪了调用 ...

  3. 小程序 showModal content换行

    wx.showModal({ title: '提示', content: '1.该拼团仅支持到指定取货地址自提\r\n2.拼团支付价格为拼团原价,当到达指定阶梯,拼团差价将在3个工作日内退回您的微信账 ...

  4. python+requests+excel+unittest+ddt接口自动化数据驱动并生成html报告(二)

    可以参考 python+requests接口自动化完整项目设计源码(一)https://www.cnblogs.com/111testing/p/9612671.html 原文地址https://ww ...

  5. 关于php

    public private protected 修饰词 public: 公有类型 在子类中可以通过self::var调用public方法或属性,parent::method调用父类方法 在实例中可以 ...

  6. docker中i的作用

    #docker container createKeep STDIN open even if not attached #docker container startAttach container ...

  7. 知识在与温故、总结-再读CLR

    序 CLR,通用语言运行时,每个.Net 程序猿,都会第一时间接触到.记得2008年,第一次学习Jeffrey Richter的CLR Via C#,读的懵懵懂懂,大抵因为编码太少,理解的只是概念和皮 ...

  8. function(){}、var fun=function(){}和function fun(){}的区别

    一.基本定义 1.函数声明:使用function声明函数,并指定函数名. function fun() { // ...... } 2.函数表达式:使用function声明函数,但未指定函数名,将匿名 ...

  9. Python之循环

    目标 程序的三大流程 while 循环基本使用 break 和 continue while 循环嵌套 一 程序的三大流程 在程序开发中,一共有三种流程方式: 顺序 —— 从上向下,顺序执行代码 分支 ...

  10. easyUI 创建详情页dialog

    使用easyui dialog先下载jQuery easyui 的压缩包  下载地址http://www.jeasyui.com/download/v155.php 解压后放在项目WebContent ...