简单线性回归问题的优化(SGD)R语言
本编博客继续分享简单的机器学习的R语言实现。
今天是关于简单的线性回归方程问题的优化问题

常用方法,我们会考虑随机梯度递降,好处是,我们不需要遍历数据集中的所有元素,这样可以大幅度的减少运算量。
具体的算法参考下面:
首先我们先定义我们需要的参数的Notation


上述算法中,为了避免过拟合,我们采用了L2的正则化,在更新步骤中,我们会发现,这个正则项目,对参数更新的影响
下面是代码部分:
## Load Library
library(ggplot2)
library(reshape2)
library(mvtnorm) ## Function for reading the data
read_data <- function(fname, sc) {
data <- read.csv(file=fname,head=TRUE,sep=",")
nr = dim(data)[1]
nc = dim(data)[2]
x = data[1:nr,1:(nc-1)]
y = data[1:nr,nc]
if (isTRUE(sc)) {
x = scale(x) ## Scale x
y = scale(y) ## Scale y
}
return (list("x" = x, "y" = y))
}
我们定义了一个读取数据的方程,这里,我们会把数据集给scale一下,可以保证进一步提高运算速度
## Matrix Product Function
predict_func <- function(Phi, w){
return(Phi%*%w)
} ## Function to compute the cost function
train_obj_func <- function (Phi, w, label, lambda){
# Cost funtion including the L2 norm regulaztion
return(.5 * mean((predict_func(Phi, w) - label)^2) + .5 * lambda * t(w) %*% w)
} ## Return the errors for each iteration
get_errors <- function(data, label, W) {
n_weights = dim(W)[1]
Phi <- cbind('X0' = 1, data)
errors = matrix(,nrow=n_weights, ncol=2)
for (tau in 1:n_weights) {
errors[tau,1] = tau
errors[tau,2] = train_obj_func(Phi, W[tau,],label, 0) ## Get the errors, set the lambda to 0
}
return(errors)
}
同时,我们定义了计算矩阵乘法,计算目标函数以及求误差的方程。
sgd_train <- function(train_x, train_y, lambda, eta, epsilon, max_epoch) {
## Prepare the traindata
## Attach the 1 for X0
Phi <- as.matrix(cbind('X0'=1, train.data))
## Calculate the max iteration time for the SGD
train_len = dim(train_x)[1]
tau_max = max_epoch * train_len
W <- matrix(,nrow=tau_max, ncol=ncol(Phi))
set.seed(1234)
## Random Generate the start parameter
W[1,] <- runif(ncol(Phi))
tau = 1 # counter
## Create a dateframe to store the value of cost function for each iteration
obj_func_val <-matrix(,nrow=tau_max, ncol=1)
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda)
while (tau <= tau_max){
# check termination criteria
if (obj_func_val[tau,1]<=epsilon) {break}
# shuffle data:
train_index <- sample(1:train_len, train_len, replace = FALSE)
# loop over each datapoint
for (i in train_index) {
# increment the counter
tau <- tau + 1
if (tau > tau_max) {break}
# make the weight update
y_pred <- predict_func(Phi[i,], W[tau-1,])
W[tau,] <- sgd_update_weight(W[tau-1,], Phi[i,], train_y[i], y_pred, lambda, eta)
# keep track of the objective funtion
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda)
}
}
# resulting values for the training objective function as well as the weights
return(list('vals'=obj_func_val,'W'=W))
}
# updating the weight vector
sgd_update_weight <- function(W_prev, x, y_true, y_pred, lambda, eta) {
## Computer the Gradient
grad = - (y_true-y_pred) * x
## Here I just combine the regularisation term with prev w
return(W_prev*(1-eta * lambda) - eta * grad)
}
根据上述我们写的计算更新目标函数和参数的方法,讲算法用R实现
下面是实验部分
## Load the train data and train label
train.data <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$x
train.label <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$y
## Load the testdata and test label
test.data <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$x
test.label <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$y # Set MAX EPOCH max_epoch = 18 ## Implement SGD with Ridge regression
options(warn=-1) ## Initilize
## Set the related parameters
epsilon = .001 ## Terimation threshold
eta = .01 ## Learning Rate
lambda= 0.5 ## Regularisation parmater ## Run SGD
## Cost function values
train_res2 = sgd_train(train.data, train.label, lambda, eta, epsilon, max_epoch)
## Calulate the errors
## To be mentioned here, we will only visulisation for the train error to check the converge result
errors2 = get_errors(train.data, train.label, train_res2$W)
接着,我们把SGD的error plot给绘制出来
## Visulastion for SGD
plot(train_res2$val, main="SGD", type="l", col="blue",
xlab="iteration", ylab="training objective function")

