Pytorch版本yolov3源码阅读

1. 阅读test.py

1.1 参数解读

  1. parser = argparse.ArgumentParser()
  2. parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
  3. parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
  4. parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
  5. parser.add_argument('-weights_path', type=str, default='checkpoints/yolov3.pt', help='path to weights file')
  6. parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
  7. parser.add_argument('-iou_thres', type=float, default=0.5, help='iou threshold required to qualify as detected')
  8. parser.add_argument('-conf_thres', type=float, default=0.5, help='object confidence threshold')
  9. parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
  10. parser.add_argument('-n_cpu', type=int, default=0, help='number of cpu threads to use during batch generation')
  11. parser.add_argument('-img_size', type=int, default=608, help='size of each image dimension')
  12. opt = parser.parse_args()
  13. print(opt)
  • batch_size: 每个batch大小,跟darknet不太一样,没有subdivision
  • cfg: 网络配置文件
  • data_config_path: coco.data文件,存储相关信息
  • weights_path: 权重文件路径
  • class_path: 类别文件,注意类别的顺序,coco.names
  • iou_thres: iou阈值
  • conf_thres: 目标执行度阈值
  • nms_thres: 非极大抑制阈值
  • n_cpu: 实用多少个线程来创建batch
  • img_size: 设置初始图片大小

1.2 data文件解析

  1. def parse_data_config(path):
  2. """Parses the data configuration file"""
  3. options = dict()
  4. options['gpus'] = '0,1'
  5. options['num_workers'] = '10'
  6. with open(path, 'r') as fp:
  7. lines = fp.readlines()
  8. for line in lines:
  9. line = line.strip()
  10. if line == '' or line.startswith('#'):
  11. continue
  12. key, value = line.split('=')
  13. options[key.strip()] = value.strip()
  14. return options

将data文件中内容存储到options这个dict中,获取的时候就可以对这个对象通过key进行提取value。

1.3 cfg文件解析

  1. def parse_model_config(path):
  2. """Parses the yolo-v3 layer configuration file and returns module definitions"""
  3. file = open(path, 'r')
  4. lines = file.read().split('\n')
  5. lines = [x for x in lines if x and not x.startswith('#')]
  6. lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
  7. module_defs = []
  8. for line in lines:
  9. if line.startswith('['): # This marks the start of a new block
  10. module_defs.append({})
  11. module_defs[-1]['type'] = line[1:-1].rstrip()
  12. if module_defs[-1]['type'] == 'convolutional':
  13. module_defs[-1]['batch_normalize'] = 0
  14. else:
  15. key, value = line.split("=")
  16. value = value.strip()
  17. module_defs[-1][key.rstrip()] = value.strip()
  18. return module_defs

返回的module_defs存储的是所有的网络参数信息,一个list中套了很多个dict.

