参考:https://github.com/milesial/Pytorch-UNet

实现的是二值汽车图像语义分割,包括 dense CRF 后处理.

使用python3,我的环境是python3.6

1.使用

1> 预测

1)查看所有的可用选项:

  1. python predict.py -h

返回:

  1. (deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py -h
  2. usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
  3. [--output INPUT [INPUT ...]] [--cpu] [--viz] [--no-save]
  4. [--no-crf] [--mask-threshold MASK_THRESHOLD] [--scale SCALE]
  5.  
  6. optional arguments:
  7. -h, --help show this help message and exit
  8. --model FILE, -m FILE
  9. Specify the file in which is stored the model (default
  10. : 'MODEL.pth')  #指明使用的训练好的模型文件,默认使用MODEL.pth
  11. --input INPUT [INPUT ...], -i INPUT [INPUT ...] #指明要进行预测的图像文件,必须要有的值
  12. filenames of input images
  13. --output INPUT [INPUT ...], -o INPUT [INPUT ...] #指明预测后生成的图像文件的名字
  14. filenames of ouput images
  15. --cpu, -c Do not use the cuda version of the net #指明使用CPU,默认为false,即默认使用GPU
  16. --viz, -v Visualize the images as they are processed #当图像被处理时,将其可视化,默认为false,即不可以可视化
  17. --no-save, -n Do not save the output masks #不存储得到的预测图像到某图像文件中,和--viz结合使用,即可对预测结果可视化,但是不存储结果,默认为false,即会保存结果
  18. --no-crf, -r Do not use dense CRF postprocessing #指明不使用CRF对输出进行后处理,默认为false,即使用CRF
  19. --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
  20. Minimum probability value to consider a mask pixel #最小化考虑掩模像素为白色的概率值,默认为0.5
  21. white
  22. --scale SCALE, -s SCALE
  23. Scale factor for the input images #输入图像的比例因子,默认为0.5

2)预测单一图片image.jpg并存储结果到output.jpg的命令

  1. python predict.py -i image.jpg -o output.jpg

测试一下:

  1. (deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py --cpu --viz -i image.jpg -o output.jpg
  2. Loading model MODEL.pth
  3. Using CPU version of the net, this may be very slow
  4. Model loaded !
  5.  
  6. Predicting image image.jpg ...
  7. /anaconda3/envs/deeplearning/lib/python3./site-packages/torch/nn/modules/upsampling.py:: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
  8. warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
  9. /anaconda3/envs/deeplearning/lib/python3./site-packages/torch/nn/functional.py:: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  10. warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
  11. Visualizing results for image image.jpg, close to continue ...

返回可视化图片为:

关闭该可视化图片命令就会运行结束:

  1. Mask saved to output.jpg
  2. (deeplearning) userdeMBP:Pytorch-UNet-master user$

并且在当前文件夹中生成名为output.jpg的文件,该图为:

3)预测多张图片并显示,预测结果不存储:

  1. python predict.py -i image1.jpg image2.jpg --viz --no-save

测试:

先得到的是image1.jpg的可视化结果:

  1. (deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py -i image1.jpg image2.jpg --viz --no-save --cpu
  2. Loading model MODEL.pth
  3. Using CPU version of the net, this may be very slow
  4. Model loaded !
  5.  
  6. Predicting image image1.jpg ...
  7. /anaconda3/envs/deeplearning/lib/python3./site-packages/torch/nn/modules/upsampling.py:: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
  8. warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))
  9. /anaconda3/envs/deeplearning/lib/python3./site-packages/torch/nn/functional.py:: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  10. warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
  11. Visualizing results for image image1.jpg, close to continue ...

图为:

关闭这个后就会接着生成image2.jpg的可视化结果:

  1. Predicting image image2.jpg ...
  2. Visualizing results for image image2.jpg, close to continue ...

返回图为:

这时候关闭该可视化服务就会结束了,并且没有在本地保存生成的图片

4)如果你的计算机只有CPU,即CPU-only版本,使用选项--cpu指定

5)你可以指定你使用的训练好的模型文件,使用--mode MODEL.pth

6)如果使用上面的命令选项--no-crf:

  1. (deeplearning) userdeMBP:Pytorch-UNet-master user$ python predict.py -i image1.jpg image2.jpg --viz --no-save --cpu --no-crf

返回的结果是:

还有:

可见crf后处理后,可以将一些不符合事实的判断结果给剔除,使得结果更加精确

2〉训练

  1. python train.py -h

首先需要安装模块pydensecrf,实现CRF条件随机场的模块:

  1. pip install pydensecrf
  2. 但是出错:
    pydensecrf/densecrf/include/Eigen/Core:22:10: fatal error: 'complex' file not found
  3. #include <complex>
  4. ^~~~~~~~~
  5. warning and error generated.
  6. error: command 'gcc' failed with exit status
  7.  
  8. ----------------------------------------
  9. Failed building wheel for pydensecrf
  10. Running setup.py clean for pydensecrf
  11. Failed to build pydensecrf

