CS20Chapter3
waiting P54 shuffle data
03_Lecture note_Linear and Logistic Regression
学习点1:
python的地址输入是要不能用正斜杠\的,要用 / 来做地址分段。 比如:
# 打开一个文件
f = open("/tmp/foo.txt", "w") f.write( "Python 是一个非常好的语言。\n是的,的确非常好!!\n" ) # 关闭打开的文件
f.close()
Birth rate - life expectancy code:
""" Solution for simple linear regression example using tf.data
Created by Chip Huyen (chiphuyen@cs.stanford.edu)
CS20: "TensorFlow for Deep Learning Research"
cs20.stanford.edu
Lecture 03
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']=''
import time import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf import utils DATA_FILE = 'data/birth_life_2010.txt' # Step 1: read in the data
data, n_samples = utils.read_birth_life_data(DATA_FILE) # Step 2: create Dataset and iterator
dataset = tf.data.Dataset.from_tensor_slices((data[:,0], data[:,1])) iterator = dataset.make_initializable_iterator()
X, Y = iterator.get_next() # Step 3: create weight and bias, initialized to 0
w = tf.get_variable('weights', initializer=tf.constant(0.0))
b = tf.get_variable('bias', initializer=tf.constant(0.0)) # Step 4: build model to predict Y
Y_predicted = X * w + b # Step 5: use the square error as the loss function
loss = tf.square(Y - Y_predicted, name='loss')
# loss = utils.huber_loss(Y, Y_predicted) # Step 6: using gradient descent with learning rate of 0.001 to minimize loss
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(loss) start = time.time() //开始时,记录一次时间
with tf.Session() as sess:
# Step 7: initialize the necessary variables, in this case, w and b
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('./graphs/linear_reg', sess.graph) # Step 8: train the model for 100 epochs
for i in range(100):
sess.run(iterator.initializer) # initialize the iterator
total_loss = 0
try:
while True:
_, l = sess.run([optimizer, loss])
total_loss += l
except tf.errors.OutOfRangeError:
pass print('Epoch {0}: {1}'.format(i, total_loss/n_samples)) # close the writer when you're done using it
writer.close() # Step 9: output the values of w and b
w_out, b_out = sess.run([w, b])
print('w: %f, b: %f' %(w_out, b_out))
print('Took: %f seconds' %(time.time() - start)) # plot the results
plt.plot(data[:,0], data[:,1], 'bo', label='Real data')
plt.plot(data[:,0], data[:,0] * w_out + b_out, 'r', label='Predicted data with squared error')
# plt.plot(data[:,0], data[:,0] * (-5.883589) + 85.124306, 'g', label='Predicted data with Huber loss')
plt.legend()
plt.show()
CS20Chapter3的更多相关文章
随机推荐
- hdu 1026 Ignatius and the Princess I 搜索,输出路径
Ignatius and the Princess I Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (J ...
- js中contains()方法的了解
今天第一次碰到了contains()方法,处于好奇了解了一下:发现在某些场合还是挺有用的. contains(),js原生方法,用于判断DOM元素的包含关系: 需要注意的是:它以HTMLElement ...
- Linux(Ubuntu)下MySQL的安装
1)首先检查系统中是否已经安装了MySQL 在终端里面输入 sudo netstat -tap | grep mysql 若没有反映,没有显示已安装结果,则没有安装.若如下显示,则表示已经安装 2)如 ...
- CSS属性之relative
0.相对定位relative特点 相对定位relative元素总是会占据位置,所占据的位置是在relative元素没有设置left/top/right/bottom属性时的位置: 相对定位relati ...
- 关于 PHPMailer 邮件发送类的使用心得(含多文件上传)
This is important for send mail PHPMailer 核心文件 class.phpmailer.php class.phpmaileroauth.php class.ph ...
- is_array判断是否为数组
if(is_array($arr)){ echo "是数组"; }else{ echo "不是数组"; }
- toast, 警告窗
//浮动提示框 1秒后消失 toast(msg, isError, sec) { var div = $('#toast'); div.html(msg); div.css({visibility: ...
- Windows 10 Framework 3.5 _x64 离线安装包 最新安装版
原文:http://www.jb51.net/softs/325481.html Windows 10 Framework 3.5 离线安装包,适用于 Win10 和 Server 2016 离线安装 ...
- jquery 仿windows10菜单效果下载
http://www.kuitao8.com/20150923/4079.shtml jquery 仿windows10菜单效果下载
- Volley框架实现Http的get和post请求
一: volley简介: Google I/O 2013上,Volley发布了.Volley是Android平台上的网络通信库,能使网络通信更快,更简单,更健壮.这是Volley名称的由来: a bu ...