自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明。

import torch
import torchvision
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import torch.utils.model_zoo as model_zoo
import math __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152'] model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
} def conv3x3(in_planes, out_planes, stride=1):
# 3x3 kernel
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) # get BasicBlock which layers < 50(18, 34)
class BasicBlock(nn.Module):
expansion = 1 def __init__(self, in_planes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.BN = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride) # outplane is not in_planes*self.expansion, is planes
self.stride = stride
self.downsample = downsample def forward(self, x):
residual = x # mark the data before BasicBlock
x = self.conv1(x)
x = self.BN(x)
x = self.relu(x)
x = self.conv2(x)
x = self.BN(x) # BN operation is before relu operation
if self.downsample is not None: # is not None
residual = self.downsample(residual) # resize the channel
x += residual
x = self.relu(x)
return x # get BottleBlock which layers >= 50
class Bottleneck(nn.Module):
expansion = 4 # the factor of the last layer of BottleBlock and the first layer of it def __init__(self, in_planes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.con2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*4)
self.downsample = downsample
self.stride = stride
self.relu = nn.ReLU(inplace=True) def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x) x = self.con2(x)
x = self.bn2(x)
x = self.relu(x) x = self.conv3(x)
x = self.bn3(x)
if self.downsample is not None:
residual = self.downsample(residual) x += residual
x = self.relu(x) return x class ResNet(nn.Module): def __init__(self, block, layers, num_classes=100):
self.inplanes = 64 # the original channel
super(ResNet, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 以下构建残差块, 具体参数可以查看resnet参数表
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.average_pool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512*block.expansion, num_classes)
# 对卷积和与BN层初始化,论文中也提到过
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# 这里是为了结局两个残差块之间可能维度不匹配无法直接相加的问题,相同类型的残差块只需要改变第一个输入的维数就好,后面的输入维数都等于输出维数
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None # 扩维
if stride != 1 or self.inplanes != block.expansion * planes:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, block.expansion*planes,kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(block.expansion*planes)
) layers = []
# 特判第一残差块
layers.append(block(self.inplanes, planes, downsample=downsample)) # outplane is planes not planes*block.expansion
self.inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.max_pool(x) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.average_pool(x)
x = x.view(x.size(0), -1) # resize batch-size x H
x = self.fc(x)
return x def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model

[源码解读] ResNet源码解读(pytorch)的更多相关文章

  1. RxJava系列6(从微观角度解读RxJava源码)

    RxJava系列1(简介) RxJava系列2(基本概念及使用介绍) RxJava系列3(转换操作符) RxJava系列4(过滤操作符) RxJava系列5(组合操作符) RxJava系列6(从微观角 ...

  2. 入口开始,解读Vue源码(一)-- 造物创世

    Why? 网上现有的Vue源码解析文章一搜一大批,但是为什么我还要去做这样的事情呢?因为觉得纸上得来终觉浅,绝知此事要躬行. 然后平时的项目也主要是Vue,在使用Vue的过程中,也对其一些约定产生了一 ...

  3. JVM源码分析之SystemGC完全解读

    JVM源码分析之SystemGC完全解读 概述 JVM的GC一般情况下是JVM本身根据一定的条件触发的,不过我们还是可以做一些人为的触发,比如通过jvmti做强制GC,通过System.gc触发,还可 ...

  4. Spring源码-循环依赖源码解读

    Spring源码-循环依赖源码解读 笔者最近无论是看书还是从网上找资料,都没发现对Spring源码是怎么解决循环依赖这一问题的详解,大家都是解释了Spring解决循环依赖的想法(有的解释也不准确,在& ...

  5. Derek解读Bytom源码-持久化存储LevelDB

    作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...

  6. Derek解读Bytom源码-创世区块

    作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...

  7. Redux学习之解读applyMiddleware源码深入middleware工作机制

    随笔前言 在上一周的学习中,我们熟悉了如何通过redux去管理数据,而在这一节中,我们将一起深入到redux的知识中学习. 首先谈一谈为什么要用到middleware 我们知道在一个简单的数据流场景中 ...

  8. SpringMVC源码解读 - RequestMapping注解实现解读 - RequestMappingInfo

    使用@RequestMapping注解时,配置的信息最后都设置到了RequestMappingInfo中. RequestMappingInfo封装了PatternsRequestCondition, ...

  9. SpringMVC源码解读 - RequestMapping注解实现解读 - RequestCondition体系

    一般我们开发时,使用最多的还是@RequestMapping注解方式. @RequestMapping(value = "/", param = "role=guest& ...

随机推荐

  1. 解决mysql不能通过'/tmp/mysql.sock 连接的问题

    解决方法:php标准配置正是通过'/tmp/mysql.sock',但一些mysql安装方法将mysql.sock放在/var/lib/mysql.sock或者其他地方,你可以通过修改/etc/my. ...

  2. 【npm start 启动失败】ubuntu 将node和npm同时更新到最新的稳定版本

    https://blog.csdn.net/u010277553/article/details/80938829 npm start 启动失败,报错如下 错误提示 make sure you hav ...

  3. 棋盘格 测量 相机近似精度 (像素精度&物理精度)

    像素精度计算 像素精度——一像素对应多少毫米——距离不同像素精度也不同 将棋盘格与相机CCD平面大致平行摆放,通过[每个点处的近似像素精度=相邻两个角点之间的实际距离(棋盘格尺寸已知)/ 棋盘格上检出 ...

  4. windows中根据进程PID查找进程对象过程深入分析

    这里windows和Linxu系列的PID 管理方式有所不同,windows中进程的PID和句柄没有本质区别,根据句柄索引对象和根据PID或者TID查找进程或者线程的步骤也是一样的.   句柄是针对进 ...

  5. Spark的集群管理器

    上篇文章谈到Driver节点和Executor节点,但是如果想要运行Driver节点和Executor节点,就不能不说spark的集群管理器.spark的集群管理器大致有三种,一种是自带的standa ...

  6. 奇异值与主成分分析(PCA)

    主成分分析在上一节里面也讲了一些,这里主要谈谈如何用SVD去解PCA的问题.PCA的问题其实是一个基的变换,使得变换后的数据有着最大的方差.方差的大小描述的是一个变量的信息量,我们在讲一个东西的稳定性 ...

  7. isEnable() 和 isDisplayed() 和 isSelected()

    isEnable().isDisplayed()和isSelected() 1.以上三个为布尔类型的函数 2.isEnable用于存储input.select等元素的可编辑状态,可以编辑返回true, ...

  8. td中不包含汉字的字符串不换行,包含汉字的能换行的问题原因及解决方法

    今天项目中遇到一个问题,一长串的字符串如:003403FF0014E54016030CC655BC3242,但是如:中国河北省石家庄市裕华区槐安路雅清街交口 这样的就可以换行. 原因是:英文字母之间如 ...

  9. 《mysql必知必会》读书笔记--触发器及管理事务处理

    触发器 触发器是MySQL响应DELETE,INSERT,UPDATE而自动执行的一条MySQL语句,其他语句不支持触发器. 创建触发器时,需要4个条件: 唯一的触发器名 触发器关联的表 触发器应该响 ...

  10. Spring-1-H Number Sequence(HDU 5014)解题报告及测试数据

    Number Sequence Time Limit:2000MS     Memory Limit:65536KB     64bit IO Format:%I64d & %I64u Pro ...