解决办法,参考https://github.com/lucasb-eyer/pydensecrf:

先安装cython,需要0.22以上的版本:

  1. (deeplearning) userdeMBP:Pytorch-UNet-master user$ pip install -U cython
  2. Installing collected packages: cython
  3. Successfully installed cython-0.29.

然后从git安装最新版本:

  1. pip install git+https://github.com/lucasb-eyer/pydensecrf.git

但还是没有成功

后面找到了新的方法,使用conda来安装就成功了:

  1. userdeMacBook-Pro:~ user$ conda install -n deeplearning -c conda-forge pydensecrf

-c指明从conda-forge下载模块

conda-forge是可以安装软件包的附加渠道,使用该conda-forge频道取代defaults

因为直接安装conda install -n deeplearning pydensecrf找不到该模块

这时候运行python train.py -h可见支持的选项的信息:

  1. (deeplearning) userdeMBP:Pytorch-UNet-master user$ python train.py -h
  2. Usage: train.py [options]
  3.  
  4. Options:
  5. -h, --help show this help message and exit
  6. -e EPOCHS, --epochs=EPOCHS
  7. number of epochs #指明迭代的次数
  8. -b BATCHSIZE, --batch-size=BATCHSIZE
  9. batch size #图像批处理的大小
  10. -l LR, --learning-rate=LR
  11. learning rate #使用的学习率
  12. -g, --gpu use cuda #使用GPU进行训练
  13. -c LOAD, --load=LOAD load file model #下载预训练的文件,在该基础上进行训练
  14. -s SCALE, --scale=SCALE
  15. downscaling factor of the images #图像的缩小因子

3>代码分析

1》unet定义网络

