#torch
import torch
import torch.nn as nn
import torch.nn.functional as F class Net(nn.Module): def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features net = Net()
print(net) params = list(net.parameters())
print(len(params))
print(params[0].size()) # conv1's .weigh input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

vgg


#从keras.model中导入model模块,为函数api搭建网络做准备
from tensorflow.keras import Model
from tensorflow.keras.layers import Flatten,Dense,Dropout,MaxPooling2D,Conv2D,BatchNormalization,Input,ZeroPadding2D,Concatenate
from tensorflow.keras import *
from tensorflow.keras import regularizers #正则化
from tensorflow.keras.optimizers import RMSprop #优化选择器
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.python.keras.utils import np_utils #数据处理
(X_train,Y_train),(X_test,Y_test)=mnist.load_data()
X_test1=X_test
Y_test1=Y_test
X_train=X_train.reshape(-1,28,28,1).astype("float32")/255.0
X_test=X_test.reshape(-1,28,28,1).astype("float32")/255.0
Y_train=np_utils.to_categorical(Y_train,10)
Y_test=np_utils.to_categorical(Y_test,10)
print(X_train.shape)
print(Y_train.shape)
print(X_train.shape) def vgg16():
x_input = Input((28, 28, 1)) # 输入数据形状28*28*1
# Block 1
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(x_input)
x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) # Block 2
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) # Block 3
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) # Block 4
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) # Block 5
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) #BLOCK 6
x=Flatten()(x)
x=Dense(256,activation="relu")(x)
x=Dropout(0.5)(x)
x = Dense(256, activation="relu")(x)
x = Dropout(0.5)(x)
#搭建最后一层,即输出层
x = Dense(10, activation="softmax")(x)
# 调用MDOEL函数,定义该网络模型的输入层为X_input,输出层为x.即全连接层
model = Model(inputs=x_input, outputs=x)
# 查看网络模型的摘要
model.summary()
return model model=vgg16()
optimizer=RMSprop(lr=1e-4)
model.compile(loss="binary_crossentropy",optimizer=optimizer,metrics=["accuracy"])
#训练加评估模型
n_epoch=4
batch_size=128
def run_model(): #训练模型
training=model.fit(
X_train,
Y_train,
batch_size=batch_size,
epochs=n_epoch,
validation_split=0.25,
verbose=1
)
test=model.evaluate(X_train,Y_train,verbose=1)
return training,test
training,test=run_model()
print("误差:",test[0])
print("准确率:",test[1]) def show_train(training_history,train, validation):
plt.plot(training.history[train],linestyle="-",color="b")
plt.plot(training.history[validation] ,linestyle="--",color="r")
plt.title("training history")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend(["training","validation"],loc="lower right")
plt.show()
show_train(training,"accuracy","val_accuracy") def show_train1(training_history,train, validation):
plt.plot(training.history[train],linestyle="-",color="b")
plt.plot(training.history[validation] ,linestyle="--",color="r")
plt.title("training history")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(["training","validation"],loc="upper right")
plt.show()
show_train1(training,"loss","val_loss") prediction=model.predict(X_test)
def image_show(image):
fig=plt.gcf() #获取当前图像
fig.set_size_inches(2,2) #改变图像大小
plt.imshow(image,cmap="binary") #显示图像
plt.show()
def result(i):
image_show(X_test1[i])
print("真实值:",Y_test1[i])
print("预测值:",np.argmax(prediction[i]))
result(0)
result(1)


