PyTorch 实战-张量
Numpy 是一个非常好的框架,但是不能用 GPU 来进行数据运算。
Numpy is a great framework, but it cannot utilize GPUs to accelerate its numerical computations. For modern deep neural networks, GPUs often provide speedups of 50x
or greater, so unfortunately numpy won’t be enough for modern deep learning.
Here we introduce the most fundamental PyTorch concept: the Tensor. A PyTorch Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional array, and PyTorch provides many functions for operating on these Tensors. Like
numpy arrays, PyTorch Tensors do not know anything about deep learning or computational graphs or gradients; they are a generic tool for scientific computing.
However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it to a new datatype.
Here we use PyTorch Tensors to fit a two-layer network to random data. Like the numpy example above we need to manually implement the forward and backward passes through the network:
# -*- coding: utf-8 -*-
import torch
dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)
# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)
learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.mm(w1)
h_relu = h.clamp(min=0)
y_pred = h_relu.mm(w2)
# Compute and print loss
loss = (y_pred - y).pow(2).sum()
print(t, loss)
# Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.t().mm(grad_y_pred)
grad_h_relu = grad_y_pred.mm(w2.t())
grad_h = grad_h_relu.clone()
grad_h[h < 0] = 0
grad_w1 = x.t().mm(grad_h)
# Update weights using gradient descent
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2
更多教程:http://www.tensorflownews.com/
PyTorch 实战-张量的更多相关文章
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- 参考《深度学习之PyTorch实战计算机视觉》PDF
计算机视觉.自然语言处理和语音识别是目前深度学习领域很热门的三大应用方向. 计算机视觉学习,推荐阅读<深度学习之PyTorch实战计算机视觉>.学到人工智能的基础概念及Python 编程技 ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- pytorch实战(一)hw1——李宏毅老师作业1
任务描述:利用前9小时数据,预测第10小时的pm2.5的数值,回归任务 kaggle地址:https://www.kaggle.com/c/ml2020spring-hw1 训练集为: 12个月*20 ...
- pytorch之张量的理解
张量==容器 张量是现代机器学习的基础,他的核心是一个容器,多数情况下,它包含数字,因此可以将它看成一个数字的水桶. 张量有很多中形式,首先让我们来看最基本的形式.从0维到5维的形式 0维张量/标量: ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- pytorch实战(7)-----卷积神经网络
一.卷积: 卷积在 pytorch 中有两种方式: [实际使用中基本都使用 nn.Conv2d() 这种形式] 一种是 torch.nn.Conv2d(), 一种是 torch.nn.function ...
- Pytorch实战(3)----分类
一.分类任务: 将以下两类分开. 创建数据代码: # make fake data n_data = torch.ones(100, 2) x0 = torch.normal(2*n_data, 1) ...
随机推荐
- C++与引用2
*/ * Copyright (c) 2016,烟台大学计算机与控制工程学院 * All rights reserved. * 文件名:text.cpp * 作者:常轩 * 微信公众号:Worldhe ...
- plsql-工具安装部署及使用配置
参考文档链接:https://blog.csdn.net/li66934791/article/details/83856225 简介: PL/SQL Developer是一个集成开发环境,专门开发面 ...
- #2020.1.26笔记——springdatajpa
2020.1.26笔记--springdatajpa 使用jpa的步骤: 1. 导入maven坐标 <?xml version="1.0" encoding="UT ...
- iOS 使用系统的UITabBarController 修改展示的图片大小
1. 设置TabBarItem图片的大小 1 - (void)configurationAppTabBarAndNavigationBar { // 选中的item普通状态图片的大小 UIImage ...
- C#开发BIMFACE系列36 服务端API之:回调机制
系列目录 [已更新最新开发文章,点击查看详细] 在<C# 开发 BIMFACE 系列文章>中介绍了模型转换.模型对比接口.这2个功能接口比较特殊,发起请求后,逻辑处理是在BIMFA ...
- java反序列化-ysoserial-调试分析总结篇(7)
前言: CommonsCollections7外层也是一条新的构造链,外层由hashtable的readObject进入,这条构造链挺有意思,因为用到了hash碰撞 yso构造分析: 首先构造进行rc ...
- 使用AQS自定义重入锁
一.创建MyLock import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.AbstractQueuedSyn ...
- linux yum安装MySQL5.6
1.新开的云服务器,需要检测系统是否自带安装mysql # yum list installed | grep mysql 2.如果发现有系统自带mysql,果断这么干 # yum -y remove ...
- 使用ZXingObjC扫描二维码横竖屏对应
/** 根据屏幕的方向设置扫描的方向 * @author maguang * @param parameter * @return result */ - (void)showaCapture { C ...
- oracle 10g 搭建备库以及一次DG GAP的处理情况
1.主庫全庫備份rman target/rman> backup database format '/backup/fullbak/fullbak_%U';2.用scp傳到備庫,最好是rman目 ...