unet/unet_parts.py

  1. # sub-parts of the U-Net model
  2.  
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6.  
  7. #实现左边的横向卷积
  8. class double_conv(nn.Module):
  9. '''(conv => BN => ReLU) * 2'''
  10. def __init__(self, in_ch, out_ch):
  11. super(double_conv, self).__init__()
  12. self.conv = nn.Sequential(
  13. #以第一层为例进行讲解
  14. #输入通道数in_ch,输出通道数out_ch,卷积核设为kernal_size *,padding为1,stride为1,dilation=
  15. #所以图中H*W能从572* 变为 *,计算为570 = (( + *padding - dilation*(kernal_size-) -) / stride ) +
  16. nn.Conv2d(in_ch, out_ch, , padding=),
  17. nn.BatchNorm2d(out_ch), #进行批标准化,在训练时,该层计算每次输入的均值与方差,并进行移动平均
  18. nn.ReLU(inplace=True), #激活函数
  19. nn.Conv2d(out_ch, out_ch, , padding=), #再进行一次卷积,从570*570变为 *
  20. nn.BatchNorm2d(out_ch),
  21. nn.ReLU(inplace=True)
  22. )
  23.  
  24. def forward(self, x):
  25. x = self.conv(x)
  26. return x
  27.  
  28. #实现左边第一行的卷积
  29. class inconv(nn.Module):#
  30. def __init__(self, in_ch, out_ch):
  31. super(inconv, self).__init__()
  32. self.conv = double_conv(in_ch, out_ch) # 输入通道数in_ch为3, 输出通道数out_ch为64
  33.  
  34. def forward(self, x):
  35. x = self.conv(x)
  36. return x
  37.  
  38. #实现左边的向下池化操作,并完成另一层的卷积
  39. class down(nn.Module):
  40. def __init__(self, in_ch, out_ch):
  41. super(down, self).__init__()
  42. self.mpconv = nn.Sequential(
  43. nn.MaxPool2d(),
  44. double_conv(in_ch, out_ch)
  45. )
  46.  
  47. def forward(self, x):
  48. x = self.mpconv(x)
  49. return x
  50.  
  51. #实现右边的向上的采样操作,并完成该层相应的卷积操作
  52. class up(nn.Module):
  53. def __init__(self, in_ch, out_ch, bilinear=True):
  54. super(up, self).__init__()
  55.  
  56. # would be a nice idea if the upsampling could be learned too,
  57. # but my machine do not have enough memory to handle all those weights
  58. if bilinear:#声明使用的上采样方法为bilinear——双线性插值,默认使用这个值,计算方法为 floor(H*scale_factor),所以由28*28变为56*
  59. self.up = nn.Upsample(scale_factor=, mode='bilinear', align_corners=True)
  60. else: #否则就使用转置卷积来实现上采样,计算式子为 (Height-)*stride - *padding -kernal_size +output_padding
  61. self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
  62.  
  63. self.conv = double_conv(in_ch, out_ch)
  64.  
  65. def forward(self, x1, x2): #x2是左边特征提取传来的值
  66. #第一次上采样返回56*,但是还没结束
  67. x1 = self.up(x1)
  68.  
  69. # input is CHW, []是batch_size, []是通道数,更改了下,与源码不同
  70. diffY = x1.size()[] - x2.size()[] #得到图像x2与x1的H的差值,-=-
  71. diffX = x1.size()[] - x2.size()[] #得到图像x2与x1的W差值,-=-
  72.  
  73. #用第一次上采样为例,即当上采样后的结果大小与右边的特征的结果大小不同时,通过填充来使x2的大小与x1相同
  74. #对图像进行填充(-,-,-,-),左右上下都缩小4,所以最后使得64*64变为56*
  75. x2 = F.pad(x2, (diffX // 2, diffX - diffX//2,
  76. diffY // 2, diffY - diffY//2))
  77.  
  78. # for padding issues, see
  79. # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
  80. # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
  81.  
  82. #将最后上采样得到的值x1和左边特征提取的值进行拼接,dim=1即在通道数上进行拼接,由512变为1024
  83. x = torch.cat([x2, x1], dim=)
  84. x = self.conv(x)
  85. return x
  86.  
  87. #实现右边的最高层的最右边的卷积
  88. class outconv(nn.Module):
  89. def __init__(self, in_ch, out_ch):
  90. super(outconv, self).__init__()
  91. self.conv = nn.Conv2d(in_ch, out_ch, )
  92.  
  93. def forward(self, x):
  94. x = self.conv(x)
  95. return x

unet/unetmodel.py

  1. # full assembly of the sub-parts to form the complete net
  2.  
  3. import torch.nn.functional as F
  4.  
  5. from .unet_parts import *
  6.  
  7. class UNet(nn.Module):
  8. def __init__(self, n_channels, n_classes): #图片的通道数,1为灰度图像,3为彩色图像
  9. super(UNet, self).__init__()
  10. self.inc = inconv(n_channels, ) #假设输入通道数n_channels为3,输出通道数为64
  11. self.down1 = down(, )
  12. self.down2 = down(, )
  13. self.down3 = down(, )
  14. self.down4 = down(, )
  15. self.up1 = up(, )
  16. self.up2 = up(, )
  17. self.up3 = up(, )
  18. self.up4 = up(, )
  19. self.outc = outconv(, n_classes)
  20.  
  21. def forward(self, x):
  22. x1 = self.inc(x)
  23. x2 = self.down1(x1)
  24. x3 = self.down2(x2)
  25. x4 = self.down3(x3)
  26. x5 = self.down4(x4)
  27. x = self.up1(x5, x4)
  28. x = self.up2(x, x3)
  29. x = self.up3(x, x2)
  30. x = self.up4(x, x1)
  31. x = self.outc(x)
  32. return F.sigmoid(x) #进行二分类

2》utils

实现dense CRF的代码utils/crf.py:

详细可见pydensecrf的使用

  1. #coding:utf-
  2. import numpy as np
  3. import pydensecrf.densecrf as dcrf
  4.  
  5. def dense_crf(img, output_probs): #img为输入的图像,output_probs是经过网络预测后得到的结果
  6. h = output_probs.shape[] #高度
  7. w = output_probs.shape[] #宽度
  8.  
  9. output_probs = np.expand_dims(output_probs, )
  10. output_probs = np.append( - output_probs, output_probs, axis=)
  11.  
  12. d = dcrf.DenseCRF2D(w, h, ) #NLABELS=2两类标注,车和不是车
  13. U = -np.log(output_probs) #得到一元势
  14. U = U.reshape((, -)) #NLABELS=2两类标注
  15. U = np.ascontiguousarray(U) #返回一个地址连续的数组
  16. img = np.ascontiguousarray(img)
  17.  
  18. d.setUnaryEnergy(U) #设置一元势
  19.  
  20. d.addPairwiseGaussian(sxy=, compat=) #设置二元势中高斯情况的值
  21. d.addPairwiseBilateral(sxy=, srgb=, rgbim=img, compat=)#设置二元势众双边情况的值
  22.  
  23. Q = d.inference() #迭代5次推理
  24. Q = np.argmax(np.array(Q), axis=).reshape((h, w)) #得列中最大值的索引结果
  25.  
  26. return Q

utils/utils.py

  1. import random
  2. import numpy as np
  3.  
  4. #将图像分成左右两块
  5. def get_square(img, pos):
  6. """Extract a left or a right square from ndarray shape : (H, W, C))"""
  7. h = img.shape[]
  8. if pos == :
  9. return img[:, :h]
  10. else:
  11. return img[:, -h:]
  12.  
  13. def split_img_into_squares(img):
  14. return get_square(img, ), get_square(img, )
  15.  
  16. #对图像进行转置,将(H, W, C)变为(C, H, W)
  17. def hwc_to_chw(img):
  18. return np.transpose(img, axes=[, , ])
  19.  
  20. def resize_and_crop(pilimg, scale=0.5, final_height=None):
  21. w = pilimg.size[] #得到图片的宽
  22. h = pilimg.size[]#得到图片的高
  23. #默认scale为0.,即将高和宽都缩小一半
  24. newW = int(w * scale)
  25. newH = int(h * scale)
  26.  
  27. #如果没有指明希望得到的最终高度
  28. if not final_height:
  29. diff =
  30. else:
  31. diff = newH - final_height
  32. #重新设定图片的大小
  33. img = pilimg.resize((newW, newH))
  34. #crop((left,upper,right,lower))函数,从图像中提取出某个矩形大小的图像。它接收一个四元素的元组作为参数,各元素为(left, upper, right, lower),坐标系统的原点(, )是左上角
  35. #如果没有设置final_height,其实就是取整个图片
  36. #如果设置了final_height,就是取一个上下切掉diff // 2,最后高度为final_height的图片
  37. img = img.crop((, diff // 2, newW, newH - diff // 2))
  38. return np.array(img, dtype=np.float32)
  39.  
  40. def batch(iterable, batch_size):
  41. """批量处理列表"""
  42. b = []
  43. for i, t in enumerate(iterable):
  44. b.append(t)
  45. if (i + ) % batch_size == :
  46. yield b
  47. b = []
  48.  
  49. if len(b) > :
  50. yield b
  51.  
  52. #然后将数据分为训练集和验证集两份
  53. def split_train_val(dataset, val_percent=0.05):
  54. dataset = list(dataset)
  55. length = len(dataset) #得到数据集大小
  56. n = int(length * val_percent) #验证集的数量
  57. random.shuffle(dataset) #将数据打乱
  58. return {'train': dataset[:-n], 'val': dataset[-n:]}
  59.  
  60. #对像素值进行归一化,由[,]变为[,]
  61. def normalize(x):
  62. return x /
  63.  
  64. #将两个图片合并起来
  65. def merge_masks(img1, img2, full_w):
  66. h = img1.shape[]
  67.  
  68. new = np.zeros((h, full_w), np.float32)
  69. new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1]
  70. new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):]
  71.  
  72. return new
  73.  
  74. # credits to https://stackoverflow.com/users/6076729/manuel-lagunas
  75. def rle_encode(mask_image):
  76. pixels = mask_image.flatten()
  77. # We avoid issues with '' at the start or end (at the corners of
  78. # the original image) by setting those pixels to '' explicitly.
  79. # We do not expect these to be non-zero for an accurate mask,
  80. # so this should not harm the score.
  81. pixels[] =
  82. pixels[-] =
  83. runs = np.where(pixels[:] != pixels[:-])[] +
  84. runs[::] = runs[::] - runs[:-:]
  85. return runs

utils/data_vis.py实现结果的可视化:

  1. import matplotlib.pyplot as plt
  2.  
  3. def plot_img_and_mask(img, mask):
  4. fig = plt.figure()
  5. a = fig.add_subplot(, , ) #先是打印输入的图片
  6. a.set_title('Input image')
  7. plt.imshow(img)
  8.  
  9. b = fig.add_subplot(, , ) #然后打印预测得到的结果图片
  10. b.set_title('Output mask')
  11. plt.imshow(mask)
  12. plt.show()

utils/load.py

  1. #
  2. # load.py : utils on generators / lists of ids to transform from strings to
  3. # cropped images and masks
  4.  
  5. import os
  6.  
  7. import numpy as np
  8. from PIL import Image
  9.  
  10. from .utils import resize_and_crop, get_square, normalize, hwc_to_chw
  11.  
  12. def get_ids(dir):
  13. """返回目录中的id列表"""
  14. return (f[:-] for f in os.listdir(dir)) #图片名字的后4位为数字,能作为图片id
  15.  
  16. def split_ids(ids, n=):
  17. """将每个id拆分为n个,为每个id创建n个元组(id, k)"""
  18. #等价于for id in ids:
  19. # for i in range(n):
  20. # (id, i)
  21. #得到元祖列表[(id1,),(id1,),(id2,),(id2,),...,(idn,),(idn,)]
  22. #这样的作用是后面会通过后面的0,1作为utils.py中get_square函数的pos参数,pos=0的取左边的部分,pos=1的取右边的部分
  23. return ((id, i) for id in ids for i in range(n))
  24.  
  25. def to_cropped_imgs(ids, dir, suffix, scale):
  26. """从元组列表中返回经过剪裁的正确img"""
  27. for id, pos in ids:
  28. im = resize_and_crop(Image.open(dir + id + suffix), scale=scale) #重新设置图片大小为原来的scale倍
  29. yield get_square(im, pos) #然后根据pos选择图片的左边或右边
  30.  
  31. def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
  32. """返回所有组(img, mask)"""
  33.  
  34. imgs = to_cropped_imgs(ids, dir_img, '.jpg', scale)
  35.  
  36. # need to transform from HWC to CHW
  37. imgs_switched = map(hwc_to_chw, imgs) #对图像进行转置,将(H, W, C)变为(C, H, W)
  38. imgs_normalized = map(normalize, imgs_switched) #对像素值进行归一化,由[,]变为[,]
  39.  
  40. masks = to_cropped_imgs(ids, dir_mask, '_mask.gif', scale) #对图像的结果也进行相同的处理
  41.  
  42. return zip(imgs_normalized, masks) #并将两个结果打包在一起
  43.  
  44. def get_full_img_and_mask(id, dir_img, dir_mask):
  45. im = Image.open(dir_img + id + '.jpg')
  46. mask = Image.open(dir_mask + id + '_mask.gif')
  47. return np.array(im), np.array(mask)

3》预测

predict.py使用训练好的U-net网络对图像进行预测,使用dense CRF进行后处理:

  1. #coding:utf-
  2. import argparse
  3. import os
  4.  
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8.  
  9. from PIL import Image
  10.  
  11. from unet import UNet
  12. from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf
  13. from utils import plot_img_and_mask
  14.  
  15. from torchvision import transforms
  16.  
  17. def predict_img(net,
  18. full_img,
  19. scale_factor=0.5,
  20. out_threshold=0.5,
  21. use_dense_crf=True,
  22. use_gpu=False):
  23.  
  24. net.eval() #进入网络的验证模式,这时网络已经训练好了
  25. img_height = full_img.size[] #得到图片的高
  26. img_width = full_img.size[] #得到图片的宽
  27.  
  28. img = resize_and_crop(full_img, scale=scale_factor) #在utils文件夹的utils.py中定义的函数,重新定义图像大小并进行切割,然后将图像转为数组np.array
  29. img = normalize(img) #对像素值进行归一化,由[,]变为[,]
  30.  
  31. left_square, right_square = split_img_into_squares(img)#将图像分成左右两块,来分别进行判断
  32.  
  33. left_square = hwc_to_chw(left_square) #对图像进行转置,将(H, W, C)变为(C, H, W),便于后面计算
  34. right_square = hwc_to_chw(right_square)
  35.  
  36. X_left = torch.from_numpy(left_square).unsqueeze() #将(C, H, W)变为(, C, H, W),因为网络中的输入格式第一个还有一个batch_size的值
  37. X_right = torch.from_numpy(right_square).unsqueeze()
  38.  
  39. if use_gpu:
  40. X_left = X_left.cuda()
  41. X_right = X_right.cuda()
  42.  
  43. with torch.no_grad(): #不计算梯度
  44. output_left = net(X_left)
  45. output_right = net(X_right)
  46.  
  47. left_probs = output_left.squeeze()
  48. right_probs = output_right.squeeze()
  49.  
  50. tf = transforms.Compose(
  51. [
  52. transforms.ToPILImage(), #重新变成图片
  53. transforms.Resize(img_height), #恢复原来的大小
  54. transforms.ToTensor() #然后再变成Tensor格式
  55. ]
  56. )
  57.  
  58. left_probs = tf(left_probs.cpu())
  59. right_probs = tf(right_probs.cpu())
  60.  
  61. left_mask_np = left_probs.squeeze().cpu().numpy()
  62. right_mask_np = right_probs.squeeze().cpu().numpy()
  63.  
  64. full_mask = merge_masks(left_mask_np, right_mask_np, img_width)#将左右两个拆分后的图片合并起来
  65.  
  66. #对得到的结果根据设置决定是否进行CRF处理
  67. if use_dense_crf:
  68. full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)
  69.  
  70. return full_mask > out_threshold
  71.  
  72. def get_args():
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument('--model', '-m', default='MODEL.pth', #指明使用的训练好的模型文件,默认使用MODEL.pth
  75. metavar='FILE',
  76. help="Specify the file in which is stored the model"
  77. " (default : 'MODEL.pth')")
  78. parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', #指明要进行预测的图像文件
  79. help='filenames of input images', required=True)
  80.  
  81. parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', #指明预测后生成的图像文件的名字
  82. help='filenames of ouput images')
  83. parser.add_argument('--cpu', '-c', action='store_true', #指明使用CPU
  84. help="Do not use the cuda version of the net",
  85. default=False)
  86. parser.add_argument('--viz', '-v', action='store_true',
  87. help="Visualize the images as they are processed", #当图像被处理时,将其可视化
  88. default=False)
  89. parser.add_argument('--no-save', '-n', action='store_true', #不存储得到的预测图像到某图像文件中,和--viz结合使用,即可对预测结果可视化,但是不存储结果
  90. help="Do not save the output masks",
  91. default=False)
  92. parser.add_argument('--no-crf', '-r', action='store_true', #指明不使用CRF对输出进行后处理
  93. help="Do not use dense CRF postprocessing",
  94. default=False)
  95. parser.add_argument('--mask-threshold', '-t', type=float,
  96. help="Minimum probability value to consider a mask pixel white", #最小概率值考虑掩模像素为白色
  97. default=0.5)
  98. parser.add_argument('--scale', '-s', type=float,
  99. help="Scale factor for the input images", #输入图像的比例因子
  100. default=0.5)
  101.  
  102. return parser.parse_args()
  103.  
  104. def get_output_filenames(args):#从输入的选项args值中得到输出文件名
  105. in_files = args.input
  106. out_files = []
  107.  
  108. if not args.output: #如果在选项中没有指定输出的图片文件的名字,那么就会根据输入图片文件名,在其后面添加'_OUT'后缀来作为输出图片文件名
  109. for f in in_files:
  110. pathsplit = os.path.splitext(f) #将文件名和扩展名分开,pathsplit[]是文件名,pathsplit[]是扩展名
  111. out_files.append("{}_OUT{}".format(pathsplit[], pathsplit[])) #得到输出图片文件名
  112. elif len(in_files) != len(args.output): #如果设置了output名,查看input和output的数量是否相同,即如果input是两张图,那么设置的output也必须是两个,否则报错
  113. print("Error : Input files and output files are not of the same length")
  114. raise SystemExit()
  115. else:
  116. out_files = args.output
  117.  
  118. return out_files
  119.  
  120. def mask_to_image(mask):
  121. return Image.fromarray((mask * ).astype(np.uint8)) #从数组array转成Image
  122.  
  123. if __name__ == "__main__":
  124. args = get_args() #得到输入的选项设置的值
  125. in_files = args.input #得到输入的图像文件
  126. out_files = get_output_filenames(args) #从输入的选项args值中得到输出文件名
  127.  
  128. net = UNet(n_channels=, n_classes=) #定义使用的model为UNet,调用在UNet文件夹下定义的unet_model.py,定义图像的通道为3,即彩色图像,判断类型设为1种
  129.  
  130. print("Loading model {}".format(args.model)) #指定使用的训练好的model
  131.  
  132. if not args.cpu: #指明使用GPU
  133. print("Using CUDA version of the net, prepare your GPU !")
  134. net.cuda()
  135. net.load_state_dict(torch.load(args.model))
  136. else: #否则使用CPU
  137. net.cpu()
  138. net.load_state_dict(torch.load(args.model, map_location='cpu'))
  139. print("Using CPU version of the net, this may be very slow")
  140.  
  141. print("Model loaded !")
  142.  
  143. for i, fn in enumerate(in_files): #对图片进行预测
  144. print("\nPredicting image {} ...".format(fn))
  145.  
  146. img = Image.open(fn)
  147. if img.size[] < img.size[]: #(W, H, C)
  148. print("Error: image height larger than the width")
  149.  
  150. mask = predict_img(net=net,
  151. full_img=img,
  152. scale_factor=args.scale,
  153. out_threshold=args.mask_threshold,
  154. use_dense_crf= not args.no_crf,
  155. use_gpu=not args.cpu)
  156.  
  157. if args.viz: #可视化输入的图片和生成的预测图片
  158. print("Visualizing results for image {}, close to continue ...".format(fn))
  159. plot_img_and_mask(img, mask)
  160.  
  161. if not args.no_save:#设置为False,则保存
  162. out_fn = out_files[i]
  163. result = mask_to_image(mask) #从数组array转成Image
  164. result.save(out_files[i]) #然后保存
  165.  
  166. print("Mask saved to {}".format(out_files[i]))

