import os
import sys
import random
import math
import re
import time
import numpy as np
import cv2
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image # Root directory of the project
ROOT_DIR = os.path.abspath("../../") # Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log #%matplotlib inline # Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs") # Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH) iter_num=0

  

Configurations

class ShapesConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
# Give the configuration a recognizable name
NAME = "shapes" # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
GPU_COUNT = 2
IMAGES_PER_GPU = 1 #这里我用了两个GPU # Number of classes (including background)
NUM_CLASSES = 1 + 1 # background + 1 shapes # Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 1080
IMAGE_MAX_DIM = 1920 # Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (8*6, 16*6, 32*6, 64*6, 128*6) # anchor side in pixels # Reduce training ROIs per image because the images are small and have
# few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 32 # Use a small epoch since the data is simple
STEPS_PER_EPOCH = 100 # use small validation steps since the epoch is small
VALIDATION_STEPS = 5 config = ShapesConfig()
config.display()

  Notebook Preference

def get_ax(rows=1, cols=1, size=8):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes. Change the default size attribute to control the size
of rendered images
"""
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax

  Dataset

class DrugDataset(utils.Dataset):

    #得到该图中有多少个实例(物体)
def get_obj_index(self, image):
n = np.max(image)
return n
#解析labelme中得到的yaml文件,从而得到mask每一层对应的实例标签
def from_yaml_get_class(self,image_id):
info=self.image_info[image_id]
with open(info['yaml_path']) as f:
temp=yaml.load(f.read())
labels=temp['label_names']
del labels[0]
return labels
#重新写draw_mask
def draw_mask(self, num_obj, mask, image):
info = self.image_info[image_id]
for index in range(num_obj):
for i in range(info['width']):
for j in range(info['height']):
at_pixel = image.getpixel((i, j))
if at_pixel == index + 1:
mask[j, i, index] =1
return mask
#重新写load_shapes,里面包含自己的自己的类别(我的是box、column、package、fruit四类)
#并在self.image_info信息中添加了path、mask_path 、yaml_path
def load_shapes(self, count, height, width, img_floder, mask_floder, imglist,dataset_root_path):
"""Generate the requested number of synthetic images.
count: number of images to generate.
height, width: the size of the generated images.
"""
# Add classes
self.add_class("shapes", 1, "box") for i in range(count):
filestr = imglist[i].split(".")[0]
filestr = filestr.split("_")[0]
mask_path = mask_floder + "/" + filestr + ".png"
yaml_path=dataset_root_path+filestr+"rgb_"+"_json/info.yaml"
self.add_image("shapes", image_id=i, path=img_floder + "/"+imglist[i],
width=width, height=height, mask_path=mask_path,yaml_path=yaml_path)
#重写load_mask
def load_mask(self, image_id):
"""Generate instance masks for shapes of the given image ID.
"""
global iter_num
info = self.image_info[image_id]
count = 1 # number of object
img = Image.open(info['mask_path'])
num_obj = self.get_obj_index(img)
mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
mask = self.draw_mask(num_obj, mask, img)
occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
for i in range(count - 2, -1, -1):
mask[:, :, i] = mask[:, :, i] * occlusion
occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
labels=[]
labels=self.from_yaml_get_class(image_id)
labels_form=[]
for i in range(len(labels)):
if labels[i].find("box")!=-1:
#print "box"
labels_form.append("box")
#elif labels[i].find("column")!=-1:
#print "column"
# labels_form.append("column")
#elif labels[i].find("package")!=-1:
#print "package"
# labels_form.append("package")
#elif labels[i].find("fruit")!=-1:
#print "fruit"
# labels_form.append("fruit")
class_ids = np.array([self.class_names.index(s) for s in labels_form])
return mask, class_ids.astype(np.int32)

  基础设置

#基础设置
dataset_root_path="/mnt/disk2/zhouqiang/Mask_RCNN/data/train_01_01/"
img_floder = dataset_root_path+"rgb"
mask_floder = dataset_root_path+"mask"
#yaml_floder = dataset_root_path
imglist = os.listdir(img_floder)
count = len(imglist)
width = 1920
height = 1080 #train与val数据集准备
dataset_train = DrugDataset()
dataset_train.load_shapes(count, 1080, 1920, img_floder, mask_floder, imglist,dataset_root_path)
dataset_train.prepare() dataset_val = DrugDataset()
dataset_val.load_shapes(count, 1080, 1920, img_floder, mask_floder, imglist,dataset_root_path)
dataset_val.prepare()

  Create Model

# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config,
model_dir=MODEL_DIR)

 

# Which weights to start with?
init_with = "coco" # imagenet, coco, or last if init_with == "imagenet":
model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
# Load weights trained on MS COCO, but skip layers that
# are different due to the different number of classes
# See README for instructions to download the COCO weights
model.load_weights(COCO_MODEL_PATH, by_name=True,
exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
"mrcnn_bbox", "mrcnn_mask"])
elif init_with == "last":
# Load the last model you trained and continue training
model.load_weights(model.find_last(), by_name=True)

  

# Fine tune all layers
# Passing layers="all" trains all layers. You can also
# pass a regular expression to select which layers to
# train by name pattern.
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=50,
layers="all")

  

 

MaskRCNN 奔跑自己的数据的更多相关文章

  1. labelme2coco问题:TypeError: Object of type 'int64' is not JSON serializable

    最近在做MaskRCNN 在自己的数据(labelme)转为COCOjson格式遇到问题:TypeError: Object of type 'int64' is not JSON serializa ...

  2. Android动画之仿美团加载数据等待时,小人奔跑进度动画对话框(附顺丰快递员奔跑效果)

    Android动画之仿美团加载数据等待时,小人奔跑进度动画对话框(附顺丰快递员奔跑效果) 首句依然是那句老话,你懂得! finddreams :(http://blog.csdn.net/finddr ...

  3. 奔跑的歌颂 diskgenius 找回了20G数据

    2.0同学家的电脑不慎重装系统,结果默认重新分区.其他倒没什么数据,就是几千张记录孩子成长的照片最为珍贵.为了找回数据,用U盘启动,使用Diskgenius全部找回,在此奔歌一下.

  4. mask-rcnn代码解读(四):rpn_feature_maps数据的处理

    此处模拟 rpn_feature_maps数据的处理,最终得到rpn_class_logits, rpn_class, rpn_bbox. 代码如下: import numpy as np'''层与层 ...

  5. MVC5 网站开发之三 数据存储层功能实现

    数据存储层在项目Ninesky.DataLibrary中实现,整个项目只有一个类Repository.   目录 奔跑吧,代码小哥! MVC5网站开发之一 总体概述 MVC5 网站开发之二 创建项目 ...

  6. C#工业物联网和集成系统解决方案的技术路线(数据源、数据采集、数据上传与接收、ActiveMQ、Mongodb、WebApi、手机App)

    目       录 工业物联网和集成系统解决方案的技术路线... 1 前言... 1 第一章           系统架构... 3 1.1           硬件构架图... 3 1.2      ...

  7. jquery表格动态增删改及取数据绑定数据完整方案

    一 前言 上一篇Jquery遮罩插件,想罩哪就罩哪! 结尾的预告终于来了. 近期参与了一个针对内部员工个人信息收集的系统,其中有一个需求是在填写各个相关信息时,需要能动态的增加行当时公司有自己的解决方 ...

  8. 奔跑的xiaodao

    http://acm.hrbust.edu.cn/index.php?m=ProblemSet&a=showProblem&problem_id=2086 很明显的一个二分题目.因为要 ...

  9. 奔跑在Docker上的Spark

    转自:马踏飞燕--奔跑在Docker上的Spark 目录 为什么要在Docker上搭建Spark集群 网络拓扑 Docker安装及配置 ssh安装及配置 基础环境安装 Zookeeper安装及配置 H ...

随机推荐

  1. php-fpm 参数调优

    php-fpm 进程池优化方法 php-fpm进程池开启进程有两种方式,一种是static,直接开启指定数量的php-fpm进程,不再增加或者减少:另一种则是dynamic,开始时开启一定数量的php ...

  2. 【洛谷】P4139 上帝与集合的正确用法

    题目描述 根据一些书上的记载,上帝的一次失败的创世经历是这样的:  第一天,上帝创造了一个世界的基本元素,称做“元”.  第二天,上帝创造了一个新的元素,称作“α”.“α”被定义为“元”构成的集合.容 ...

  3. ubuntu之路——day11.6 多任务学习

    在迁移学习transfer learning中,你的步骤是串行的sequential process 在多任务学习multi-task learning中,你试图让单个神经网络同时做几件事情,然后这里 ...

  4. selenium之 下拉选择框Select

    今天总结下selenium的下拉选择框.我们通常会遇到两种下拉框,一种使用的是html的标签select,另一种是使用input标签做的假下拉框. 后者我们通常的处理方式与其他的元素类似,点击或使用J ...

  5. 常用的maven仓库地址

    maven 仓库地址: 共有的仓库http://repo1.maven.org/maven2/http://repository.jboss.com/maven2/http://repository. ...

  6. 关于 array of const

    之前应该参考一下: 关于开放数组参数 //这是在 System 单元定义的一组标识数据类型的常量: vtInteger    = ; vtBoolean    = ; vtChar      = ; ...

  7. unity疯狂牧场完整项目源码 - Frenzy Farming time management game kit V1.0

    You will love this game kit! Have you ever wondered what it would be like to run your own farm? Look ...

  8. strace命令 二

    让我们看一台高负载服务器的 top 结果: top 技巧:运行 top 时,按「1」打开 CPU 列表,按「shift+p」以 CPU 排序. 在本例中大家很容易发现 CPU 主要是被若干个 PHP ...

  9. linux 查看gpu信息

  10. Linux 操作系统 & High Tech

    分享10大白帽黑客专用的 Linux 操作系统 - 51CTO.COMhttp://os.51cto.com/art/201905/597156.htm Ubuntu 创始人谈论为什么 Linux 在 ...