目标检测网络: yolo-v1 实现关键点讲解

yolo-v1作为anchor free 的目标检测方法, 虽然已经较老,但深入理解其原理还是很有必要的. 对于个人而言, 完全从头实现目标检测算法是必不可少的, 因为很多细节光看文章是无法体会到的, 因此本文详细介绍在pytorch下如何实现 yolo-v1

本文将讲解最重要的三个部分, 主要讲解 GT 的制作与损失的计算, 关于 yolo-v1 论文讲解, 请看我的博客, 完整的代码见我的github:

  • 网络设计
  • 如何生成网络所需要的输出
  • yolo-v1 损失计算实现

网络结构

yolo-v1 的网络结构相对简单, 按照论文给出的模型搭建即可, 由于本文训练时数据集较小, 也没有预训练权重, 故采用了深度可分离卷积实现.

import torch.nn as nn
import torch.nn.functional as F


class SeparableConv2D(nn.Module):
    # 深度可分类卷积, 减少网络参数与计算量
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 
        super(SeparableConv2D, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.depth_conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.in_channels,
                                    kernel_size=self.kernel_size, stride=self.stride, padding=self.padding,
                                    groups=self.in_channels)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.point_conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels,
                                    kernel_size=(1, 1), stride=(1, 1))
        self.bn2 = nn.BatchNorm2d(self.out_channels)

    def forward(self, x):
        x = self.depth_conv(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.point_conv(x)
        x = self.bn2(x)
        x = F.relu(x)

        return x


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class Yolo_v1(nn.Module):
    def __init__(self, class_num, box_num=2):
        # class_num: 类别数
        # box_num: yolo-v1 中默认为2
        super(Yolo_v1, self).__init__()
        self.C = class_num
        self.box_num = box_num
        self.out_channel = self.box_num * 5 + self.C
        self.conv_layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=7//2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_layer2 = nn.Sequential(
            SeparableConv2D(in_channels=64, out_channels=192, kernel_size=3, stride=1, padding=3//2),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_layer3 = nn.Sequential(
            SeparableConv2D(in_channels=192, out_channels=128, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=3//2),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_layer4 = nn.Sequential(
            SeparableConv2D(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=3//2),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_layer5 = nn.Sequential(
            SeparableConv2D(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=1024, out_channels=512, kernel_size=1, stride=1, padding=1//2),
            SeparableConv2D(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=3//2),
        )
        self.conv_layer6 = nn.Sequential(
            SeparableConv2D(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=3//2),
            SeparableConv2D(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=3//2),
        )
        # self.flatten = Flatten()
        # self.fc1 = nn.Sequential(
        #     nn.Linear(in_features=7*7*1024, out_features=4096),
        #     nn.Dropout(),
        #     nn.LeakyReLU(0.1)
        # )
        # self.fc2 = nn.Sequential(nn.Linear(in_features=4096, out_features=7 * 7 * (2 * 5 + self.C)),
        #                          nn.Sigmoid())

        self.conv_out = nn.Sequential(
            SeparableConv2D(in_channels=1024, out_channels=self.out_channel, kernel_size=3, stride=1, padding=3//2),
            nn.Conv2d(in_channels=self.out_channel, out_channels=self.out_channel, kernel_size=3, stride=1, padding=3//2),
            nn.BatchNorm2d(self.out_channel),
            nn.ReLU(True),
            nn.Conv2d(in_channels=self.out_channel, out_channels=self.out_channel, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        conv_layer1 = self.conv_layer1(x)
        conv_layer2 = self.conv_layer2(conv_layer1)
        conv_layer3 = self.conv_layer3(conv_layer2)
        conv_layer4 = self.conv_layer4(conv_layer3)
        conv_layer5 = self.conv_layer5(conv_layer4)
        conv_layer6 = self.conv_layer6(conv_layer5)
        # flatten = self.flatten(conv_layer6)
        # fc1 = self.fc1(flatten)
        # fc2 = self.fc2(fc1)
        # output = fc2.reshape([-1, 7, 7, 2 * 5 + self.C])
        output = self.conv_out(conv_layer6)
        output = output.permute(0, 2, 3, 1).contiguous()
        return output

制作GT

此部分代码位于 tools/dataloader.py

def encode_target(self, bboxs, img_clas):
        """
        :param bboxs: [[x1, y1, x2, y2]...]
        :param img_clas: [[0]...]
        :return: S * S * (B * 5 + class_num)
        """
        # 每张图片编码为 7 * 7 * (2 * 5 + 类别数)
        target = torch.zeros((self.S, self.S, self.B * 5 + self.class_num))
        # 统计当前图片中多少个目标
        n_box = len(bboxs)
        # 存放类别信息, img_clas中为每个目标的类别index, 置为1
        clas = torch.zeros((n_box, self.class_num))
        for i in range(n_box):
            clas[i, img_clas[i]] = 1
        # 将输入的左上和右下左边变为物体的中心坐标和宽和高
        # 宽
        bboxs[:, 2] = bboxs[:, 2] - bboxs[:, 0]
        # 高
        bboxs[:, 3] = bboxs[:, 3] - bboxs[:, 1]
        # 中心坐标 xc
        bboxs[:, 0] = bboxs[:, 0] + torch.floor(bboxs[:, 2] / 2)
        # 中心坐标 yc
        bboxs[:, 1] = bboxs[:, 1] + torch.floor(bboxs[:, 3] / 2)

        # 对于yolo-v1, 预测中心坐标相对于当前grid cell左上角的偏移值, 故将其限制在0~1内, 加快收敛
        # 当前grid cell 的左上角坐标为grid cell 在feature map 的索引
        # x = xc / w * s - col
        # y = yc / h * s - row
        # 求索引
        col = torch.floor(bboxs[:, 0] / (self.target_size[0] / self.S))
        row = torch.floor(bboxs[:, 1] / (self.target_size[1] / self.S))
        # 求物体的中心到所在grid cell 左上的偏移值
        x = bboxs[:, 0] / self.target_size[0] * self.S - col
        y = bboxs[:, 1] / self.target_size[1] * self.S - row
        # 物体的宽高相对于图片进行归一化并开方
        w_sqrt = torch.sqrt(bboxs[:, 2] / self.target_size[0])
        h_sqrt = torch.sqrt(bboxs[:, 3] / self.target_size[1])
        # 将置信度设为1
        conf = torch.ones_like(col)
        # yolo-v1 两个box检测同一个物体, 因此复制成两个
        grid_info = torch.cat([conf.view(-1, 1), x.view(-1, 1), y.view(-1, 1), w_sqrt.view(-1, 1), h_sqrt.view(-1, 1)],
                              dim=1).repeat(1, self.B)
        # 将图片中的目标拼接成[[conf1, x, y, w_sqrt, h_sqrt, conf2, x, y, w_sqrt, h_sqrt, c1, c2 ...],
        #         # [....]]
        grid_info = torch.cat([grid_info, clas], dim=1)

        # 将物体的信息放到对应的grid cell中
        for i in range(n_box):
            row_index = row[i].numpy()
            col_index = col[i].numpy()
            target[row_index, col_index] = grid_info[i].clone()

        return target

当多个目标出现在一个grid cell中时, yolo-v1 只能检测一个, 若出现这种情况, 上诉代码中会保留最后一个赋值的bounding box 的信息.

yolo loss layer

对于yolo-v1 的损失计算, 包括分类损失, 置信度损失和坐标损失. 同时yolo-v1 的每个 grid cell 有两个检测 box:

  • 当 grid cell 中存在目标时计算分类损失, 两个box 共用一套类别预测
  • 当 grid cell 中没有目标时只计算置信度损失, 两个检测box都要计算
  • 当 grid cell 中有目标时, 计算两个检测box 预测的坐标与GT 坐标的IOU, 谁的 IOU大谁就负责检测当前目标, 即IOU大的计算置信度损失与坐标损失; IOU小的检测box仍然要计算置信度损失, 只是不计算坐标损失

此部分代码见 vision/yolov1_loss.py

def forward(self, predicts, targets):
        """
        输入输出格式如下:
        :param predicts: [[conf1, x, y, w_sqrt, h_sqrt, conf2, x, y, w_sqrt, h_sqrt, c1, c2 ...]
        :param targets: [[conf1, x, y, w_sqrt, h_sqrt, conf2, x, y, w_sqrt, h_sqrt, c1, c2 ...]
        :return:
        """
        # 分别取取第一个、第二个检测box 的输出
        box1 = predicts[..., :5]
        box2 = predicts[..., 5:10]
        # 找出哪些 grid cell 没有目标
        nobj = 1 - targets[..., 0]

        # 计算所有不包含目标的检测box 的置信度损失
        nobj_box1_loss = self.confidence_loss(box1[..., 0] * nobj, targets[..., 0] * nobj)
        nobj_box2_loss = self.confidence_loss(box2[..., 0] * nobj, targets[..., 0] * nobj)
        nobj_conf_loss = nobj_box1_loss.sum() + nobj_box2_loss.sum()

        # 根据预测的坐标与GT计算每一个检测box与GT的IOU, 检测box负责检测则置为1
        response_mask = self.response_box(predicts, targets)

        # 1. 在包含目标的情况下, 计算响应的检测box的坐标损失和置信度损失
        box1_coord_loss = self.coord_loss(box1, targets) * targets[..., 0] * response_mask[..., 0]
        box2_coord_loss = self.coord_loss(box2, targets) * targets[..., 0] * response_mask[..., 1]
        boxes_coord_loss = box1_coord_loss.sum() + box2_coord_loss.sum()

        box1_response_conf_loss = self.confidence_loss(box1[..., 0] * response_mask[..., 0] * targets[..., 0], targets[..., 0] * response_mask[...,0])
        box2_response_conf_loss = self.confidence_loss(box2[..., 0] * response_mask[..., 1] * targets[..., 0], targets[..., 0] * response_mask[..., 1])
        response_conf_loss_obj = box1_response_conf_loss.sum() + box2_response_conf_loss.sum()

        # 2. 在 grid cell 包含目标的情况下, 对于非响应的检测 box 计算置信度损失
        box1_not_response_conf_loss = self.confidence_loss(box1[..., 0] * response_mask[..., 1] * targets[..., 0], targets[..., 0] * response_mask[..., 1])
        box2_not_response_conf_loss = self.confidence_loss(box2[..., 0] * response_mask[..., 0] * targets[..., 0], targets[..., 0] * response_mask[..., 0])
        not_response_conf_loss_obj = box1_not_response_conf_loss.sum() + box2_not_response_conf_loss.sum()

        obj_conf_loss = response_conf_loss_obj + not_response_conf_loss_obj

        # 3.分类损失
        class_loss = ((predicts[..., self.B * 5:] - targets[..., self.B * 5:]) ** 2 * torch.unsqueeze(targets[..., 0], dim=3)).sum()

        print("coord loss: {}, conf_loss_obj: {}, conf_loss_nobj: {}, class_loss: {}".format(self.lambda_coord * boxes_coord_loss, obj_conf_loss, nobj_conf_loss * self.lambda_noobj, class_loss))

        total_loss = self.lambda_coord * boxes_coord_loss + obj_conf_loss + nobj_conf_loss * self.lambda_noobj + class_loss
        return total_loss