4》训练

  1. import sys
  2. import os
  3. from optparse import OptionParser
  4. import numpy as np
  5.  
  6. import torch
  7. import torch.backends.cudnn as cudnn
  8. import torch.nn as nn
  9. from torch import optim
  10.  
  11. from eval import eval_net
  12. from unet import UNet
  13. from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch
  14.  
  15. def train_net(net,
  16. epochs=,
  17. batch_size=,
  18. lr=0.1,
  19. val_percent=0.05,
  20. save_cp=True,
  21. gpu=False,
  22. img_scale=0.5):
  23.  
  24. dir_img = 'data/train/' #训练图像文件夹
  25. dir_mask = 'data/train_masks/' #图像的结果文件夹
  26. dir_checkpoint = 'checkpoints/' #训练好的网络保存文件夹
  27.  
  28. ids = get_ids(dir_img)#图片名字的后4位为数字,能作为图片id
  29.  
  30. #得到元祖列表为[(id1,),(id1,),(id2,),(id2,),...,(idn,),(idn,)]
  31. #这样的作用是后面重新设置生成器时会通过后面的0,1作为utils.py中get_square函数的pos参数,pos=0的取左边的部分,pos=1的取右边的部分
  32. #这样图片的数量就会变成2倍
  33. ids = split_ids(ids)
  34.  
  35. iddataset = split_train_val(ids, val_percent) #将数据分为训练集和验证集两份
  36.  
  37. print('''
  38. Starting training:
  39. Epochs: {}
  40. Batch size: {}
  41. Learning rate: {}
  42. Training size: {}
  43. Validation size: {}
  44. Checkpoints: {}
  45. CUDA: {}
  46. '''.format(epochs, batch_size, lr, len(iddataset['train']),
  47. len(iddataset['val']), str(save_cp), str(gpu)))
  48.  
  49. N_train = len(iddataset['train']) #训练集长度
  50.  
  51. optimizer = optim.SGD(net.parameters(), #定义优化器
  52. lr=lr,
  53. momentum=0.9,
  54. weight_decay=0.0005)
  55.  
  56. criterion = nn.BCELoss()#损失函数
  57.  
  58. for epoch in range(epochs): #开始训练
  59. print('Starting epoch {}/{}.'.format(epoch + , epochs))
  60. net.train() #设置为训练模式
  61.  
  62. # reset the generators重新设置生成器
  63. # 对输入图片dir_img和结果图片dir_mask进行相同的图片处理,即缩小、裁剪、转置、归一化后,将两个结合在一起,返回(imgs_normalized, masks)
  64. train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
  65. val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)
  66.  
  67. epoch_loss =
  68.  
  69. for i, b in enumerate(batch(train, batch_size)):
  70. imgs = np.array([i[] for i in b]).astype(np.float32) #得到输入图像数据
  71. true_masks = np.array([i[] for i in b]) #得到图像结果数据
  72.  
  73. imgs = torch.from_numpy(imgs)
  74. true_masks = torch.from_numpy(true_masks)
  75.  
  76. if gpu:
  77. imgs = imgs.cuda()
  78. true_masks = true_masks.cuda()
  79.  
  80. masks_pred = net(imgs) #图像输入的网络后得到结果masks_pred,结果为灰度图像
  81. masks_probs_flat = masks_pred.view(-) #将结果压扁
  82.  
  83. true_masks_flat = true_masks.view(-)
  84.  
  85. loss = criterion(masks_probs_flat, true_masks_flat) #对两个结果计算损失
  86. epoch_loss += loss.item()
  87.  
  88. print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))
  89.  
  90. optimizer.zero_grad()
  91. loss.backward()
  92. optimizer.step()
  93.  
  94. print('Epoch finished ! Loss: {}'.format(epoch_loss / i)) #一次迭代后得到的平均损失
  95.  
  96. if :
  97. val_dice = eval_net(net, val, gpu)
  98. print('Validation Dice Coeff: {}'.format(val_dice))
  99.  
  100. if save_cp:
  101. torch.save(net.state_dict(),
  102. dir_checkpoint + 'CP{}.pth'.format(epoch + ))
  103. print('Checkpoint {} saved !'.format(epoch + ))
  104.  
  105. def get_args():
  106. parser = OptionParser()
  107. parser.add_option('-e', '--epochs', dest='epochs', default=, type='int', #设置迭代数
  108. help='number of epochs')
  109. parser.add_option('-b', '--batch-size', dest='batchsize', default=, #设置训练批处理数
  110. type='int', help='batch size')
  111. parser.add_option('-l', '--learning-rate', dest='lr', default=0.1, #设置学习率
  112. type='float', help='learning rate')
  113. parser.add_option('-g', '--gpu', action='store_true', dest='gpu', #是否使用GPU,默认是不使用
  114. default=False, help='use cuda')
  115. parser.add_option('-c', '--load', dest='load', #下载之前预训练好的模型
  116. default=False, help='load file model')
  117. parser.add_option('-s', '--scale', dest='scale', type='float', #图像的缩小因子,用来重新设置图片大小
  118. default=0.5, help='downscaling factor of the images')
  119.  
  120. (options, args) = parser.parse_args()
  121. return options
  122.  
  123. if __name__ == '__main__':
  124. args = get_args() #得到设置的所有参数信息
  125.  
  126. net = UNet(n_channels=, n_classes=)
  127.  
  128. if args.load: #是否加载预先训练好的模型
  129. net.load_state_dict(torch.load(args.load))
  130. print('Model loaded from {}'.format(args.load))
  131.  
  132. if args.gpu: #是否使用GPU,设置为True,则使用
  133. net.cuda()
  134. # cudnn.benchmark = True # faster convolutions, but more memory
  135.  
  136. try: #开始训练
  137. train_net(net=net,
  138. epochs=args.epochs,
  139. batch_size=args.batchsize,
  140. lr=args.lr,
  141. gpu=args.gpu,
  142. img_scale=args.scale)
  143. except KeyboardInterrupt: #如果键盘输入ctrl+c停止,则会将结果保存在INTERRUPTED.pth中
  144. torch.save(net.state_dict(), 'INTERRUPTED.pth')
  145. print('Saved interrupt')
  146. try:
  147. sys.exit()
  148. except SystemExit:
  149. os._exit()

