# -*- coding:utf-8 -*-

import os
import numpy as np
import torch
import cv2
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from matplotlib import pyplot as plt
import os
from PIL import Image
os.environ ['KMP_DUPLICATE_LIB_OK'] ='True'
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
fmap_block = list()
grad_block = list()
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) torch.manual_seed(1) # 设置随机种子
rmb_label = {"1": 0, "100": 1} # 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1 output_dir = os.path.join(BASE_DIR, "..", "..", "Result", "backward_hook_cam") fmap_block = list()
input_block = list()
# ============================ step 1/5 数据 ============================ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))
train_dir = os.path.join(split_dir, "train")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225] def backward_hook(module, grad_in, grad_out):
grad_block.append(grad_out[0].detach()) def farward_hook(module, input, output):
fmap_block.append(output) def show_cam_on_image(img, mask, out_dir):
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
path_cam_img = os.path.join(out_dir, "cam1.jpg")
path_raw_img = os.path.join(out_dir, "raw1.jpg")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
print(cam)
cv2.imwrite(path_cam_img, np.uint8(255 * cam))
cv2.imwrite(path_raw_img, np.uint8(255 * img)) def comp_class_vec(ouput_vec, index=None):
"""
计算类向量
:param ouput_vec: tensor
:param index: int,指定类别
:return: tensor
"""
if not index:
index = np.argmax(ouput_vec.cpu().data.numpy())
else:
index = np.array(index)
index = index[np.newaxis, np.newaxis]
index = torch.from_numpy(index)
one_hot = torch.zeros(1, 2).scatter_(1, index, 1)
one_hot.requires_grad = True
class_vec = torch.sum(one_hot * outputx) # one_hot = 11.8605
return class_vec def gen_cam(feature_map, grads):
"""
依据梯度和特征图,生成cam
:param feature_map: np.array, in [C, H, W]
:param grads: np.array, in [C, H, W]
:return: np.array, [H, W]
"""
cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # cam shape (H, W) weights = np.mean(grads, axis=(1, 2)) # for i, w in enumerate(weights):
cam += w * feature_map[i, :, :] cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (64, 64))
cam -= np.min(cam)
cam /= np.max(cam) return cam train_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomCrop(64, padding=4),
transforms.RandomGrayscale(p=0.8),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) valid_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) # 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform) # 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # ============================ step 2/5 模型 ============================ net = LeNet(classes=2)
net.initialize_weights() # ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数 # ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略 # ============================ step 5/5 训练 ============================
train_curve = list() iter_count = 0 for epoch in range(MAX_EPOCH):
fmap_dict = dict() loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
iter_count += 1
# forward
inputs, labels = data
outputs = net(inputs)
# backward optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward() # update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0. scheduler.step() # 更新学习率
img = cv2.imread('100.jpg', 1) # H*W*C
x = Image.open('100.jpg').convert('RGB')
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
valid_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
x = valid_transform(x)
x.unsqueeze_(0) net.conv2.register_forward_hook(farward_hook)
net.conv2.register_backward_hook(backward_hook)
outputx = net(x)
net.zero_grad()
class_loss = comp_class_vec(outputx)
class_loss.backward()
grads_val = grad_block[0].cpu().data.numpy().squeeze()
fmap = fmap_block[0].cpu().data.numpy().squeeze()
cam = gen_cam(fmap, grads_val)
img_show = np.float32(cv2.resize(img, (64, 64))) / 255
show_cam_on_image(img_show, cam, output_dir)