torch& tensorflow的更多相关文章

  1. Tutorial: Implementation of Siamese Network on Caffe, Torch, Tensorflow

    Tutorial: Implementation of Siamese Network with Caffe, Theano, PyTorch, Tensorflow  Updated on 2018 ...

  2. torch 入门

    torch 入门1.安装环境我的环境mac book pro 集成显卡 Intel Iris不能用 cunn 模块,因为显卡不支持 CUDA2.安装步骤: 官方文档 (1).git clone htt ...

  3. 学习Data Science/Deep Learning的一些材料

    原文发布于我的微信公众号: GeekArtT. 从CFA到如今的Data Science/Deep Learning的学习已经有一年的时间了.期间经历了自我的兴趣.擅长事务的探索和试验,有放弃了的项目 ...

  4. pytorch使用不完全文档

    1. 利用tensorboard看loss: tensorflow和pytorch环境是好的的话,链接中的logger.py拉到自己的工程里,train.py里添加相应代码,直接能用. 关于环境,小小 ...

  5. CS231n 2016 通关 第一章-内容介绍

    第一节视频的主要内容: Fei-Fei Li 女神对Computer Vision的整体介绍.包括了发展历史中的重要事件,其中最为重要的是1959年测试猫视觉神经的实验. In 1959 Harvar ...

  6. 深度学习框架caffe/CNTK/Tensorflow/Theano/Torch的对比

    在单GPU下,所有这些工具集都调用cuDNN,因此只要外层的计算或者内存分配差异不大其性能表现都差不多. Caffe: 1)主流工业级深度学习工具,具有出色的卷积神经网络实现.在计算机视觉领域Caff ...

  7. Torch,Tensorflow使用: Ubuntu14.04(x64)+ CUDA8.0 安装 Torch和Tensorflow

    系统配置: Ubuntu14.04(x64) CUDA8.0 cudnn-8.0-linux-x64-v5.1.tgz(Tensorflow依赖) Anaconda 1. Torch安装 Torch是 ...

  8. 一图看懂深度学习框架对比----Caffe Torch Theano TensorFlow

      Caffe Torch Theano TensorFlow Language C++, Python Lua Python Python Pretrained Yes ++ Yes ++ Yes ...

  9. tensorflow,torch tips

    apply weightDecay,L2 REGULARIZATION_LOSSES weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIAB ...

  10. 关于类型为numpy,TensorFlow.tensor,torch.tensor的shape变化以及相互转化

    https://blog.csdn.net/zz2230633069/article/details/82669546 2018年09月12日 22:56:50 一只tobey 阅读数:727   1 ...

随机推荐

  1. 基于C++的OpenGL 14 之模型加载

    1. 引言 本文基于C++语言,描述OpenGL的模型加载 前置知识可参考: 基于C++的OpenGL 13 之Mesh - 当时明月在曾照彩云归 - 博客园 (cnblogs.com) 笔者这里不过 ...

  2. c#反射优化

    https://www.cnblogs.com/xinaixia/p/5777886.html https://www.cnblogs.com/xinaixia/p/5777961.html

  3. 【转载】python解决文本乱码问题及文本二进制读取后的处理

    转自:https://blog.csdn.net/u011316258/article/details/50450079 python解决文本乱码问题及文本二进制读取后的处理 吲哚乙酸 当文本中含有很 ...

  4. [GKCTF2021]random

    [GKCTF2021]random 本题出现了MT19937伪随机数生成算法. 目录 [GKCTF2021]random 题目 分析 MT19937算法 步骤 代码实现 解法1 解法2 总结 题目 t ...

  5. 概率生成函数(PGF)简记

    基本搬运自<浅谈生成函数在掷骰子问题上的应用>. 对于定义在非负整数上的离散随机变量 \(X\),级数 \(F(z) = \sum\limits_{i\ge 0} \operatornam ...

  6. soursetree 关于https:git remote: Unauthorized和username和password修改

    一.sourcetree推送代码提交不上提示https:git remote: Unauthorized由于没有权限,需要登陆正确的账号以及密码即可以提交 二.SourceTree这是一个无效源路径/ ...

  7. 向mysql插入数据报错 pymysql.err.DataError: (1406, "Data too long for column 'class' at row 1") 解决方案

    这个问题一开始更换数据类型或者数据类型的大小,发现还是不行.后面通过网上查询了一条神奇的sql语句分分钟钟的解决了 问题原因明明是: 字段的长度不够存放数据 解决方案: 在mysql命令行输入如下:S ...

  8. 在docker中导入python的包时ImportError: libgthread-2.0.so.0: cannot open shared object file: No such file or directory

    问题: ImportError: libGL.so.1: cannot open shared object file: No such file or directory ImportError: ...

  9. CodeGym自学笔记07——入门Java书籍

    入门Java书籍 Head First Java Java:The Complete Reference,作者:Herbert Schildt   这本书对初学者也很有好处.与前一本书的主要区别在于素 ...

  10. window下快速启动mysql,bat脚本

    cls @echo off:设置窗口字体颜色color 0a :设置窗口标题TITLE MySQL管理程序 call :checkAdmin goto menu:菜单:menuclsecho. ech ...