In [19]:
 
 
 
 
 
  1. import tensorflow as tf
  1. import numpy as np
  1. # #简单的数据形网络
  1. # #定义输入参数
  1. # X=tf.constant(value=[[0.7,0.5]])
  1. # W1=tf.Variable(tf.truncated_normal(shape=[2,3],mean=0,stddev=2))
  1. # W2=tf.Variable(tf.truncated_normal(shape=[3,1],mean=0,stddev=2))
  1. # #定义前向传播过程
  1. # a=tf.matmul(X,W1)
  1. # y=tf.matmul(a,W2)
  1. # #变量初始化
  1. # init=tf.global_variables_initializer()
  1. # with tf.Session() as sess:
  1. # sess.run(init)
  1. # print("Y",sess.run(y))
  1. #multiply是对应元素相乘,matul是矩阵相乘
  1. # 定义网络变量
  1. #s输入常亮
  1. x_input=tf.placeholder(dtype=tf.float32,shape=[1,2],name="x_input")
  1. #权重变量
  1. W1=tf.Variable(tf.random_normal(shape=[2,3],mean=0,stddev=2))
  1. W2=tf.Variable(tf.random_normal(shape=[3,1],mean=0,stddev=2))
  1. #定义网络运算图
  1. a=tf.matmul(x_input,W1)
  1. Y=tf.matmul(a,W2)
  1. #初始化全局变量
  1. init=tf.global_variables_initializer()
  1. with tf.Session() as sess:
  1. sess.run(init)
  1. result=sess.run(Y,feed_dict={x_input:[[0.7,0.5]]})
  1. print("result", result)
 
 
 
  1. result [[-4.5023837]]
  2.  
  3. #在pycharm运行程序
  1. import tensorflow as tf
    import numpy as np
  2.  
  3. BATCH_SIZE = 8 # 一次输入网络的数据,称为batch。一次不能喂太多数据
    SEED = 23455 # 产生统一的随机数
  4.  
  5. rdm = np.random.RandomState(SEED)
    X = rdm.rand(32, 2)
  6.  
  7. i=0
    Y=np.zeros((32,), dtype=np.int)
    Y_=np.transpose([Y])
    for (x0, x1) in X:
    y=int(x0 + x1 < 1)
    Y_[[i]]=y
    i=i+1
  8.  
  9. print("X:\n", X)
    print("Y_:\n",Y_)
  10.  
  11. # 1定义神经网络的输入、参数和输出,定义前向传播过程。
    x = tf.placeholder(tf.float32, shape=(None, 2))
    y_ = tf.placeholder(tf.float32, shape=(None, 1))
  12.  
  13. w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
    w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
  14.  
  15. a = tf.matmul(x, w1)
    y = tf.matmul(a, w2)
  16.  
  17. # 2定义损失函数及反向传播方法。
    loss = tf.reduce_mean(tf.square(y - y_))
    train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss) # 三种优化方法选择一个就可以
  18.  
  19. with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    # 输出目前(未经训练)的参数取值。
    print("w1:\n", sess.run(w1))
    print("w2:\n", sess.run(w2))
    print("\n")
    STEPS = 3000
    for i in range(STEPS): #0-2999
    start = (i * BATCH_SIZE) % 32
    end = start + BATCH_SIZE
    sess.run(train_step, feed_dict={x: X[start:end], y_: Y_[start:end]})
    if i % 500 == 0:
    total_loss = sess.run(loss, feed_dict={x: X, y_: Y_})
    print("After %d training step(s), loss on all data is %g"%(i,total_loss))
    print("\n")
    print("w1:\n", sess.run(w1))
    print("w2:\n", sess.run(w2))
  20.  
  21. #比较完整的一个网络

import tensorflow as tf
import numpy as np
batch_size=8
seed=23455

#制造一些假数据
rng=np.random.RandomState(seed)
X=rng.rand(32,2)
print(X)

Y=np.zeros(shape=(32,1),dtype=np.int)

Y=[[np.int(x0+x1<1)]for (x0,x1) in X]

print(Y)

#定义网络
x_input=tf.placeholder(shape=[None,2],dtype=np.float,name="input")
y_output=tf.placeholder(shape=[None,1],dtype=np.float,name="output")

#定义变量
W1=tf.Variable(tf.random_normal(shape=[2,3],stddev=1,seed=1))
W2=tf.Variable(tf.random_normal(shape=[3,1],stddev=1,seed=1))

#定义静态网络函数
a=tf.matmul(x_input,W1)
y=tf.matmul(a,W2)

#定义损失函数
loss=tf.reduce_mean(tf.square(y-y_output))
train_step=tf.train.GradientDescentOptimizer(0.001).minimize(loss)
#初始化变量
init=tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init)
print("W1:\n",W1)
#开始批量载入数据
for i in range(2000):
data_start=(i*batch_size)%32
data_end=data_start+batch_size
#开始训练数据
sess.run(train_step,feed_dict={x_input:X[data_start:data_end],y_output:Y[data_start:data_end]})
#每隔一段就打印出损失值
if ((i)%500==0):
Loss=sess.run(loss,feed_dict={x_input:X,y_output:Y})
print("loss",Loss)
print("w1",sess.run(W1))
print("W2",sess.run(W2))

  1.  

