自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明。

import torch
import torchvision
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import torch.utils.model_zoo as model_zoo
import math __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152'] model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
} def conv3x3(in_planes, out_planes, stride=1):
# 3x3 kernel
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) # get BasicBlock which layers < 50(18, 34)
class BasicBlock(nn.Module):
expansion = 1 def __init__(self, in_planes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.BN = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride) # outplane is not in_planes*self.expansion, is planes
self.stride = stride
self.downsample = downsample def forward(self, x):
residual = x # mark the data before BasicBlock
x = self.conv1(x)
x = self.BN(x)
x = self.relu(x)
x = self.conv2(x)
x = self.BN(x) # BN operation is before relu operation
if self.downsample is not None: # is not None
residual = self.downsample(residual) # resize the channel
x += residual
x = self.relu(x)
return x # get BottleBlock which layers >= 50
class Bottleneck(nn.Module):
expansion = 4 # the factor of the last layer of BottleBlock and the first layer of it def __init__(self, in_planes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.con2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*4)
self.downsample = downsample
self.stride = stride
self.relu = nn.ReLU(inplace=True) def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x) x = self.con2(x)
x = self.bn2(x)
x = self.relu(x) x = self.conv3(x)
x = self.bn3(x)
if self.downsample is not None:
residual = self.downsample(residual) x += residual
x = self.relu(x) return x class ResNet(nn.Module): def __init__(self, block, layers, num_classes=100):
self.inplanes = 64 # the original channel
super(ResNet, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 以下构建残差块, 具体参数可以查看resnet参数表
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.average_pool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512*block.expansion, num_classes)
# 对卷积和与BN层初始化,论文中也提到过
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# 这里是为了结局两个残差块之间可能维度不匹配无法直接相加的问题,相同类型的残差块只需要改变第一个输入的维数就好,后面的输入维数都等于输出维数
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None # 扩维
if stride != 1 or self.inplanes != block.expansion * planes:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, block.expansion*planes,kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(block.expansion*planes)
) layers = []
# 特判第一残差块
layers.append(block(self.inplanes, planes, downsample=downsample)) # outplane is planes not planes*block.expansion
self.inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.max_pool(x) x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.average_pool(x)
x = x.view(x.size(0), -1) # resize batch-size x H
x = self.fc(x)
return x def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model

[源码解读] ResNet源码解读(pytorch)的更多相关文章

  1. RxJava系列6(从微观角度解读RxJava源码)

    RxJava系列1(简介) RxJava系列2(基本概念及使用介绍) RxJava系列3(转换操作符) RxJava系列4(过滤操作符) RxJava系列5(组合操作符) RxJava系列6(从微观角 ...

  2. 入口开始,解读Vue源码(一)-- 造物创世

    Why? 网上现有的Vue源码解析文章一搜一大批,但是为什么我还要去做这样的事情呢?因为觉得纸上得来终觉浅,绝知此事要躬行. 然后平时的项目也主要是Vue,在使用Vue的过程中,也对其一些约定产生了一 ...

  3. JVM源码分析之SystemGC完全解读

    JVM源码分析之SystemGC完全解读 概述 JVM的GC一般情况下是JVM本身根据一定的条件触发的,不过我们还是可以做一些人为的触发,比如通过jvmti做强制GC,通过System.gc触发,还可 ...

  4. Spring源码-循环依赖源码解读

    Spring源码-循环依赖源码解读 笔者最近无论是看书还是从网上找资料,都没发现对Spring源码是怎么解决循环依赖这一问题的详解,大家都是解释了Spring解决循环依赖的想法(有的解释也不准确,在& ...

  5. Derek解读Bytom源码-持久化存储LevelDB

    作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...

  6. Derek解读Bytom源码-创世区块

    作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...

  7. Redux学习之解读applyMiddleware源码深入middleware工作机制

    随笔前言 在上一周的学习中,我们熟悉了如何通过redux去管理数据,而在这一节中,我们将一起深入到redux的知识中学习. 首先谈一谈为什么要用到middleware 我们知道在一个简单的数据流场景中 ...

  8. SpringMVC源码解读 - RequestMapping注解实现解读 - RequestMappingInfo

    使用@RequestMapping注解时,配置的信息最后都设置到了RequestMappingInfo中. RequestMappingInfo封装了PatternsRequestCondition, ...

  9. SpringMVC源码解读 - RequestMapping注解实现解读 - RequestCondition体系

    一般我们开发时,使用最多的还是@RequestMapping注解方式. @RequestMapping(value = "/", param = "role=guest& ...

随机推荐

  1. Linux系统内核参数优化

    Linux服务器内核参数优化 cat >> /etc/sysctl.conf << EOF # kernel optimization net.ipv4.tcp_fin_tim ...

  2. HDFS基本操作的API

    一.从hdfs下载文件到windows本地: package com.css.hdfs01; import java.io.IOException; import java.net.URI; impo ...

  3. 一次漫长的服务CPU优化过程

    从师父那里接了个服务,每天单机的流量并不大,峰值tips也并不高,但是CPU却高的异常.由于,服务十分重要,这个服务最高时占用了100个docker节点在跑,被逼无奈开始了异常曲折的查因和优化过程. ...

  4. 系列:每日一linux命令(转)

    原文:http://www.cnblogs.com/peida/archive/2012/12/05/2803591.html 一. 文件目录操作命令: 1.每天一个linux命令(1):ls命令 2 ...

  5. Centos配置nginx反向代理8090端口到80端口

    下面,我就来说说怎么反向代理自己的项目到默认80端口. 1)安装nginx:yum install nginx -y 2)启动nginx:service nginx start或者systemctl ...

  6. 登录plsql 报错 the account is locked --用户被锁

    登录数据库服务器,进入oracle用户下: [root@uumsnormal-oracle admin]# su - oracle [oracle@uumsnormal-oracle ~]$ sqlp ...

  7. POJ2653:Pick-up sticks(线段相交)

    题目:http://poj.org/problem?id=2653 题意:题意很简单,就是在地上按顺序撒一对木棒,看最后有多少是被压住的,输出没有被压住的木棒的序号.(有点坑的就是没说清楚木棒怎么算压 ...

  8. 转载SQL_trace 和10046使用

    SQL_TRACE是Oracle提供的用于进行SQL跟踪的手段,是强有力的辅助诊断工具.在日常的数据库问题诊断和解决中,SQL_TRACE是非常常用的方法.本文就SQL_TRACE的使用作简单探讨,并 ...

  9. vue——学习笔记

    1.vue需要在dom加载完成之后实现实例化 eg: window.onload = function(){ new Vue({ el: '#editor', data: { input: '# he ...

  10. session和token的区别

    session的使用方式是客户端cookie里存id,服务端session存用户数据,客户端访问服务端的时候,根据id找用户数据 而token一般翻译成令牌,一般是用于验证表明身份的数据或是别的口令数 ...