官方github上已经有了pytorch基础模型的实现,链接

但是其中一些模型,尤其是resnet,都是用函数生成的各个层,自己看起来是真的难受!

所以自己按照caffe的样子,写一个pytorch的resnet18模型,当然和1000分类模型不同,模型做了一些修改,输入48*48的3通道图片,输出7类。

import torch.nn as nn
import torch.nn.functional as F class ResNet18Model(nn.Module):
def __init__(self):
super().__init__() self.bn64_0 = nn.BatchNorm2d(64)
self.bn64_1 = nn.BatchNorm2d(64)
self.bn64_2 = nn.BatchNorm2d(64)
self.bn64_3 = nn.BatchNorm2d(64)
self.bn64_4 = nn.BatchNorm2d(64) self.bn128_0 = nn.BatchNorm2d(128)
self.bn128_1 = nn.BatchNorm2d(128)
self.bn128_2 = nn.BatchNorm2d(128)
self.bn128_3 = nn.BatchNorm2d(128) self.bn256_0 = nn.BatchNorm2d(256)
self.bn256_1 = nn.BatchNorm2d(256)
self.bn256_2 = nn.BatchNorm2d(256)
self.bn256_3 = nn.BatchNorm2d(256) self.bn512_0 = nn.BatchNorm2d(512)
self.bn512_1 = nn.BatchNorm2d(512)
self.bn512_2 = nn.BatchNorm2d(512)
self.bn512_3 = nn.BatchNorm2d(512) self.shortcut_straight_0 = nn.Sequential()
self.shortcut_straight_1 = nn.Sequential()
self.shortcut_straight_2 = nn.Sequential()
self.shortcut_straight_3 = nn.Sequential()
self.shortcut_straight_4 = nn.Sequential() self.shortcut_conv_bn_64_128_0 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) self.shortcut_conv_bn_128_256_0 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) self.shortcut_conv_bn_256_512_0 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(512)) self.conv_w3_h3_in3_out64_s1_p1_0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out64_s1_p1_0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out128_s2_p1_0 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in128_out128_s1_p1_0 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in128_out256_s2_p1_0 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in256_out256_s1_p1_0 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in256_out512_s2_p1_0 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in512_out512_s1_p1_0 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) self.avg_pool_0 = nn.AdaptiveAvgPool2d((1, 1))
self.fc_512_7_0 = nn.Linear(512, 7)
self.dropout_0 = nn.Dropout(p=0.5) def forward(self, x): # 48*48*3
t = self.conv_w3_h3_in3_out64_s1_p1_0(x) #48*48*64
t = self.bn64_0(t)
y1 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_0(y1) #48*48*64
t = self.bn64_1(t)
y2 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_1(y2) #48*48*64
t = self.bn64_2(t)
t += self.shortcut_straight_0(y1)
y3 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_2(y3) #48*48*64
t = self.bn64_3(t)
y4 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_3(y4) #48*48*64
t = self.bn64_4(t)
t += self.shortcut_straight_1(y3)
y5 = F.relu(t) t = self.conv_w3_h3_in64_out128_s2_p1_0(y5) #24*24*128
t = self.bn128_0(t)
y6 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_0(y6) #24*24*128
t = self.bn128_1(t)
t += self.shortcut_conv_bn_64_128_0(y5)
y7 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_1(y7) #24*24*128
t = self.bn128_2(t)
y8 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_2(y8) #24*24*128
t = self.bn128_3(t)
t += self.shortcut_straight_2(y7)
y9 = F.relu(t) t = self.conv_w3_h3_in128_out256_s2_p1_0(y9) #12*12*256
t = self.bn256_0(t)
y10 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_0(y10) #12*12*256
t = self.bn256_1(t)
t += self.shortcut_conv_bn_128_256_0(y9)
y11 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_1(y11) #12*12*256
t = self.bn256_2(t)
y12 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_2(y12) #12*12*256
t = self.bn256_3(t)
t += self.shortcut_straight_3(y11)
y13 = F.relu(t) t = self.conv_w3_h3_in256_out512_s2_p1_0(y13) #6*6*512
t = self.bn512_0(t)
y14 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_0(y14) #6*6*512
t = self.bn512_1(t)
t += self.shortcut_conv_bn_256_512_0(y13)
y15 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_1(y15) #6*6*512
t = self.bn512_2(t)
y16 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_2(y16) #6*6*512
t = self.bn512_3(t)
t += self.shortcut_straight_4(y15)
y17 = F.relu(t) out = self.avg_pool_0(y17) #1*1*512
out = out.view(out.size(0), -1)
out = self.dropout_0(out)
out = self.fc_512_7_0(out) return out if __name__ == '__main__':
net = ResNet18Model()
# print(net) import torch
net_in = torch.rand(1, 3, 48, 48)
net_out = net(net_in)
print(net_out)
print(net_out.size())

  

