使用tensorflow的线性回归的例子(八)

发布于:2025-07-09 ⋅ 阅读:(15) ⋅ 点赞:(0)

L1与L2损失

这个脚本展示如何用TensorFlow求解线性回归。

在算法的收敛性中,理解损失函数的影响是很重要的。这里我们展示L2损失函数是如何影响线性回归的收敛性的。我们使用iris数据集,但是我们将改变损失函数和学习速率来看收敛性的改变。

L2-Loss

这里我们展示用L2损失的线性回归。

线性最小二乘的L2损失函数为

其中 N 是数据点数, yi 是第i个实际y值, ^yi^ 是第i个预测y值。

def loss2(x, y,w,b):

    # Declare loss functions

    loss_l2 = tf.reduce_mean(tf.square(y - model(x,w,b)))

    return loss_l2

def grad2(x,y,w,b):

    with tf.GradientTape() as tape:

        loss_2 = loss2(x,y,w,b)

return tape.gradient(loss_2,[w,b])

batch_size = 25

learning_rate = 0.4 # Will not converge with learning rate at 0.4

iterations = 50

# Create variables for linear regression

w2 = tf.Variable(tf.random.normal(shape=[1,1]),tf.float32)

b2 = tf.Variable(tf.random.normal(shape=[1,1]),tf.float32)

optimizer = tf.optimizers.Adam(learning_rate)

# Training loop

loss_vec_l1=[]

loss_vec_l2=[]

for i in range(5000):

    rand_index = np.random.choice(len(x_vals), size=batch_size)

    rand_x = np.transpose([x_vals[rand_index]])

    rand_y = np.transpose([y_vals[rand_index]])

    x=tf.cast(rand_x,tf.float32)

    y=tf.cast(rand_y,tf.float32)

    grads2=grad2(x,y,w2,b2)

    optimizer.apply_gradients(zip(grads2,[w2,b2]))

    #sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

    temp_loss2 = loss2(x, y,w2,b2).numpy()

    #sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})

    loss_vec_l2.append(temp_loss2)

    if (i+1)%25==0:

        print('Step #' + str(i+1) + ' A = ' + str(w2.numpy()) + ' b = ' + str(b2.numpy()))

        print('Loss = ' + str(temp_loss2))

# Get the optimal coefficients

[slope] = w2.numpy()

[y_intercept] = b2.numpy()

# Get best fit line

best_fit2 = []

for i in x_vals:

  best_fit2.append(slope*i+y_intercept)

# Plot the result

plt.plot(x_vals, y_vals, 'o', label='Data Points')

plt.plot(x_vals, best_fit2, 'r-', label='Best fit line', linewidth=3)

plt.legend(loc='upper left')

plt.title('Sepal Length vs Petal Width')

plt.xlabel('Petal Width')

plt.ylabel('Sepal Length')

plt.show()

# Plot loss over time

plt.plot(loss_vec_l2, 'k-')

plt.title('L2 Loss per Generation')

plt.xlabel('Generation')

plt.ylabel('L2 Loss')

plt.show()

# Plot loss over time

plt.plot(loss_vec_l1, 'k-', label='L1 Loss')

plt.plot(loss_vec_l2, 'r--', label='L2 Loss')

plt.title('L1 and L2 Loss per Generation')

plt.xlabel('Generation')

plt.ylabel('L1 Loss')

plt.legend(loc='upper right')

plt.show()


网站公告

今日签到

点亮在社区的每一天
去签到