本篇博客代码来自于《动手学深度学习》pytorch版,也是代码较多,解释较少的一篇。不过好多方法在我以前的博客都有提,所以这次没提。还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂(前提是python语法大概了解),这是我不加很多解释的重要原因。

K折交叉验证实现

def get_k_fold_data(k, i, X, y):
# 返回第i折交叉验证时所需要的训练和验证数据,分开放,X_train为训练数据,X_valid为验证数据
assert k > 1
fold_size = X.shape[0] // k # 双斜杠表示除完后再向下取整
X_train, y_train = None, None
for j in range(k):
idx = slice(j * fold_size, (j + 1) * fold_size) #slice(start,end,step)切片函数
X_part, y_part = X[idx, :], y[idx]
if j == i:
X_valid, y_valid = X_part, y_part
elif X_train is None:
X_train, y_train = X_part, y_part
else:
X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行数,竖着连接
y_train = torch.cat((y_train, y_part), dim=0)
return X_train, y_train, X_valid, y_valid def k_fold(k, X_train, y_train, num_epochs,learning_rate, weight_decay, batch_size):
train_l_sum, valid_l_sum = 0, 0
for i in range(k):
data = get_k_fold_data(k, i, X_train, y_train) # 获取k折交叉验证的训练和验证数据
net = get_net(X_train.shape[1]) #get_net在这是一个基本的线性回归模型,方法实现见附录1
train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
weight_decay, batch_size) #train方法见后面附录2
train_l_sum += train_ls[-1]
valid_l_sum += valid_ls[-1]
if i == 0:
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',
range(1, num_epochs + 1), valid_ls,
['train', 'valid']) #画图,且是对y求对数了,x未变。方法实现见附录3
print('fold %d, train rmse %f, valid rmse %f' % (i, train_ls[-1], valid_ls[-1]))
return train_l_sum / k, valid_l_sum / k

*args:表示接受任意长度的参数,然后存放入一个元组中;如def fun(*args) print(args),‘fruit','animal','human'作为参数传进去,输出(‘fruit','animal','human')

**kwargs:表示接受任意长的参数,然后存放入一个字典中;如

def fun(**kwargs):
for key, value in kwargs.items():
print("%s:%s" % (key,value)

fun(a=1,b=2,c=3)会输出 a=1 b=2 c=3

附录1

loss = torch.nn.MSELoss()

def get_net(feature_num):
net = nn.Linear(feature_num, 1)
for param in net.parameters():
nn.init.normal_(param, mean=0, std=0.01)
return net

附录2

def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size):
train_ls, test_ls = [], []
dataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True) #TensorDataset和DataLoader的使用请查看我以前的博客 #这里使用了Adam优化算法
optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay)
net = net.float()
for epoch in range(num_epochs):
for X, y in train_iter:
l = loss(net(X.float()), y.float())
optimizer.zero_grad()
l.backward()
optimizer.step()
train_ls.append(log_rmse(net, train_features, train_labels))
if test_labels is not None:
test_ls.append(log_rmse(net, test_features, test_labels))
return train_ls, test_ls

 附录3

def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, legend=None, figsize=(3.5, 2.5)):
set_figsize(figsize)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.semilogy(x_vals, y_vals)
if x2_vals and y2_vals:
plt.semilogy(x2_vals, y2_vals, linestyle=':')
plt.legend(legend)

注:由于最近有其他任务,所以此博客写的匆忙,等我有时间后会丰富,也可能加详细解释。

小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)的更多相关文章

  1. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  2. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  3. 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()

    模型训练的三要素:数据处理.损失函数.优化算法    数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...

  4. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  5. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  6. 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())

    先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...

  7. Python——决策树实战:california房价预测

    Python——决策树实战:california房价预测 编译环境:Anaconda.Jupyter Notebook 首先,导入模块: import pandas as pd import matp ...

  8. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  9. 深度学习之PyTorch实战(1)——基础学习及搭建环境

    最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...

随机推荐

  1. [Codeforces] #603 (Div. 2) A-E题解

    [Codeforces]1263A Sweet Problem [Codeforces]1263B PIN Code [Codeforces]1263C Everyone is a Winner! [ ...

  2. DRF项目之自定义分页器

    在项目中,我们多需要自定义分页器. 代码实现: class PageNum(PageNumberPagination): '''自定义分页器''' # 每页显示个数 page_size = 10 pa ...

  3. Java集合基于JDK1.8的ArrayList源码分析

    本篇分析ArrayList的源码,在分析之前先跟大家谈一谈数组.数组可能是我们最早接触到的数据结构之一,它是在内存中划分出一块连续的地址空间用来进行元素的存储,由于它直接操作内存,所以数组的性能要比集 ...

  4. 无法识别的配置节 system.webServer

    Web.config文件里面加入 <configSections> <section name="system.webServer" type="Sys ...

  5. leetcode1161 Maximum Level Sum of a Binary Tree

    """ BFS遍历题,一遍AC Given the root of a binary tree, the level of its root is 1, the leve ...

  6. 使用WinDbg分析蓝屏dump原因

    大多数人或许都经历过系统蓝屏问题,然而大多数人不清楚该怎么处理蓝屏问题,这里主要对系统蓝屏做一些解释,同时介绍下蓝屏问题分析工具WinDbg分析蓝屏问题的一般步骤. 微软官方对蓝屏的定义是,当系统遇到 ...

  7. SpringBoot-集成通用mapper

    SpringBoot-集成通用mapper SpringBoot-集成通用mapper ​ 我们在SpringBoot中整合了MyBatis,但是大量重复的增删改查还是很头疼的问题,MyBatis也给 ...

  8. Golang的常量定义及使用案例

    Golang的常量定义及使用案例 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.常量的定义 package main import ( "fmt" ) fu ...

  9. webpack4-1.常见配置

    参看:文档地址 视频地址:https://www.bilibili.com/video/av51693431 webpack的作用:代码转换.文件优化.代码分割.模块管理.自动刷新.代码检验.自动发布 ...

  10. 第二阶段scrum-5

    1.整个团队的任务量: 2.任务看板: 会议照片: 产品状态: 注册登陆界面功能正在实装,前端制作完成 数据库配置中