1. 代码实现

from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T

def compute_accuracy(y_target, y_predict):
    correct_prediction = np.equal(y_predict, y_target)
    accuracy = np.sum(correct_prediction)/len(correct_prediction)
    return accuracy

rng = np.random

N = 400                                   # training sample size
feats = 784                               # number of input variables

# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))

# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")

# initialize the weights and biases
W = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")

# Construct Theano expression graph
p_1 = T.nnet.sigmoid(T.dot(x, W) + b)   # Logistic Probability that target = 1 (activation function)
prediction = p_1 > 0.5                    # The prediction thresholded
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
# or
# xent = T.nnet.binary_crossentropy(p_1, y) # this is provided by theano
cost = xent.mean() + 0.01 * (W ** 2).sum()# The cost to minimize (l2 regularization)
gW, gb = T.grad(cost, [W, b])             # Compute the gradient of the cost

# Compile
learning_rate = 0.1
train = theano.function(
          inputs=[x, y],
          outputs=[prediction, xent.mean()],
          updates=((W, W - learning_rate * gW), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)

# Training
for i in range(500):
    pred, err = train(D[0], D[1])
    if i % 50 == 0:
        print('cost:', err)
        print("accuracy:", compute_accuracy(D[1], predict(D[0])))

print("target values for D:")
print(D[1])
print("prediction on D:")
print(predict(D[0]))

莫烦theano学习自修第八天【分类问题】的更多相关文章

  1. 莫烦theano学习自修第九天【过拟合问题与正规化】

    如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...

  2. 莫烦sklearn学习自修第八天【过拟合问题】

    1. 什么是过拟合问题 所谓过拟合问题指的是使用训练样本进行训练时100%正确分类或规划,当使用测试样本时则不能正确分类和规划 2. 代码实战(模拟过拟合问题) from __future__ imp ...

  3. 莫烦theano学习自修第十天【保存神经网络及加载神经网络】

    1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...

  4. 莫烦theano学习自修第七天【回归结果可视化】

    1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...

  5. 莫烦theano学习自修第六天【回归】

    1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...

  6. 莫烦theano学习自修第五天【定义神经层】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

  7. 莫烦theano学习自修第四天【激励函数】

    1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...

  8. 莫烦theano学习自修第三天【共享变量】

    1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...

  9. 莫烦theano学习自修第二天【激励函数】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

随机推荐

  1. BAT美团滴滴java面试大纲(带答案版)之四:多线程Lock

    每天学习一点点 编程PDF电子书.视频教程免费下载:http://www.shitanlife.com/code 这是多线程的第二篇. 多线程就像武学中对的吸星大法,理解透了用好了可以得道成仙,俯瞰芸 ...

  2. php 依赖注入的实现

    当A类需要依赖于B类,也就是说需要在A类中实例化B类的对象来使用时候,如果B类中的功能发生改变,也会导致A类中使用B类的地方也要跟着修改,导致A类与B类高耦合.这个时候解决方式是,A类应该去依赖B类的 ...

  3. 【vue】vue +element 搭建项目,vue-cli 如何打包上线

    以自己的项目为例 第一步:手动修改config文件夹中的index.js文件中的build对象,将 assetsPublicPath 中的 “/” ,改为 “你实际的加载路径” 如图: 第二步:执行( ...

  4. object detection[content]

    近些年,随着DL的不断兴起,计算机视觉中的对象检测领域也随着CNN的广泛使用而大放异彩,其中Girshick等人的<R-CNN>是第一篇基于CNN进行对象检测的文献.本文欲通过自己的理解来 ...

  5. nova系列一:虚拟化介绍

    一 什么是虚拟化 虚拟化说白了就是本来是一个完整的资源,切分或者说虚拟成多份,让这多份资源都使用起来,物尽其用,减少了浪费,提高了利用率,省了钱. 虚拟化(Virtualization)技术最早出现在 ...

  6. Vue-校验props传来的值

    对父组件传来的值进行校验. Vue.component('child',{ props:{ content:{ type:String, required:false, default:'li zha ...

  7. Ionic1 环境破坏后程序重新恢复过程

    ionic platform remove android ionic platform add android cordova plugin add cordova-plugin-network-i ...

  8. Keil开发环境如何生成BIN文件

    为什么需要BIN文件呢? 有些烧录器只支持BIN文件. 进行OTA远程升级时,只能使用BIN文件. 使用JLink脚本文件进行一键烧录时,只支持BIN文件. BIN文件要比HEX和AXF文件小的多. ...

  9. 基于Vue.js 2.0 + Vuex打造微信项目

    一.项目简介 基于Vue + Vuex + Vue-router + Webpack 2.0打造微信界面,实现了微信聊天.搜索.点赞.通讯录(快速导航).个人中心.模拟对话.朋友圈.设置等功能. 二. ...

  10. [Luogu4916]魔力环[Burnside引理、组合计数、容斥]

    题意 题目链接 分析 sπo yyb 代码 #include<bits/stdc++.h> using namespace std; typedef long long LL; #defi ...