1.4 根据cfg文件创建模块

  1. def create_modules(module_defs):
  2. """
  3. Constructs module list of layer blocks from module configuration in module_defs
  4. """
  5. #将第一层内容,也就是网络超参数设定
  6. hyperparams = module_defs.pop(0)
  7. output_filters = [int(hyperparams['channels'])]
  8. module_list = nn.ModuleList()
  9. for i, module_def in enumerate(module_defs):
  10. #一个时序容器。`Modules` 会以他们传入的顺序被添加到容器中。当然,也可以传入一个`OrderedDict`
  11. modules = nn.Sequential()
  12. #根据不同的层进行不同的设计
  13. if module_def['type'] == 'convolutional':
  14. bn = int(module_def['batch_normalize'])
  15. filters = int(module_def['filters'])
  16. kernel_size = int(module_def['size'])
  17. pad = (kernel_size - 1) // 2 if int(module_def['pad']) else 0
  18. #将一个 `child module` 添加到当前 `modle`。 被添加的`module`可以通过 `name`属性来获取。
  19. modules.add_module('conv_%d' % i, nn.Conv2d(in_channels=output_filters[-1],
  20. out_channels=filters,
  21. kernel_size=kernel_size,
  22. stride=int(module_def['stride']),
  23. padding=pad,
  24. bias=not bn))
  25. if bn:
  26. modules.add_module('batch_norm_%d' % i, nn.BatchNorm2d(filters))
  27. if module_def['activation'] == 'leaky':
  28. modules.add_module('leaky_%d' % i, nn.LeakyReLU(0.1))
  29. elif module_def['type'] == 'upsample':
  30. # pytorch中的上采样函数
  31. upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest')
  32. modules.add_module('upsample_%d' % i, upsample)
  33. elif module_def['type'] == 'route':
  34. # 对yolo cfg文件中的route层进行解析
  35. # eg: route -1, 14
  36. layers = [int(x) for x in module_def['layers'].split(',')]
  37. # 将多个层进行以sum的形式合并
  38. # 这个地方发现与darknet中不同,darknet中是以concate的方式进行的
  39. filters = sum([output_filters[layer_i] for layer_i in layers])
  40. modules.add_module('route_%d' % i, EmptyLayer())
  41. elif module_def['type'] == 'shortcut':
  42. # eg from yolov3.cfg
  43. # from=-3
  44. # activation = linear
  45. # 未定义activation方式???
  46. filters = output_filters[int(module_def['from'])]
  47. modules.add_module('shortcut_%d' % i, EmptyLayer())
  48. elif module_def['type'] == 'yolo':
  49. anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
  50. # Extract anchors
  51. anchors = [float(x) for x in module_def['anchors'].split(',')]
  52. anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
  53. anchors = [anchors[i] for i in anchor_idxs]
  54. num_classes = int(module_def['classes'])
  55. img_height = int(hyperparams['height'])
  56. # Define detection layer
  57. yolo_layer = YOLOLayer(anchors, num_classes, img_height, anchor_idxs)
  58. modules.add_module('yolo_%d' % i, yolo_layer)
  59. # Register module list and number of output filters
  60. # 将module添加到module_list中进行保存
  61. module_list.append(modules)
  62. output_filters.append(filters)
  63. return hyperparams, module_list

这里开始就涉及到pytorch部分的内容了:

  • module_list = nn.ModuleList(): 创建一个list,其中存放的是module
  • nn.Sequential(): 一个时序容器。Modules 会以他们传入的顺序被添加到容器中。当然,也可以传入一个OrderedDict
  • add_module(name,module):将一个 child module 添加到当前 modle。 被添加的module可以通过 name属性来获取。