Pytorch实现UNet例子学习的更多相关文章

  1. pytorch例子学习-DATA LOADING AND PROCESSING TUTORIAL

    参考:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html DATA LOADING AND PROCESSING TUT ...

  2. 深度学习框架PyTorch一书的学习-第五章-常用工具模块

    https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...

  3. 深度学习框架PyTorch一书的学习-第一/二章

    参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 pytorch的设计遵循tensor- ...

  4. PyTorch如何构建深度学习模型?

    简介 每过一段时间,就会有一个深度学习库被开发,这些深度学习库往往可以改变深度学习领域的景观.Pytorch就是这样一个库. 在过去的一段时间里,我研究了Pytorch,我惊叹于它的操作简易.Pyto ...

  5. 数百个 HTML5 例子学习 HT 图形组件 – 3D建模篇

    http://www.hightopo.com/demo/pipeline/index.html <数百个 HTML5 例子学习 HT 图形组件 – WebGL 3D 篇>里提到 HT 很 ...

  6. 数百个 HTML5 例子学习 HT 图形组件 – 3D 建模篇

    http://www.hightopo.com/demo/pipeline/index.html <数百个 HTML5 例子学习 HT 图形组件 – WebGL 3D 篇>里提到 HT 很 ...

  7. 数百个 HTML5 例子学习 HT 图形组件 – WebGL 3D 篇

    <数百个 HTML5 例子学习 HT 图形组件 – 拓扑图篇>一文让读者了解了 HT的 2D 拓扑图组件使用,本文将对 HT 的 3D 功能做个综合性的介绍,以便初学者可快速上手使用 HT ...

  8. 数百个 HTML5 例子学习 HT 图形组件 – 拓扑图篇

    HT 是啥:Everything you need to create cutting-edge 2D and 3D visualization. 这口号是当年心目中的产品方向,接着就朝这个方向慢慢打 ...

  9. HTML5 例子学习 HT 图形组件

    HTML5 例子学习 HT 图形组件 HT 是啥:Everything you need to create cutting-edge 2D and 3D visualization. 这口号是当年心 ...