pytorch resnet实现的更多相关文章

  1. PyTorch ResNet 使用与源码解析

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py 这篇文章首先会简 ...

  2. [源码解读] ResNet源码解读(pytorch)

    自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明. import torch import torchvision import argparse i ...

  3. 解读 pytorch对resnet的官方实现

    地址:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 贴代码 import torch.nn as ...

  4. 【深度学习】基于Pytorch的ResNet实现

    目录 1. ResNet理论 2. pytorch实现 2.1 基础卷积 2.2 模块 2.3 使用ResNet模块进行迁移学习 1. ResNet理论 论文:https://arxiv.org/pd ...

  5. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  6. Pytorch构建ResNet

    学了几天Pytorch,大致明白代码在干什么了,贴一下.. import torch from torch.utils.data import DataLoader from torchvision ...

  7. 陈云pytorch学习笔记_用50行代码搭建ResNet

    import torch as t import torch.nn as nn import torch.nn.functional as F from torchvision import mode ...

  8. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  9. 【pytorch】改造resnet为全卷积神经网络以适应不同大小的输入

    为什么resnet的输入是一定的? 因为resnet最后有一个全连接层.正是因为这个全连接层导致了输入的图像的大小必须是固定的. 输入为固定的大小有什么局限性? 原始的resnet在imagenet数 ...

随机推荐

  1. 12.su 命令与sudo 服务

     1.su 命令:解决切换用户身份的需求,使得当前用户在不退出登录的情况下,顺畅地切换到其他用户. 比如从root 管理员切换至普通用户: [root@Centos test]# id uid=0(r ...

  2. TCP/IP__IP寻址及ARP解析

    ARP解析过程中MAC地址以及IP地址的变化情况 1.两主机要通信传送数据时,就要把应用数据封装成IP包,然后再交给下一层数据链路层继续封装成帧:之后根据MAC地址才能把数据从一台主机,准确无误的传送 ...

  3. 记angular和asp.net使用grpc进行通信

    AspNetCore配置grpc服务端 新建一个Demo项目: GrpcStartup, 目录结构如下图: GrpcStartup.GrpcServices需要安装下面的依赖 <PackageR ...

  4. Codeforces Round #682 (Div. 2)【ABCD】

    比赛链接:https://codeforces.com/contest/1438 A. Specific Tastes of Andre 题意 构造一个任意连续子数组元素之和为子数组长度倍数的数组. ...

  5. F - To Add Which?

    Description There is an integer sequence with N integers. You can use 1 unit of cost to increase any ...

  6. Codeforces Round #673 (Div. 2) B. Two Arrays (贪心)

    题意:给你一组数\(a\)和一个数\(T\),将这组数分为两组\(c\)和\(d\),定义\(f(x)\)为数组\(x\)中任意两个不同元素的和为\(T\)的个数,问为了使\(min(f(c)+f(d ...

  7. WPF 无法对元素“TextBox”设置 Name 特性值“TB2”

    错误信息 无法对元素"TextBox"设置 Name 特性值"TB2"."TextBox"在元素"UserControl1&quo ...

  8. 通过k8s部署dubbo微服务并接入ELK架构

    需要这样一套日志收集.分析的系统: 收集 -- 能够采集多种来源的日志数据 (流式日志收集器) 传输 -- 能够稳定的把日志数据传输到中央系统 (消息队列) 存储 -- 可以将日志以结构化数据的形式存 ...

  9. 爬虫——urllib.request包

    一.引用包 import urllib.request 二.常用方法 (1)urllib.request.urlretrieve(网址,本地文件存储地址):直接下载网页到本地 urllib.reque ...

  10. C++ part3

    函数和const references: C++中const用于函数重载 有些情况可以重载,有些不行,具体看↑. 隐式类型转换 references: nowcoder 对于内置类型,低精度的变量给高 ...