1.5 YOLOLayer

  1. class YOLOLayer(nn.Module):
  2. def __init__(self, anchors, nC, img_dim, anchor_idxs):
  3. super(YOLOLayer, self).__init__()
  4. anchors = [(a_w, a_h) for a_w, a_h in anchors] # (pixels)
  5. nA = len(anchors)
  6. self.anchors = anchors
  7. self.nA = nA # number of anchors (3)
  8. self.nC = nC # number of classes (80)
  9. self.bbox_attrs = 5 + nC
  10. self.img_dim = img_dim # from hyperparams in cfg file, NOT from parser
  11. if anchor_idxs[0] == (nA * 2): # 6
  12. stride = 32
  13. elif anchor_idxs[0] == nA: # 3
  14. stride = 16
  15. else:
  16. stride = 8
  17. # Build anchor grids
  18. nG = int(self.img_dim / stride)
  19. self.grid_x = torch.arange(nG).repeat(nG, 1).view([1, 1, nG, nG]).float()
  20. self.grid_y = torch.arange(nG).repeat(nG, 1).t().view([1, 1, nG, nG]).float()
  21. self.scaled_anchors = torch.FloatTensor([(a_w / stride, a_h / stride) for a_w, a_h in anchors])
  22. self.anchor_w = self.scaled_anchors[:, 0:1].view((1, nA, 1, 1))
  23. self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1))
  24. def forward(self, p, targets=None, requestPrecision=False):
  25. FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
  26. bs = p.shape[0] # batch size
  27. nG = p.shape[2] # number of grid points
  28. stride = self.img_dim / nG
  29. if p.is_cuda and not self.grid_x.is_cuda:
  30. self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
  31. self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
  32. # p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
  33. p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
  34. # Get outputs
  35. x = torch.sigmoid(p[..., 0]) # Center x
  36. y = torch.sigmoid(p[..., 1]) # Center y
  37. # Width and height (yolo method)
  38. w = p[..., 2] # Width
  39. h = p[..., 3] # Height
  40. width = torch.exp(w.data) * self.anchor_w
  41. height = torch.exp(h.data) * self.anchor_h
  42. # Width and height (power method)
  43. # w = torch.sigmoid(p[..., 2]) # Width
  44. # h = torch.sigmoid(p[..., 3]) # Height
  45. # width = ((w.data * 2) ** 2) * self.anchor_w
  46. # height = ((h.data * 2) ** 2) * self.anchor_h
  47. # Add offset and scale with anchors (in grid space, i.e. 0-13)
  48. pred_boxes = FT(bs, self.nA, nG, nG, 4)
  49. pred_conf = p[..., 4] # Conf
  50. pred_cls = p[..., 5:] # Class
  51. # Training
  52. if targets is not None:
  53. MSELoss = nn.MSELoss(size_average=True)
  54. BCEWithLogitsLoss = nn.BCEWithLogitsLoss(size_average=True)
  55. CrossEntropyLoss = nn.CrossEntropyLoss()
  56. if requestPrecision:
  57. gx = self.grid_x[:, :, :nG, :nG]
  58. gy = self.grid_y[:, :, :nG, :nG]
  59. pred_boxes[..., 0] = x.data + gx - width / 2
  60. pred_boxes[..., 1] = y.data + gy - height / 2
  61. pred_boxes[..., 2] = x.data + gx + width / 2
  62. pred_boxes[..., 3] = y.data + gy + height / 2
  63. tx, ty, tw, th, mask, tcls, TP, FP, FN, TC = \
  64. build_targets(pred_boxes, pred_conf, pred_cls, targets, self.scaled_anchors, self.nA, self.nC, nG,
  65. requestPrecision)
  66. tcls = tcls[mask]
  67. if x.is_cuda:
  68. tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
  69. # Mask outputs to ignore non-existing objects (but keep confidence predictions)
  70. nT = sum([len(x) for x in targets]) # number of targets
  71. nM = mask.sum().float() # number of anchors (assigned to targets)
  72. nB = len(targets) # batch size
  73. k = nM / nB
  74. if nM > 0:
  75. lx = k * MSELoss(x[mask], tx[mask])
  76. ly = k * MSELoss(y[mask], ty[mask])
  77. lw = k * MSELoss(w[mask], tw[mask])
  78. lh = k * MSELoss(h[mask], th[mask])
  79. # lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float())
  80. lconf = k * BCEWithLogitsLoss(pred_conf, mask.float())
  81. lcls = k * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
  82. # lcls = k * BCEWithLogitsLoss(pred_cls[mask], tcls.float())
  83. else:
  84. lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
  85. # Add confidence loss for background anchors (noobj)
  86. #lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float())
  87. # Sum loss components
  88. loss = lx + ly + lw + lh + lconf + lcls
  89. # Sum False Positives from unassigned anchors
  90. i = torch.sigmoid(pred_conf[~mask]) > 0.9
  91. if i.sum() > 0:
  92. FP_classes = torch.argmax(pred_cls[~mask][i], 1)
  93. FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu() # extra FPs
  94. else:
  95. FPe = torch.zeros(self.nC)
  96. return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \
  97. nT, TP, FP, FPe, FN, TC
  98. else:
  99. pred_boxes[..., 0] = x.data + self.grid_x
  100. pred_boxes[..., 1] = y.data + self.grid_y
  101. pred_boxes[..., 2] = width
  102. pred_boxes[..., 3] = height
  103. # If not in training phase return predictions
  104. output = torch.cat((pred_boxes.view(bs, -1, 4) * stride,
  105. torch.sigmoid(pred_conf.view(bs, -1, 1)), pred_cls.view(bs, -1, self.nC)), -1)
  106. return output.data

