Logistic regression中regularization失败的解决方法探索(文末附解决后code)
在matlab中做Regularized logistic regression
原理:
我的代码:
- function [J, grad] = costFunctionReg(theta, X, y, lambda)
- %COSTFUNCTIONREG Compute cost and gradient for logistic regression with regularization
- % J = COSTFUNCTIONREG(theta, X, y, lambda) computes the cost of using
- % theta as the parameter for regularized logistic regression and the
- % gradient of the cost w.r.t. to the parameters.
- % Initialize some useful values
- m = length(y); % number of training examples
- % You need to return the following variables correctly
- J = 0;
- grad = zeros(size(theta));
- % ====================== YOUR CODE HERE ======================
- % Instructions: Compute the cost of a particular choice of theta.
- % You should set J to the cost.
- % Compute the partial derivatives and set grad to the partial
- % derivatives of the cost w.r.t. each parameter in theta
- h = sigmoid(X*theta);
- theta2=[0;theta(2:end)];
- J_partial = sum((-y).*log(h)+(y-1).*log(1-h))./m;
- J_regularization= (lambda/(2*m)).*sum(theta2.^2);
- J = J_partial+J_regularization;
- grad_partial = sum((h-y).*X)/m;
- grad_regularization = lambda.*theta2./m;
- grad = grad_partial+grad_regularization;
- % =============================================================
- end
运行结果:
标黄的与下面的预期对比发现不同
尝试删去
.rtcContent { padding: 30px }
.lineNode { font-size: 10pt; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-style: normal; font-weight: normal }
部分结果符合预期,部分不符合
尝试大佬代码
- %Hypotheses
- hx = sigmoid(X * theta);
- %%The cost without regularization
- J_partial = (-y' * log(hx) - (1 - y)' * log(1 - hx)) ./ m;
- %%Regularization Cost Added
- J_regularization = (lambda/(2*m)) * sum(theta(2:end).^2);
- %%Cost when we add regularization
- J = J_partial + J_regularization;
- %Grad without regularization
- grad_partial = (1/m) * (X' * (hx -y));
- %%Grad Cost Added
- grad_regularization = (lambda/m) .* theta(2:end);
- grad_regularization = [0; grad_regularization];
- grad = grad_partial + grad_regularization;
完全成功!?我不李姐……
观察大佬代码发现,我和大佬的区别在于:
最开始的theta向量和计算J(theta)和grad时候使用sum的数目
故尝试修改和大佬数目一样多的sum
- h = sigmoid(X*theta);
- theta2=[0;theta(2:end)];
- J_partial = (-y).*log(h)+(y-1).*log(1-h)./m;
- J_regularization= (lambda/(2*m)).*sum(theta2.^2);
- J = J_partial+J_regularization;
- grad_partial = (h-y).*X/m;
- grad_regularization = lambda.*theta2./m;
- grad = grad_partial+grad_regularization;
结果:incompatible不兼容
文档对该错误的解释如下
事已至此,只好向大佬更近一步!
- h = sigmoid(X*theta);
- J_partial = (-y).*log(h)+(y-1).*log(1-h)./m;
- J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
- J = J_partial+J_regularization;
- grad_partial = (h-y).*X/m;
- grad_regularization = lambda.*theta(2:end)./m;
- grad_regularization2=[0;grad_regularization];
- grad = grad_partial+grad_regularization2;
为什么还是不兼容?
到底哪里出了问题?
最后,尝试离大佬更近一步,把grad_partial里的(h-y).*X/m变成了(1/m) * (X' * (h -y))
- h = sigmoid(X*theta);
- J_partial = (1/m).*((-y).*log(h)+(y-1).*log(1-h));
- J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
- J = J_partial+J_regularization;
- grad_partial = (1/m) * (X' * (h -y));
- grad_regularization = (lambda/m).*theta(2:end);
- grad_regularization = [0; grad_regularization];
- grad = grad_partial+ grad_regularization;
舒服了!
但,等等,上面怎么那么多行,数值还不对?看来不能完全靠大佬,还得自己改!!!
- h = sigmoid(X*theta);
- J_partial = (1/m).*sum((-y).*log(h)+(y-1).*log(1-h));
- J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
- J = J_partial+J_regularization;
- grad_partial = (1/m) * (X' * (h -y));
- grad_regularization = (lambda/m).*theta(2:end);
- grad_regularization = [0; grad_regularization];
- grad = grad_partial+ grad_regularization;
最终,得到了满意的答案
以及
总结一下出现的问题
01不兼容,就像上面说明的那样,行列不匹配
(解决方法:查看有无sum、是值还是array,把系数往前放,修改两数相乘的顺序)
02加入grad_regularization后,grad(1,5)的后四项都出现了问题(很神奇地值相等),
一旦去掉又与正确值有小范围差距(缺少grad_regularization导致的)
说明grad_regularization存在问题
而如果一开始就将theta变为第一行元素是0的矩阵,很容易出现不兼容的问题
大佬的代码提示我们特殊情况可以分出来特殊处理,也就是:
在计算J(θ)不使用矩阵,而是用除0外、后面的θ直接产出需要的值
在计算grad时,由于输出也是矩阵,所以可以创建一个含0和其他θ的矩阵
这样既可以避免不兼容,也可以得出正确的结果
最终的部分code如下
- h = sigmoid(X*theta);
- J_partial = (1/m).*sum((-y).*log(h)+(y-1).*log(1-h));
- J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
- J = J_partial+J_regularization;
- grad_partial = (1/m) * (X' * (h -y));
- grad_regularization = (lambda/m).*theta(2:end);
- grad_regularization = [0; grad_regularization];
- grad = grad_partial+ grad_regularization;
Logistic regression中regularization失败的解决方法探索(文末附解决后code)的更多相关文章
- Machine Learning - 第3周(Logistic Regression、Regularization)
Logistic regression is a method for classifying data into discrete outcomes. For example, we might u ...
- logistic regression中的cost function选择
一般的线性回归使用的cost function为: 但由于logistic function: 本身非凸函数(convex function), 如果直接使用线性回归的cost function的话, ...
- Windows 共享无线上网 无法启动ICS服务解决方法(WIN7 ICS服务启动后停止)
Windows 共享无线上网 无法启动ICS服务解决方法(WIN7 ICS服务启动后停止) ICS 即Internet Connection Sharing,internet连接共享,可以使局域网上其 ...
- 斯坦福机器学习视频笔记 Week3 逻辑回归与正则化 Logistic Regression and Regularization
我们将讨论逻辑回归. 逻辑回归是一种将数据分类为离散结果的方法. 例如,我们可以使用逻辑回归将电子邮件分类为垃圾邮件或非垃圾邮件. 在本模块中,我们介绍分类的概念,逻辑回归的损失函数(cost fun ...
- Andrew Ng Machine Learning 专题【Logistic Regression & Regularization】
此文是斯坦福大学,机器学习界 superstar - Andrew Ng 所开设的 Coursera 课程:Machine Learning 的课程笔记. 力求简洁,仅代表本人观点,不足之处希望大家探 ...
- week3编程作业: Logistic Regression中一些难点的解读
%% ============ Part : Compute Cost and Gradient ============ % In this part of the exercise, you wi ...
- 在IE浏览器中执行OpenFlashChart的reload方法时无法刷新的解决方法
由于项目需求,需要在网页上利用图表展示相关数据的统计信息,采用了OpenFlashChart技术.OpenFlashChart是一款开源的以Flash和Javascript为技术基础的免费图表,用它能 ...
- (蓝牙)网络编程中,使用InputStream read方法读取数据阻塞的解决方法
问题如题,这个问题困扰了我好几天,今天终于解决了,感谢[1]. 首先,我要做的是android手机和电脑进行蓝牙通信,android发一句话,电脑端程序至少就要做到接受到那句话.android端发送信 ...
- blocked because of many connection errors; unblock with 'mysqladmin flush-hosts;MySQL在远程访问时非常慢的解决方法;MySql链接慢的解决方法
一:服务器异常:Host 'xx.xxx.xx.xxx' is blocked because of many connection errors; unblock with 'mysqladmin ...
随机推荐
- Apache+tomcat实现应用服务器集群
Ngnix/Apache比较 Nginx:Nginx是一款轻量级的Web 服务器/反向代理服务器及电子邮件(IMAP/POP3)代理服务器,在BSD-like 协议下发行.其特点是占有内存少,并发能力 ...
- 99%的人都搞错了的java方法区存储内容,通过可视化工具HSDB和代码示例一次就弄明白了
https://zhuanlan.zhihu.com/p/269134063 番茄番茄我是西瓜 那是我日夜思念深深爱着的人啊~ 已关注 6 人赞同了该文章 前言 本篇是java内存区域管理系列教 ...
- Sting -- byte[]互转
1.String -->byte[] String str = "中国"; byte[] bys = str.getBytes(); Arrays.toString(bys) ...
- 请说说你对Struts2的拦截器的理解?
Struts2拦截器是在访问某个Action或Action的某个方法,字段之前或之后实施拦截,并且Struts2拦截器是可插拔的,拦截器是AOP的一种实现. 拦截器栈(Interceptor Stac ...
- SpringMVC常用的注解有哪些?
@RequestMapping:用于处理请求 url 映射的注解,可用于类或方法上.用于类上,则表示类中的所有响应请求的方法都是以该地址作为父路径. @RequestBody:注解实现接收http请求 ...
- 二、mycat15种分片规则
一.分片枚举 通过在配置文件中配置可能的枚举 id,自己配置分片,本规则适用于特定的场景,比如有些业务需要按照省份或区县来做保存,而全国省份区县固定的,这类业务使用本条规则,配置如下: <tab ...
- java对象的克隆以及深拷贝与浅拷贝
一.为什么要使用克隆 在实际编程过程中,我们常常要遇到这种情况:有一个对象A,在某一时刻A中已经包含了一些有效值,此时可能 会需要一个和A完全相同新对象B,并且此后对B任何改动都不会影响到A中的值,也 ...
- java-反射-注解-json-xml
反射: 框架设计的灵魂 框架:半成品软件.可以再框架的基础上进行软件开发,简化代码 定义:将类的各个组成部分封装为其他对象,这就是反射机制 好处: 可以再程序运行过程中,操作这些对象 可以解耦,提高程 ...
- 学习Apache(四)
介绍 Apache HTTP 服务器被设计为一个功能强大,并且灵活的 web 服务器, 可以在很多平台与环境中工作.不同平台和不同的环境往往需要不同 的特性,或可能以不同的方式实现相同的特性最有效率. ...
- 全方位讲解 Nebula Graph 索引原理和使用
本文首发于 Nebula Graph Community 公众号 index not found?找不到索引?为什么我要创建 Nebula Graph 索引?什么时候要用到 Nebula Graph ...