#List3-31
拟合y=2x+0.2
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
np.random.seed(seed=47)
# Linear space of 101 and [-1,1]
trX = np.linspace(-1, 1, 101)
#Create The y function based on the x axis
trY = 2 * trX + np.random.randn(*trX.shape) * 0.4 + 0.2
# Create a new figure
plt.figure()
#Plot a scatter draw of the random data points
plt.scatter(trX,trY)
# Draw one line with the line function
plt.plot (trX, .2 + 2 * trX)
plt.show()
#定义模型函数
def model(x,w,b):
return tf.multiply(x,w)+b
def loss_fun(x,y,w,b):
err = model(x,w,b)-y
squared_err = tf.square(err)
return tf.reduce_mean(squared_err)
def grad(x,y,w,b):
with tf.GradientTape() as tape:
loss_ = loss_fun(x,y,w,b)
return tape.gradient(loss_,[w,b])
# create symbolic variables
#X = tf.placeholder("float", name="X")
#Y = tf.placeholder("float", name = "Y")
#x_data = tf.Variable([1], dtype=tf.float32)
#y_target = tf.Variable([1], dtype=tf.float32)
#构建线性函数的斜率和截距
w = tf.Variable(np.random.randn(),tf.float32)
b = tf.Variable(0.0,tf.float32)
#设置迭代次数和学习率
train_epochs = 100
learning_rate = 0.01
loss = []
count = 0
display_count = 10 #控制显示粒度的参数,每训练10个样本输出一次损失值
#开始训练,轮数为epoch,采用SGD随机梯度下降优化方法
for epoch in range(train_epochs):
for xs,ys in zip(trX,trY):
#计算损失,并保存本次损失计算结果
#rand_index = np.random.choice(100)
#rand_x = tf.cast([x_vals[rand_index]],dtype=tf.float32)
#rand_y = tf.cast([y_vals[rand_index]],dtype=tf.float32)
loss_ =loss_fun(xs,ys,w,b)
loss.append(loss_)
#计算当前[w,b]的梯度
delta_w,delta_b = grad(xs,ys,w,b)
change_w = delta_w * learning_rate
change_b = delta_b * learning_rate
w.assign_sub(change_w)
b.assign_sub(change_b)
#训练步数加1
count = count +1
if count % display_count == 0:
print('train epoch : ','%02d'%(epoch+1),'step:%03d' % (count),'loss= ','{:.9f}'.format(loss_))
#b0temp=b.eval(session=sess)
#b1temp=w.eval(session=sess)
plt.plot (trX, b.numpy() + w.numpy()*trX )
plt.show()
print ("w = {}".format(w.numpy())) # Should be around 2
print ("b = {}".format(b.numpy())) #Should be around 0.2
plt.scatter(trX,trY)
plt.plot (trX, b.numpy() + trX * w.numpy())
plt.show()