暂且放到这里,之后在做解析

1.6 初始化模型

model = Darknet(opt.cfg, opt.img_size)

转到定义:

  1. class Darknet(nn.Module):
  2. """YOLOv3 object detection model"""
  3. def __init__(self, config_path, img_size=416):
  4. super(Darknet, self).__init__()
  5. self.module_defs = parse_model_config(config_path)
  6. self.module_defs[0]['height'] = img_size
  7. self.hyperparams, self.module_list = create_modules(self.module_defs)
  8. self.img_size = img_size
  9. self.loss_names = ['loss', 'x', 'y', 'w', 'h', 'conf', 'cls', 'nT', 'TP', 'FP', 'FPe', 'FN', 'TC']
  10. def forward(self, x, targets=None, requestPrecision=False):
  11. is_training = targets is not None
  12. output = []
  13. self.losses = defaultdict(float)
  14. layer_outputs = []
  15. for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
  16. if module_def['type'] in ['convolutional', 'upsample']:
  17. x = module(x)
  18. elif module_def['type'] == 'route':
  19. layer_i = [int(x) for x in module_def['layers'].split(',')]
  20. x = torch.cat([layer_outputs[i] for i in layer_i], 1)
  21. elif module_def['type'] == 'shortcut':
  22. layer_i = int(module_def['from'])
  23. x = layer_outputs[-1] + layer_outputs[layer_i]
  24. elif module_def['type'] == 'yolo':
  25. # Train phase: get loss
  26. if is_training:
  27. x, *losses = module[0](x, targets, requestPrecision)
  28. for name, loss in zip(self.loss_names, losses):
  29. self.losses[name] += loss
  30. # Test phase: Get detections
  31. else:
  32. x = module(x)
  33. output.append(x)
  34. layer_outputs.append(x)
  35. if is_training:
  36. self.losses['nT'] /= 3
  37. self.losses['TC'] /= 3
  38. metrics = torch.zeros(4, len(self.losses['FPe'])) # TP, FP, FN, target_count
  39. ui = np.unique(self.losses['TC'])[1:]
  40. for i in ui:
  41. j = self.losses['TC'] == float(i)
  42. metrics[0, i] = (self.losses['TP'][j] > 0).sum().float() # TP
  43. metrics[1, i] = (self.losses['FP'][j] > 0).sum().float() # FP
  44. metrics[2, i] = (self.losses['FN'][j] == 3).sum().float() # FN
  45. metrics[3] = metrics.sum(0)
  46. metrics[1] += self.losses['FPe']
  47. self.losses['TP'] = metrics[0].sum()
  48. self.losses['FP'] = metrics[1].sum()
  49. self.losses['FN'] = metrics[2].sum()
  50. self.losses['TC'] = 0
  51. self.losses['metrics'] = metrics
  52. return sum(output) if is_training else torch.cat(output, 1)

梳理一下属性值,以便更好理解:

  • module_def: dict类型,存储cfg文件中
  • hyperparams: 超参数,整个网络需要的参数被存储到改属性中
  • module_list:整个网络所有的模型加载到pytorch中的nn.ModuleList()
  • loss_names: 有必要理解一下这里的loss中参数的含义
    • loss
    • x,y,w,h
    • conf
    • cls
    • nT
    • TP,FP,FPe,FN,TC

loss参数含义还不是很明白,留坑,待填坑

1.7 加载权重

都知道,pytorch版的yolov3权重文件是.pt结尾的,darknet版本的yolov3权重文件是.weights结尾的。