随机推荐

  1. .NET 机器学习生态调查

    机器学习是一种允许计算机使用现有数据预测未来行为.结果和趋势的数据科学方法. 使用机器学习,计算机可以在未显式编程的情况下进行学习.机器学习的预测可以使得应用和设备更智能. 在线购物时,机器学习基于历 ...

  2. zsh: command not found: conda的一种解决方法

    通过conda —version来验证conda命令是否可用,若出现下图 则需要修改.zshrc,如下: 第一步: 第二步: 注意,1:/Users/mac/是anaconda的安装路径,须根据自己情 ...

  3. 为什么要重写hashcode和equals方法?初级程序员在面试中很少能说清楚。

    我在面试 Java初级开发的时候,经常会问:你有没有重写过hashcode方法?不少候选人直接说没写过.我就想,或许真的没写过,于是就再通过一个问题确认:你在用HashMap的时候,键(Key)部分, ...

  4. 一个适合.NET Core的代码安全分析工具 - Security Code Scan

    本文主要翻译自Security Code Scan的官方Github文档,结合自己的初步使用简单介绍一下这款工具,大家可以结合自己团队的情况参考使用.此外,对.NET Core开发团队来说,可以参考张 ...

  5. 从零开始学习PYTHON3讲义(九)字典类型和插入排序

    <从零开始PYTHON3>第九讲 第六讲.上一讲我们都介绍了列表类型.列表类型是编程中最常用的一种类型,但也有挺明显的缺陷,比如: data = [5,22,34,12,87,67,3,4 ...

  6. 【Android Studio安装部署系列】四十二、Android Studio使用Eclipse中的keystore为App签名

    版权声明:本文为HaiyuKing原创文章,转载请注明出处! 概述 从eclipse迁移到AndroidStudio,要用原Eclipse的签名文件,这样才能保证转到AndroidStudio后更新的 ...

  7. 浅析Servlet执行原理

    在JavaWeb学习研究中,Servlet扮演重要的作用,学好它,是后续JavaWeb学习的良好基础.无论是SSH,还是SSM,微服务JavaWeb技术,都应先学好Servlet,从而达到事半功倍的效 ...

  8. 浅谈Google Chrome浏览器(理论篇)

    注解:各位读者,经博客园工作人员反馈,hosts涉及违规问题,我暂时屏蔽了最新hosts,若已经获取最新hosts的朋友们,注意保密,不要外传.给大家带来麻烦,对此非常抱歉!!! 开篇概述 1.详解g ...

  9. web scraper 抓取分页数据和二级页面内容

    如果是刚接触 web scraper 的,可以看第一篇文章. web scraper 是一款免费的,适用于普通用户(不需要专业 IT 技术的)的爬虫工具,可以方便的通过鼠标和简单配置获取你所想要数据. ...

  10. phpcms V9 二次开发------(获取点击数详解)

    关于phpcms V9的点击数的使用应该有不少数是直接调用网上搜索到的代码,但是对于一些想要深入研究开发的人来说,看到网上的代码后更是不解,本人这几天看了看,了解了一些东西,在这里写出来分享一下,首先 ...