tensorflow--建立一个简单的小网络的更多相关文章

  1. Hyperledger Fabric 建立一个简单网络

    Building you first network 网络结构: 2个Orgnizations(每个Org包含2个peer节点)+1个solo ordering service 打开fabric-sa ...

  2. 3.2 Lucene实战:一个简单的小程序

    在讲解Lucene索引和检索的原理之前,我们先来实战Lucene:一个简单的小程序! 一.索引小程序 首先,new一个java project,名字叫做LuceneIndex. 然后,在project ...

  3. 输出多行字符的一个简单JAVA小程序

    public class JAVA { public static void main(String[] args) { System.out.println("-------------- ...

  4. python -----一个简单的小程序(监控电脑内存,cpu,硬盘)

    一个简单的小程序 用函数实现!~~ 实现: cpu 使用率大于百分之50 时  ,  C 盘容量不足5 G 时, 内存 低于2G 时. 出现以上其中一种情况,发送自动报警邮件! 主要运用 到了两个 模 ...

  5. idea破解版安装、配置jdk以及建立一个简单的maven工程

    idea破解版安装.配置jdk,配置jdk环境变量以及建立一个简单的maven工程 一.idea破解版以及配置文件下载 下载网址:https://pan.baidu.com/s/1yojA51X1RU ...

  6. 通过myclipse建立一个简单的Hibernate项目(PS:在单元测试中实现数据的向表的插入)

    Hibernate的主要功能及用法: Ⅰ.Hibernate封装了JDBC,使Java程序员能够以面向对象的思想对数据库进行操作 Ⅱ.Hibernate可以应用于EJB的J2EE架构,完成数据的持久化 ...

  7. Python3的tkinter写一个简单的小程序

    一.这个学期开始学习python,但是看了python2和python3,最后还是选择了python3 本着熟悉python的原因,并且也想做一些小程序来增加自己对python的熟练度.所以写了一个简 ...

  8. 利用NET HUNTER建立一个自动文件下载的网络接入点

    免责声明:本文旨在分享技术进行安全学习,禁止非法利用. 本文中我将完整的阐述如何通过建立一个非常邪恶的网络接入点来使得用户进行自动文件下载.整个过程中我将使用 Nexus 9 来运行Kali NetH ...

  9. Django 学习笔记之六 建立一个简单的博客应用程序

    最近在学习django时建立了一个简单的博客应用程序,现在把简单的步骤说一下.本人的用的版本是python 2.7.3和django 1.10.3,Windows10系统 1.首先通过命令建立项目和a ...

随机推荐

  1. Maven 使用Nexus搭建Maven私服

    Maven学习 (四) 使用Nexus搭建Maven私服 为什么要搭建nexus私服,原因很简单,有些公司都不提供外网给项目组人员,因此就不能使用maven访问远程的仓库地址,所以很有必要在局域网里找 ...

  2. IdentityServer4专题之六:Resource Owner Password Credentials

    实现代码: (1)IdentityServer4授权服务器代码: public static class Config {  public static IEnumerable<Identity ...

  3. SciPy 线性代数

    章节 SciPy 介绍 SciPy 安装 SciPy 基础功能 SciPy 特殊函数 SciPy k均值聚类 SciPy 常量 SciPy fftpack(傅里叶变换) SciPy 积分 SciPy ...

  4. SPFA和堆优化的Dijk

    朴素dijkstra时间复杂度$O(n^{2})$,通过使用堆来优化松弛过程可以使时间复杂度降到O((m+n)logn):dijkstra不能用于有负权边的情况,此时应使用SPFA,两者写法相似. 朴 ...

  5. vSphere 计算vMotion的迁移原理

    1. 计算vMotion 的应用场景 1). 计划内停机维护 2). 提高资源的利用率 2. 计算vMotion 需求: 1).共享存储 vMotion需要解决的核心问题就是:将VMs的内存从源ESX ...

  6. vCenter 导入Windows Server 2003/XP自定义规范失败

    1.vcsa 切换到/etc/vmware-vpx/sysprep目录下,会有很多个目录,根据Windows Server 2003的版本,64位找到 svr2003-64 这个目录,32位找到svr ...

  7. 题目:给定一数组 例如:a = [1,2,3,5,2,1] 现用户提供一个数字 请返回用户所提供的数字的所有下标

    def test(ary): ds = {} for i in range(len(ary)): if ds.get(ary[i]): ds[ary[i]].append(i) else: ds[ar ...

  8. app1----攻防世界

    啥也不说把题目下载下来,在模拟器里运行一下 输入正确的key就是flag 继续下一步分析,可以使用Androidkiller分析,我喜欢使用jeb这里我就使用jeb进行分析 找到MainActivit ...

  9. DBlink查看,创建于删除

    1.查看dblink select owner,object_name from dba_objects where object_type='DATABASE LINK'; 或者 select * ...

  10. PHP代码审计之入门实战

    系统介绍 CMS名称:新秀企业网站系统PHP版 官网:www.sinsiu.com 版本:这里国光用的1.0 正式版 (官网最新的版本有毒,网站安装的时候居然默认使用远程数据库???迷之操作 那站长的 ...