所以得知了这个版本可以使用加载weights文件。

  1. # Load weights
  2. if opt.weights_path.endswith('.weights'): # darknet format
  3. load_weights(model, opt.weights_path)
  4. elif opt.weights_path.endswith('.pt'): # pytorch format
  5. checkpoint = torch.load(opt.weights_path, map_location='cpu')
  6. model.load_state_dict(checkpoint['model'])
  7. del checkpoint

1.8 计算mAP

  1. print('Compute mAP...')
  2. correct = 0
  3. targets = None
  4. outputs, mAPs, TP, confidence, pred_class, target_class = [], [], [], [], [], []
  5. for batch_i, (imgs, targets) in enumerate(dataloader):
  6. imgs = imgs.to(device)
  7. with torch.no_grad():
  8. output = model(imgs)
  9. output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
  10. # Compute average precision for each sample
  11. for sample_i in range(len(targets)):
  12. correct = []
  13. # Get labels for sample where width is not zero (dummies)
  14. annotations = targets[sample_i]
  15. # Extract detections
  16. detections = output[sample_i]
  17. if detections is None:
  18. # If there are no detections but there are annotations mask as zero AP
  19. if annotations.size(0) != 0:
  20. mAPs.append(0)
  21. continue
  22. # Get detections sorted by decreasing confidence scores
  23. detections = detections[np.argsort(-detections[:, 4])]
  24. # If no annotations add number of detections as incorrect
  25. if annotations.size(0) == 0:
  26. target_cls = []
  27. #correct.extend([0 for _ in range(len(detections))])
  28. mAPs.append(0)
  29. continue
  30. else:
  31. target_cls = annotations[:, 0]
  32. # Extract target boxes as (x1, y1, x2, y2)
  33. target_boxes = xywh2xyxy(annotations[:, 1:5])
  34. target_boxes *= opt.img_size
  35. detected = []
  36. for *pred_bbox, conf, obj_conf, obj_pred in detections:
  37. pred_bbox = torch.FloatTensor(pred_bbox).view(1, -1)
  38. # Compute iou with target boxes
  39. iou = bbox_iou(pred_bbox, target_boxes)
  40. # Extract index of largest overlap
  41. best_i = np.argmax(iou)
  42. # If overlap exceeds threshold and classification is correct mark as correct
  43. if iou[best_i] > opt.iou_thres and obj_pred == annotations[best_i, 0] and best_i not in detected:
  44. correct.append(1)
  45. detected.append(best_i)
  46. else:
  47. correct.append(0)
  48. # Compute Average Precision (AP) per class
  49. AP = ap_per_class(tp=correct, conf=detections[:, 4], pred_cls=detections[:, 6], target_cls=target_cls)
  50. # Compute mean AP for this image
  51. mAP = AP.mean()
  52. # Append image mAP to list
  53. mAPs.append(mAP)
  54. # Print image mAP and running mean mAP
  55. print('+ Sample [%d/%d] AP: %.4f (%.4f)' % (len(mAPs), len(dataloader) * opt.batch_size, mAP, np.mean(mAPs)))
  56. print('Mean Average Precision: %.4f' % np.mean(mAPs))

留坑,待填

2. 阅读train.py

2.1 参数解读

  1. parser = argparse.ArgumentParser()
  2. parser.add_argument('-epochs', type=int, default=68, help='number of epochs')
  3. parser.add_argument('-batch_size', type=int, default=12, help='size of each image batch')
  4. parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='data config file path')
  5. parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
  6. parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each image dimension')
  7. parser.add_argument('-resume', default=False, help='resume training flag')
  8. opt = parser.parse_args()
  9. print(opt)
  • epochs 设置循环的参数
  • batch_size: 设置batch
  • data_config_path: data文件位置
  • cfg: 记录cfg文件的位置
  • img_size: 设置图片大小
  • resume: 是否恢复训练(True or False)