[个人总结]利用grad-cam实现人民币分类的更多相关文章

  1. 机器学习实战 - 读书笔记(07) - 利用AdaBoost元算法提高分类性能

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习笔记,这次是第7章 - 利用AdaBoost元算法提高分类性能. 核心思想 在使用某个特定的算法是, ...

  2. 【转载】 机器学习实战 - 读书笔记(07) - 利用AdaBoost元算法提高分类性能

    原文地址: https://www.cnblogs.com/steven-yang/p/5686473.html ------------------------------------------- ...

  3. NLP(二十二)利用ALBERT实现文本二分类

      在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...

  4. 利用RNN进行中文文本分类(数据集是复旦中文语料)

    利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) 1.训练词向量 数据预处理参考利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) ,现在我们有了分词 ...

  5. 利用CNN进行中文文本分类(数据集是复旦中文语料)

    利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) 利用RNN进行中文文本分类(数据集是复旦中文语料) 上一节我们利用了RNN(GRU)对中文文本进行了分类,本节我们将继续使用 ...

  6. 利用AdaBoost元算法提高分类性能

    当做重要决定时,大家可能都会吸取多个专家而不只是一个人的意见.机器学习处理问题时又何尝不是如此?这就是元算法背后的思路.元算法是对其他算法进行组合的一种方式. 自举汇聚法(bootstrap aggr ...

  7. 【Python与机器学习】:利用Keras进行多类分类

    多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...

  8. 利用Spark-mllab进行聚类,分类,回归分析的代码实现(python)

    Spark作为一种开源集群计算环境,具有分布式的快速数据处理能力.而Spark中的Mllib定义了各种各样用于机器学习的数据结构以及算法.Python具有Spark的API.需要注意的是,Spark中 ...

  9. 利用logistic回归解决多分类问题

    利用logistic回归解决手写数字识别问题,数据集私聊. from scipy.io import loadmat import numpy as np import pandas as pd im ...

随机推荐

  1. Broken robot CodeForces - 24D (三对角矩阵简化高斯消元+概率dp)

    题意: 有一个N行M列的矩阵,机器人最初位于第i行和第j列.然后,机器人可以在每一步都转到另一个单元.目的是转到最底部(第N个)行.机器人可以停留在当前单元格处,向左移动,向右移动或移动到当前位置下方 ...

  2. 2020 ICPC Asia Taipei-Hsinchu Regional Problem H Optimization for UltraNet (二分,最小生成树,dsu计数)

    题意:给你一张图,要你去边,使其成为一个边数为\(n-1\)的树,同时要求树的最小边权最大,如果最小边权最大的情况有多种,那么要求总边权最小.求生成树后的所有简单路径上的最小边权和. 题解:刚开始想写 ...

  3. hdu5135 Little Zu Chongzhi's Triangles

    Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 512000/512000 K (Java/Others) Total Submissi ...

  4. Linux 查看系统日志 ,查看服务日志

    journalctl 查看系统日志参数 -f 表示日志跟中-u 指定的是 unit 指定要查看的服务日志,如果不指定的话会显示所有服务的日志 journalctl -f -u 要查看的服务日志 jou ...

  5. 自己动手实现springboot运行时执行java源码(运行时编译、加载、注册bean、调用)

    看来断点.单步调试还不够硬核,根本没多少人看,这次再来个硬核的.依然是由于apaas平台越来越流行了,如果apaas平台选择了java语言作为平台内的业务代码,那么不仅仅面临着IDE外的断点.单步调试 ...

  6. Linux入门详解

    Linux基础知识 Linux&Unix 说起Linux,就不得不提Unix操作系统. Unix系统号称世界上最稳定的系统,就连苹果公司也从中获取灵感开发出了移动端大名鼎鼎的IOS. Unix ...

  7. ajax全局

    $.ajaxSetup({ complete: function (xhr) { xhr.promise().done(function (json) { if (json.errorNo == &q ...

  8. vue中怎么动态生成form表单

    form-create 是一个可以通过 JSON 生成具有动态渲染.数据收集.验证和提交功能的表单生成组件.支持3个UI框架,并且支持生成任何 Vue 组件.内置20种常用表单组件和自定义组件,再复杂 ...

  9. Mybatis-02 CRUD

    Mybatis-02 CRUD CRUD 先来简单回顾一下之前的准备步骤: 创建一个数据库,并加入数据 创建一个Maven项目 导入对应的依赖 创建Pojo类和Dao类 写出Mybatis工具类 配置 ...

  10. React tutorial

    https://www.algolia.com Build Unique Search ExperiencesHosted Search API that delivers instant and r ...