基本的卷积神经网络

from torch import nn

class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
layer1 = nn.Sequential() # 将网络模型进行添加
layer1.add_module('conv1', nn.Conv2d(3, 32, 3, 1, padding=1)) # nn.Conv
layer1.add_module('relu1', nn.ReLU(True))
layer1.add_module('pool1', nn.MaxPool2d(2, 2))
self.layer1 = layer1 layer2 = nn.Sequential()
layer2.add_module('conv2', nn.Conv2d(32, 64, 3, 1, padding=1))
layer2.add_module('relu2', nn.ReLU(True))
layer2.add_module('pool2', nn.MaxPool2d(2, 2))
self.layer2 = layer2 layer3 = nn.Sequential()
layer3.add_module('conv3', nn.Conv2d(64, 128, 3, 1, padding=1))
layer3.add_module('relu3', nn.ReLU(True))
layer3.add_module('pool3', nn.MaxPool2d(2, 2))
self.layer3 = layer3 layer4 = nn.Sequential()
layer4.add_module('fc1', nn.Linear(2048, 512))
layer4.add_module('fc_relu1', nn.ReLU(True))
layer4.add_module('fc2', nn.Linear(512, 64))
layer4.add_module('fc_relu2', nn.ReLU(True))
layer4.add_module('fc3', nn.Linear(64, 10))
self.layer4 = layer4 def forward(self, x):
conv1 = self.layer1(x)
conv2 = self.layer2(conv1)
conv3 = self.layer3(conv2)
fc_input = conv3.view(conv3.size(0), -1)
fc_out = self.layer4(fc_input) return fc_out model = SimpleCNN()
# print(model) # 打印输出网络结构

提取前两层网络结构

new_model = nn.Sequential(*list(model.children())[:2])  # 提取前两层的网络结构, 构造nn.Sequential网络串接, * 表示将里面的内容一个个传进去

提取所有的卷积层网络

conv_model = nn.Sequential()
# 提取所有的卷积层操作
for name, layer in model.named_modules():
if isinstance(layer, nn.Conv2d):
name = name.replace('.', '_')
conv_model.add_module(name, layer)
print(conv_model)

打印卷积层的网络名字

for param in model.named_parameters():
print(param)

对权重参数进行初始化操作

from torch.nn import init
# 对权重参数进行初始化操作
for m in model.modules():
if isinstance(m, nn.Conv2d):
init.normal(m.weight.data)
init.xavier_normal(m.weight.data)
init.kaiming_normal(m.weight.data)
elif isinstance(m, nn.Linear):
m.weight.data.normal_()

pytorch-卷积基本网络结构-提取网络参数-初始化网络参数的更多相关文章

  1. pytorch和tensorflow的爱恨情仇之参数初始化

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch和tensorflow的爱恨情仇之定义可训练的参数 pytorch版本 ...

  2. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  3. PyTorch常用参数初始化方法详解

    1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...

  4. PyTorch模型读写、参数初始化、Finetune

    使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...

  5. pytorch搭建网络,保存参数,恢复参数

    这是看过莫凡python的学习笔记. 搭建网络,两种方式 (1)建立Sequential对象 import torch net = torch.nn.Sequential( torch.nn.Line ...

  6. Pytorch基础(6)----参数初始化

    一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...

  7. pytorch对模型参数初始化

    1.使用apply() 举例说明: Encoder :设计的编码其模型 weights_init(): 用来初始化模型 model.apply():实现初始化 # coding:utf- from t ...

  8. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  9. 从零搭建Pytorch模型教程(四)编写训练过程--参数解析

    ​  前言 训练过程主要是指编写train.py文件,其中包括参数的解析.训练日志的配置.设置随机数种子.classdataset的初始化.网络的初始化.学习率的设置.损失函数的设置.优化方式的设置. ...

随机推荐

  1. ABAP下载的病毒扫描Virus Scan

    当我使用CL_HTTP_ENTITY=>IF_HTTP_ENTITY~GET_DATA从网络下载数据时,遇到异常CX_VSI: 错误原因是数据从网络下载到Netweaver服务器上之后,在服务器 ...

  2. 018.查询练习50题(sql实例)

    CREATE TABLE EMP(EMPNO numeric(5,0) NOT NULL primary key,--雇员的编号ENAME nvarchar(10) not null,--雇员的名字J ...

  3. windows BAT脚本2个服务器间传递文件

    1. 脚本功能: 实现2个服务器间文件的传递,例如从A服务器往B服务器上传文件 2. 实现步骤: 2.1 服务器连结,找到指定路径,读取所需要上传的文件,将文件名称复制到一个文件下 (此处考虑可能需要 ...

  4. golang shell 交叉编译

    #!/usr/bin/env bash set -e uname_s=`uname -s | awk '{print tolower($0)}'` uname_m=`uname -m` timeTag ...

  5. delphi xe5 fastreport4.14 中文很多时换行不正确

    用一般的frxMEMOview 中文换行是瞎换,缺少数据,换成frxrichview 即可, frxrichview 使用注意点 1).Delphi中文很多时换行不正确 2).要在窗体上拖一个frxr ...

  6. Swagger保姆级教学

    Swagger保姆级教学 Swagger 简介 Swagger 是一个规范和完整的框架,用于生成.描述.调用和可视化 RESTful 风格的 Web 服务.总体目标是使客户端和文件系统作为服务器以同样 ...

  7. sql 拼接字符串单条拆分多条

    SELECT * FROM ( SELECT A.WS_ID , B.NEXT_OPERATOR FROM ( SELECT WS_ID , [NEXT_OPERATOR] = CONVERT(XML ...

  8. redis + boost.asio

    redis是一个key-value存储系统.和Memcached类似,它支持存储的value类型相对更多,包括string(字符串).list(链表).set(集合).zset(sorted set ...

  9. 关于微信小程序的本地存储

    微信小程序中会使用wx.setStorage(wx.setStorageSync)来存储数据,问题是:即使小程序被销毁了,本地缓存的数据仍然存在.会造成: 所以要及时清理掉本地缓存的数据.解决思路: ...

  10. Hivesql中的正则

    ================================================================================================= 一般 ...