2.2 随机初始化

  1. random.seed(0)
  2. np.random.seed(0)
  3. torch.manual_seed(0)
  4. if cuda:
  5. torch.cuda.manual_seed(0)
  6. torch.cuda.manual_seed_all(0)
  7. torch.backends.cudnn.benchmark = True

2.3 设置优化器

  1. optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3,momentum=.9, weight_decay=5e-4, nesterov=True)

使用SGD优化器,learning_rate=0.001,momentum=0.9,weight_decay=5e-4,使用nesterov动量

2.4 更新优化器

根据当前epoch来确定使用哪一个lr:

  1. # Update scheduler (automatic)
  2. # scheduler.step()
  3. # Update scheduler (manual)
  4. if epoch < 54:
  5. lr = 1e-3
  6. elif epoch < 61:
  7. lr = 1e-4
  8. else:
  9. lr = 1e-5
  10. for g in optimizer.param_groups:
  11. g['lr'] = lr

可以自动更新参数,也可以手工更新参数。

2.5 loss指标

  • mean_precision:
  1. # Precision
  2. precision = metrics[0] / (metrics[0] + metrics[1] + 1e-16)
  3. k = (metrics[0] + metrics[1]) > 0
  4. if k.sum() > 0:
  5. mean_precision = precision[k].mean()
  6. else:
  7. mean_precision = 0
  • mean_recall:
  1. # Recall
  2. recall = metrics[0] / (metrics[0] + metrics[2] + 1e-16)
  3. k = (metrics[0] + metrics[2]) > 0
  4. if k.sum() > 0:
  5. mean_recall = recall[k].mean()
  6. else:
  7. mean_recall = 0

然后将所有指标写到results.txt文件中

2.6 checkpoint相关

checkpoint参数:epoch, best_loss,model,optimizer

latest.pt: 最新的权重文件

best.pt: 当前最好的权重文件

  1. # Save latest checkpoint
  2. checkpoint = {'epoch': epoch,
  3. 'best_loss': best_loss,
  4. 'model': model.state_dict(),
  5. 'optimizer': optimizer.state_dict()}
  6. torch.save(checkpoint, 'checkpoints/latest.pt')
  7. # Save best checkpoint
  8. if best_loss == loss_per_target:
  9. os.system('cp checkpoints/latest.pt checkpoints/best.pt')
  10. # Save backup checkpoint
  11. if (epoch > 0) & (epoch % 5 == 0):
  12. os.system('cp checkpoints/latest.pt checkpoints/backup' + str(epoch) + '.pt')

3. 阅读detect.py

3.1 参数解读

  1. parser.add_argument('-image_folder', type=str, default='data/samples', help='path to images')
  2. parser.add_argument('-output_folder', type=str, default='output', help='path to outputs')
  3. parser.add_argument('-plot_flag', type=bool, default=True)
  4. parser.add_argument('-txt_out', type=bool, default=False)
  5. parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
  6. parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
  7. parser.add_argument('-conf_thres', type=float, default=0.50, help='object confidence threshold')
  8. parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
  9. parser.add_argument('-batch_size', type=int, default=1, help='size of the batches')
  10. parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each image dimension')
  11. opt = parser.parse_args()
  12. print(opt)
  • image_folder: data/samples, 待检测的图片的文件夹
  • output_folder: output,结果输出文件
  • plot_flag: True or False, 添加bbox, 保存图片
  • txt_out: True or False, 是否保存图片检测结果
  • cfg: cfg文件路径
  • class_path: 类别名称文件位置
  • conf_thres, nms_thres: 目标检测置信度,非极大抑制阈值
  • batch_size: 一般设置为1,选用默认的即可
  • img_size: 设置加载图片时候的图片大小

3.2 预测框的获取

  1. # Get detections
  2. with torch.no_grad():
  3. chip = torch.from_numpy(img).unsqueeze(0).to(device)
  4. pred = model(chip)
  5. pred = pred[pred[:, :, 4] > opt.conf_thres]
  6. if len(pred) > 0:
  7. detections = non_max_suppression(pred.unsqueeze(0), opt.conf_thres, opt.nms_thres)
  8. img_detections.extend(detections)
  9. imgs.extend(img_paths)

