最近在看DARTS的代码,有一个operations.py的文件,里面是对各类点与点之间操作的方法。

OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}

首先定义10个操作,依次解释:

  • class PoolBN(nn.Module):
    """
    AvgPool or MaxPool - BN
    """
    def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
    """
    Args:
    pool_type: 'max' or 'avg'
    """
    super().__init__()
    if pool_type.lower() == 'max':
    self.pool = nn.MaxPool2d(kernel_size, stride, padding)
    elif pool_type.lower() == 'avg':
    self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
    else:
    raise ValueError() self.bn = nn.BatchNorm2d(C, affine=affine) def forward(self, x):
    out = self.pool(x)
    out = self.bn(out)
    return out

    这是池化函数,有最大池化和平均池化方法,count_include_pad=False表示不把填充的0计算进去

  • class Identity(nn.Module):
    def __init__(self):
    super().__init__() def forward(self, x):
    return x

    这个表示skip conncet

  • class FactorizedReduce(nn.Module):
    """
    Reduce feature map size by factorized pointwise(stride=2).
    """
    def __init__(self, C_in, C_out, affine=True):
    super().__init__()
    self.relu = nn.ReLU()
    self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x):
    x = self.relu(x)
    out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
    out = self.bn(out)
    return out

    这个表示将特征图大小变为原来的一半

  • class DilConv(nn.Module):
    """ (Dilated) depthwise separable conv
    ReLU - (Dilated) depthwise separable - Pointwise - BN If dilation == 2, 3x3 conv => 5x5 receptive field
    5x5 conv => 9x9 receptive field
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
    bias=False),
    nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,groups=C_in,表示把输入特种图分成C_in(输入通道数)那么多组,然后加C_out(输出通道数)1*1的卷积,这样可以对每个通道单独提取特征,同时降低了参数量和计算量。

  • class SepConv(nn.Module):
    """ Depthwise separable conv
    DilConv(dilation=1) * 2
    """
    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
    DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    深度可分离卷积,由两个上面的深度分组卷积组成

  • class FacConv(nn.Module):
    """ Factorized conv
    ReLU - Conv(Kx1) - Conv(1xK) - BN
    """
    def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
    super().__init__()
    self.net = nn.Sequential(
    nn.ReLU(),
    nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
    nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
    nn.BatchNorm2d(C_out, affine=affine)
    ) def forward(self, x):
    return self.net(x)

    这个表示长方形的卷积,增加了一点特征图的长和宽

  • class Zero(nn.Module):
    def __init__(self, stride):
    super().__init__()
    self.stride = stride def forward(self, x):
    if self.stride == 1:
    return x * 0. # re-sizing by stride
    return x[:, :, ::self.stride, ::self.stride] * 0.

    这个表示把特种图的输出变为全是0,但特征图的大小会根据stride而改变

DARTS代码分析(Pytorch)的更多相关文章

  1. Android代码分析工具lint学习

    1 lint简介 1.1 概述 lint是随Android SDK自带的一个静态代码分析工具.它用来对Android工程的源文件进行检查,找出在正确性.安全.性能.可使用性.可访问性及国际化等方面可能 ...

  2. pmd静态代码分析

    在正式进入测试之前,进行一定的静态代码分析及code review对代码质量及系统提高是有帮助的,以上为数据证明 Pmd 它是一个基于静态规则集的Java源码分析器,它可以识别出潜在的如下问题:– 可 ...

  3. [Asp.net 5] DependencyInjection项目代码分析-目录

    微软DI文章系列如下所示: [Asp.net 5] DependencyInjection项目代码分析 [Asp.net 5] DependencyInjection项目代码分析2-Autofac [ ...

  4. [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(5)(IEnumerable<>补充)

    Asp.net 5的依赖注入注入系列可以参考链接: [Asp.net 5] DependencyInjection项目代码分析-目录 我们在之前讲微软的实现时,对于OpenIEnumerableSer ...

  5. 完整全面的Java资源库(包括构建、操作、代码分析、编译器、数据库、社区等等)

    构建 这里搜集了用来构建应用程序的工具. Apache Maven:Maven使用声明进行构建并进行依赖管理,偏向于使用约定而不是配置进行构建.Maven优于Apache Ant.后者采用了一种过程化 ...

  6. STM32启动代码分析 IAR 比较好

    stm32启动代码分析 (2012-06-12 09:43:31) 转载▼     最近开始使用ST的stm32w108芯片(也是一款zigbee芯片).开始看他的启动代码看的晕晕呼呼呼的. 还好在c ...

  7. 常用 Java 静态代码分析工具的分析与比较

    常用 Java 静态代码分析工具的分析与比较 简介: 本文首先介绍了静态代码分析的基 本概念及主要技术,随后分别介绍了现有 4 种主流 Java 静态代码分析工具 (Checkstyle,FindBu ...

  8. SonarQube-5.6.3 代码分析平台搭建使用

    python代码分析 官网主页: http://docs.sonarqube.org/display/PLUG/Python+Plugin Windows下安装使用: 快速使用: 1.下载jdk ht ...

  9. angular代码分析之异常日志设计

    angular代码分析之异常日志设计 错误异常是面向对象开发中的记录提示程序执行问题的一种重要机制,在程序执行发生问题的条件下,异常会在中断程序执行,同时会沿着代码的执行路径一步一步的向上抛出异常,最 ...

随机推荐

  1. 运行别人的Vue项目

    步骤一:先 安装 cnpm cmd命令下 输入  npm install -g cnpm --registry=http://registry.npm.taobao.org (由于npm有些资源被屏蔽 ...

  2. mysql中source提高导入数据速率的方法

    示例: 第一步: 第二步: 使用 source 导入你所需要导入的文件 第三步: 在导入的数据停止后,输入  commit; 这样数据就算是导入完成了.

  3. Python-multiprocessing-Process模块

    获取当前执行该文件的进程ID import os # 获取当前执行该文件的进程ID print("Process (%s) start..." % os.getpid()) mul ...

  4. matplotlib(一):散点图

    import numpy as np import matplotlib.pyplot as plt #产生测试数据 # x,y为数组 N = 50 x = np.random.rand(N) y=n ...

  5. 项目部署中,tomcat报java.lang.OutOfMemoryError: PermGen space

    原因: PermGen space的全称是Permanent Generation space,是指内存的永久保存区域,这块内存主要是被JVM存放Class和Meta信息的,Class在被Loader ...

  6. 进程间通信之管道--pipe和fifo使用

    匿名管道pipe 函数原型: #include <unistd.h> int pipe(int fildes[2]); 参数说明 fildes是我们传入的数组,也是一个传出参数.filde ...

  7. P1598 垂直柱状图

    输入格式: 四行字符,由大写字母组成,每行不超过100个字符 输出格式: 由若干行组成,前几行由空格和星号组成,最后一行则是由空格和字母组成的.在任何一行末尾不要打印不需要的多余空格.不要打印任何空行 ...

  8. 微信小程序_(表单组件)checkbox与label

    微信小程序组件checkbox官方文档 传送门 微信小程序组件label官方文档 传送门 Learn 一.checkbox组件 二.label组件与checkbox组件共用 一.checkbox组件 ...

  9. koa 基础(三)路由的另一种写法

    1.配置路由 app.js // 引入模块 const Koa = require('koa'); const router = require('koa-router')(); /*引入是实例化路由 ...

  10. FPGA实战操作(2) -- PCIe总线(例程设计分析)

    1.框架总览 平台:vivado 2016.4 FPGA:A7 在实际应用中,我们几乎不可能自己去编写接口协议,所以在IP核的例程上进行修改来适用于项目是个不错的选择. 通过vivado 中有关PCI ...