这里我们的方程比较简单,可以看到,目标函数很快就收敛了。
简单线性回归问题的优化(SGD)R语言的更多相关文章
- 【数据分析】线性回归与逻辑回归(R语言实现)
文章来源:公众号-智能化IT系统. 回归模型有多种,一般在数据分析中用的比较常用的有线性回归和逻辑回归.其描述的是一组因变量和自变量之间的关系,通过特定的方程来模拟.这么做的目的也是为了预测,但有时也 ...
- 一个简单文本分类任务-EM算法-R语言
一.问题介绍 概率分布模型中,有时只含有可观测变量,如单硬币投掷模型,对于每个测试样例,硬币最终是正面还是反面是可以观测的.而有时还含有不可观测变量,如三硬币投掷模型.问题这样描述,首先投掷硬币A,如 ...
- R语言
什么是R语言编程? R语言是一种用于统计分析和为此目的创建图形的编程语言.不是数据类型,它具有用于计算的数据对象.它用于数据挖掘,回归分析,概率估计等领域,使用其中可用的许多软件包. R语言中的不同数 ...
- R语言:用简单的文本处理方法优化我们的读书体验
博客总目录:http://www.cnblogs.com/weibaar/p/4507801.html 前言 延续之前的用R语言读琅琊榜小说,继续讲一下利用R语言做一些简单的文本处理.分词的事情.其实 ...
- R语言-简单线性回归图-方法
目标:利用R语言统计描绘50组实验对比结果 第一步:导入.csv文件 X <- read.table("D:abc11.csv",header = TRUE, sep = & ...
- R 语言中的简单线性回归
... sessionInfo() # 查询版本及系统和库等信息 getwd() path <- "E:/RSpace/R_in_Action" setwd(path) rm ...
- R语言解读一元线性回归模型
转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...
- R语言解读多元线性回归模型
转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...
- 多元线性回归公式推导及R语言实现
多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...
随机推荐
- 装箱问题(NOIP2001&水题测试2017082401)
题目链接:装箱问题 这题经典的01背包. 动规. 设计状态f[n][V]表示前n个物体放在V中的最大体积是多少. 所以代码如下: #include<bits/stdc++.h> using ...
- 组合数问题(NOIP2016)
题目链接:组合数问题 这道题可以算当年第二简单的. 这里要用到两个技巧: 用杨辉三角递推计算组合数 运用前缀和 有了这两点,这道题就出来了. 我们先运用杨辉三角推出题目范围内所能用到的所有组合数,然后 ...
- kbmmw 的HTTPSmartService中的跨域访问
有同学在使用kbmmw 与extjs 结合的时候,涉及到了跨域访问,这个在 kbmmw 里面已经完全解决. extjs 在访问跨域的时候,首先会使用OPIONS 调用,服务端要根据浏览器请求的 he ...
- 使用delphi 10.2 开发linux 上的Daemon
delphi 10.2 支持linux, 而且官方只是支持命令行编程,目地就是做linux 服务器端的开发. 既然是做linux服务器端的开发,那么普通的命令行运行程序,然后等待开一个黑窗口的方式就 ...
- 制作centos sshd 镜像
[root@b5926410fe60 /]# yum install passwd openssl openssh-server -y 启动sshd: # /usr/sbin/sshd -D 这时报以 ...
- Jetty 9的使用
参考来源:https://www.cnblogs.com/empireghost/p/3522834.html
- iOS后台唤醒实战:微信收款到账语音提醒技术总结
1.前言 微信为了解决小商户老板们在频繁交易中不方便核对.确认到账的功能痛点,产品MM提出了新版本需要支持收款到账语音提醒功能.本文借此总结了iOS平台上的APP后台唤醒和语音合成.播放等一系列技术开 ...
- 2018.10.27 loj#2292. 「THUSC 2016」成绩单(区间dp)
传送门 g[i][j][k][l]g[i][j][k][l]g[i][j][k][l]表示将区间l,rl,rl,r变成最小值等于kkk,最大值等于lll时的花费的最优值. f[i][j]f[i][j] ...
- 如何将本地代码通过git上传到码云
ps:同部署到GitHub上一样 http://www.cnblogs.com/pcx105/p/7777932.html
- Gitolite 权限控制
官网 http://gitolite.com/gitolite/index.html 安装配置 http://gitolite.com/gitolite/install/ 傻瓜安装教程 http:// ...