获取预测框,非极大值抑制。

3.2 核心-迭代图片画出预测框

  1. # Iterate through images and save plot of detections
  2. for img_i, (path, detections) in enumerate(zip(imgs, img_detections)):
  3. print("image %g: '%s'" % (img_i, path))
  4. if opt.plot_flag:
  5. img = cv2.imread(path)
  6. # The amount of padding that was added
  7. pad_x = max(img.shape[0] - img.shape[1], 0) * (opt.img_size / max(img.shape))
  8. pad_y = max(img.shape[1] - img.shape[0], 0) * (opt.img_size / max(img.shape))
  9. # Image height and width after padding is removed
  10. unpad_h = opt.img_size - pad_y
  11. unpad_w = opt.img_size - pad_x
  12. # Draw bounding boxes and labels of detections
  13. if detections is not None:
  14. unique_classes = detections[:, -1].cpu().unique()
  15. bbox_colors = random.sample(color_list, len(unique_classes))
  16. # write results to .txt file
  17. results_img_path = os.path.join(opt.output_folder, path.split('/')[-1])
  18. results_txt_path = results_img_path + '.txt'
  19. if os.path.isfile(results_txt_path):
  20. os.remove(results_txt_path)
  21. for i in unique_classes:
  22. n = (detections[:, -1].cpu() == i).sum()
  23. print('%g %ss' % (n, classes[int(i)]))
  24. for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
  25. # Rescale coordinates to original dimensions
  26. box_h = ((y2 - y1) / unpad_h) * img.shape[0]
  27. box_w = ((x2 - x1) / unpad_w) * img.shape[1]
  28. y1 = (((y1 - pad_y // 2) / unpad_h) * img.shape[0]).round().item()
  29. x1 = (((x1 - pad_x // 2) / unpad_w) * img.shape[1]).round().item()
  30. x2 = (x1 + box_w).round().item()
  31. y2 = (y1 + box_h).round().item()
  32. x1, y1, x2, y2 = max(x1, 0), max(y1, 0), max(x2, 0), max(y2, 0)
  33. # write to file
  34. if opt.txt_out:
  35. with open(results_txt_path, 'a') as file:
  36. file.write(('%g %g %g %g %g %g \n') % (x1, y1, x2, y2, cls_pred, cls_conf * conf))
  37. if opt.plot_flag:
  38. # Add the bbox to the plot
  39. label = '%s %.2f' % (classes[int(cls_pred)], conf)
  40. color = bbox_colors[int(np.where(unique_classes == int(cls_pred))[0])]
  41. plot_one_box([x1, y1, x2, y2], img, label=label, color=color)
  42. if opt.plot_flag:
  43. # Save generated image with detections
  44. cv2.imwrite(results_img_path.replace('.bmp', '.jpg').replace('.tif', '.jpg'), img)

Pytorch版本yolov3源码阅读的更多相关文章

  1. 【原】FMDB源码阅读(三)

    [原]FMDB源码阅读(三) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 FMDB比较优秀的地方就在于对多线程的处理.所以这一篇主要是研究FMDB的多线程处理的实现.而 ...

  2. 【原】FMDB源码阅读(二)

    [原]FMDB源码阅读(二) 本文转载请注明出处 -- polobymulberry-博客园 1. 前言 上一篇只是简单地过了一下FMDB一个简单例子的基本流程,并没有涉及到FMDB的所有方方面面,比 ...

  3. 【原】AFNetworking源码阅读(四)

    [原]AFNetworking源码阅读(四) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 上一篇还遗留了很多问题,包括AFURLSessionManagerTaskDe ...

  4. 【原】AFNetworking源码阅读(二)

    [原]AFNetworking源码阅读(二) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 上一篇中我们在iOS Example代码中提到了AFHTTPSessionMa ...

  5. 【原】AFNetworking源码阅读(一)

    [原]AFNetworking源码阅读(一) 本文转载请注明出处 —— polobymulberry-博客园 1. 前言 AFNetworking版本:3.0.4 由于我平常并没有经常使用AFNetw ...

  6. 【原】SDWebImage源码阅读(三)

    [原]SDWebImage源码阅读(三) 本文转载请注明出处 —— polobymulberry-博客园 1.SDWebImageDownloader中的downloadImageWithURL 我们 ...

  7. Android源码阅读 – Zygote

    @Dlive 本文档: 使用的Android源码版本为:Android-4.4.3_r1 kitkat (源码下载: http://source.android.com/source/index.ht ...

  8. 源码阅读系列:EventBus

    title: 源码阅读系列:EventBus date: 2016-12-22 16:16:47 tags: 源码阅读 --- EventBus 是人们在日常开发中经常会用到的开源库,即使是不直接用的 ...

  9. EventBus源码解析 源码阅读记录

    EventBus源码阅读记录 repo地址: greenrobot/EventBus EventBus的构造 双重加锁的单例. static volatile EventBus defaultInst ...

随机推荐

  1. python爬虫CSDN文章抓取

    版权声明:本文为博主原创文章.未经博主同意不得转载. https://blog.csdn.net/nealgavin/article/details/27230679 CSDN原则上不让非人浏览訪问. ...

  2. 001-Spring Cloud Edgware.SR3 升级最新 Finchley.SR1,spring boot 1.5.9.RELEASE 升级2.0.4.RELEASE注意问题点

    一.前提 升级前 => 升级后 Spring Boot 1.5.x => Spring Boot 2.0.4.RELEASE Spring Cloud Edgware SR3 => ...

  3. JSP页面中引入另一个JSP页面

    一个JSP页面中引入另一个JSP页面,相当于把另一个JSP页面的内容复制到对应位置: <%@include file="date.jsp" %> 一般页面的top和bo ...

  4. 简单的应用可以用storyBoard

    可是问题,你不知道你的项目有多复杂,storyBoard跳转控制有代码这么灵活吗? 1. 假是是根据推送来推出页面呢? 2. 假如我要根据不同情况不停地推出不同的页面呢?storyBoard怎么确定关 ...

  5. liferay中如何实现自己定义的方法

    大家看到这篇文章是不是很开心啊,我感觉是很开心,我们终于可以按照自己的意愿来写一次代码,在liferay中一些基本的增删改查的代码是自动生成的,然而我们想要实现自己的方法的话,恐怕要费一点劲,你要知道 ...

  6. cmd重启服务器,有时不想去机房,并且远程桌面连接登录不上了

    有时不想去机房,并且远程桌面连接登录不上了,需要远程重启服务器的,这时可以使用命令行方式远程重启.在cmd命令行状态下输入:shutdown -r -m \\192.168.1.10 -t 0 -f ...

  7. 016-sed

    行处理:一次处理一行.正则选定文本 ----->>sed处理格式:一.命令行格式:sed [options] 'command' files(如果没有则是通过管道)1.options: - ...

  8. linux常用命令:tr 命令

    tr 命令实现字符转换功能,其功能类似于 sed 命令,但是,tr 命令比 sed 命令简单.也就是说,tr 命令能实现的功能,sed 命令都可以实现.尽管如此,tr 命令依然是 Linux 系统下处 ...

  9. SQLServer 进阶记录式学习

    1.强制类型转换  nvarchar->decimal ) , , ) SET @i = '1083.589' SET @num = @i SELECT @num , )) SELECT @nu ...

  10. php多进程结合Linux利器split命令实现把大文件分批高效处理

    有时候会遇到这样的需求,比如log日志文件,这个文件很大,甚至上百M,需要把所有的日志拿来做统计,这时候我们如果用单进程来处理,效率会很慢.如果我们想要快速完成这项需求,我们可以利用Linux的一个利 ...