RetinaNet目标检测任务


1.数据集介绍

DOTA数据集全称:Dataset for Object deTection in Aerial images

DOTA数据集v1.0共收录2806张4000 × 4000的图片,总共包含188282个目标。

  • DOTA数据集论文介绍:https://arxiv.org/pdf/1711.10398.pdf

  • 数据集官网:https://captain-whu.github.io/DOTA/dataset.html

 

DOTA数据集有三个版本:

* DOTAV1.0

类别数目:15

类别名称:plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, large vehicle, small vehicle, helicopter, roundabout, soccer ball field , swimming pool

* DOTAV1.5

类别数目:16

类别名称:plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, large vehicle, small vehicle, helicopter, roundabout, soccer ball field, swimming pool , container crane

* DOTAV2.0

类别数目:18

类别名称:plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, large vehicle, small vehicle, helicopter, roundabout, soccer ball field, swimming pool, container crane, airport , helipad

 

 (1) 标签

在对数据集进行数据增强时,我们需要知道相关标签文件格式

每个对象有10个数值,前8个代表一个矩形框四个角的坐标,第9个表示对象类别,第10个表示识别难易程度,0表示简单,1表示困难。

下面是一个类似的文件


950.0 851.0 931.0 852.0 932.0 817.0 952.0 817.0 small-vehicle 1

475.0 982.0 456.0 982.0 461.0 841.0 481.0 842.0 large-vehicle 0

424.0 978.0 400.0 982.0 403.0 840.0 426.0 839.0 large-vehicle 0

395.0 984.0 373.0 985.0 376.0 842.0 399.0 843.0 large-vehicle 0

365.0 979.0 344.0 978.0 346.0 839.0 369.0 838.0 large-vehicle 0

337.0 977.0 317.0 977.0 321.0 836.0 339.0 835.0 large-vehicle 0

310.0 978.0 287.0 979.0 286.0 838.0 311.0 838.0 large-vehicle 0

154.0 947.0 250.0 947.0 250.0 971.0 154.0 971.0 large-vehicle 0

140.0 894.0 255.0 894.0 255.0 919.0 140.0 919.0 large-vehicle 0

116.0 862.0 236.0 862.0 236.0 888.0 116.0 888.0 large-vehicle 0

146.0 771.0 269.0 771.0 269.0 796.0 146.0 796.0 large-vehicle 0

136.0 741.0 271.0 741.0 271.0 766.0 136.0 766.0 large-vehicle 0

136.0 713.0 271.0 713.0 271.0 735.0 136.0 735.0 large-vehicle 0

 

可以看出其标签是由四点组成的旋转框。 注:由于可视化代码需要坐标点输入为整数,因此后续给输出对象坐标点进行了取整操作,这会损失一些精度。

(2) 数据增强

相关数据增强, DOTA(obb)数据增加(方法有:改变亮度,加噪声,旋转角度,镜像,平移,裁剪,cutout),相关code如下:

2025.8.3 : 数据增强(dota4点转换)

# -*- coding=utf-8 -*-

# 包括:

#     6. 裁剪(需改变bbox)

#     5. 平移(需改变bbox)

#     1. 改变亮度

#     2. 加噪声

#     3. 旋转角度(需要改变bbox)

#     4. 镜像(需要改变bbox)

#     7. cutout

# 注意:

#     random.seed(),相同的seed,产生的随机数是一样的!!

import time

import random

import cv2 as cv

import os

import math

import numpy as np

from skimage.util import random_noise

from skimage import exposure

import shutil

import imutils

  

# 调整亮度

def changeLight(img,inputtxt,outputiamge,outputtxt):

 

    # random.seed(int(time.time()))

    flag = random.uniform(0.5, 1.5)  # flag>1为调暗,小于1为调亮

    label=round(flag,2)

    (filepath, tempfilename) = os.path.split(inputtxt)

    (filename, extension) = os.path.splitext(tempfilename)

    outputiamge=os.path.join(outputiamge+"/"+filename+"_"+str(label)+".jpg")

    outputtxt=os.path.join(outputtxt+"/"+filename+"_"+str(label)+extension)

 

    ima_gamma=exposure.adjust_gamma(img, 0.5)

 

    shutil.copyfile(inputtxt, outputtxt)

    cv.imwrite(outputiamge,ima_gamma)

 

# 加噪声

def gasuss_noise(image,inputtxt,outputiamge,outputtxt,mean=0, var=0.01):

    '''

        添加高斯噪声

        mean : 均值

        var : 方差

    '''

    image = np.array(image/255, dtype=float)

    noise = np.random.normal(mean, var ** 0.5, image.shape)

    out = image + noise

 

    if out.min() < 0:

        low_clip = -1.

    else:

        low_clip = 0.

    out = np.clip(out, low_clip, 1.0)

    out = np.uint8(out*255)

 

    (filepath, tempfilename) = os.path.split(inputtxt)

    (filename, extension) = os.path.splitext(tempfilename)

    outputiamge = os.path.join(outputiamge + "/" + filename + "_gasunoise_" + str(mean)+"_"+str(var)+ ".jpg")

    outputtxt = os.path.join(outputtxt + "/" +  filename + "_gasunoise_" + str(mean)+"_"+str(var)+ extension)

 

    shutil.copyfile(inputtxt, outputtxt)

    cv.imwrite(outputiamge, out)

 

#对比度调整算法

def ContrastAlgorithm(rgb_img,inputtxt, outputiamge, outputtxt):

   img_shape=rgb_img.shape

   temp_imag=np.zeros(img_shape, dtype=float)

   for num in range(0,3):

       # 通过直方图正规化增强对比度

 

       in_image =rgb_img[:,:,num]

       # 求输入图片像素最大值和最小值

       Imax = np.max(in_image)

       Imin = np.min(in_image)

       # 要输出的最小灰度级和最大灰度级

       Omin, Omax = 0, 255

       # 计算a 和 b的值

       a = float(Omax - Omin) / (Imax - Imin)

       b = Omin - a * Imin

       # 矩阵的线性变化

       out_image = a * in_image + b

       # 数据类型的转化

       out_image = out_image.astype(np.uint8)

       temp_imag[:,:,num]=out_image

   (filepath, tempfilename) = os.path.split(inputtxt)

   (filename, extension) = os.path.splitext(tempfilename)

   outputiamge = os.path.join(outputiamge + "/" + filename + "_contrastAlgorithm"  + ".jpg")

   outputtxt = os.path.join(outputtxt + "/" + filename + "_contrastAlgorithm"  + extension)

   shutil.copyfile(inputtxt, outputtxt)

   cv.imwrite(outputiamge, temp_imag)

 

# 旋转

def rotate_img_bbox(img, inputtxt,temp_outputiamge,temp_outputtxt,angle,scale=1):

    nAgree=angle

    size=img.shape

    w=size[1]

    h=size[0]

    for numAngle in range(0,len(nAgree)):

        dRot = nAgree[numAngle] * np.pi / 180

        dSinRot = math.sin(dRot)

        dCosRot = math.cos(dRot)

 

        nw = (abs(np.sin(dRot) * h) + abs(np.cos(dRot) * w)) * scale

        nh = (abs(np.cos(dRot) * h) + abs(np.sin(dRot) * w)) * scale

 

        (filepath, tempfilename) = os.path.split(inputtxt)

        (filename, extension) = os.path.splitext(tempfilename)

        outputiamge = os.path.join(temp_outputiamge + "/" + filename + "_rotate_" + str(nAgree[numAngle])+ ".jpg")

        outputtxt = os.path.join(temp_outputtxt + "/" + filename + "_rotate_" + str(nAgree[numAngle])+ extension)

 

        rot_mat = cv.getRotationMatrix2D((nw * 0.5, nh * 0.5), nAgree[numAngle], scale)

        rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))

        rot_mat[0, 2] += rot_move[0]

        rot_mat[1, 2] += rot_move[1]


        # 仿射变换

        rotat_img = cv.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv.INTER_LANCZOS4)

        cv.imwrite(outputiamge,rotat_img)

 

        save_txt=open(outputtxt,'w')

        f = open(inputtxt)

        for line in f.readlines():

            line= line.split(" ")

            x1=float(line[0])

            y1=float(line[1])

            x2 = float(line[2])

            y2 = float(line[3])

            x3 = float(line[4])

            y3 = float(line[5])

            x4 = float(line[6])

            y4 = float(line[7])

            category=str(line[8])

 

            point1 = np.dot(rot_mat, np.array([x1, y1, 1]))

            point2 = np.dot(rot_mat, np.array([x2, y2, 1]))

            point3 = np.dot(rot_mat, np.array([x3, y3, 1]))

            point4 = np.dot(rot_mat, np.array([x4, y4, 1]))

            x1=round(point1[0],3)

            y1 = round(point1[1], 3)

            x2 = round(point2[0], 3)

            y2 = round(point2[1], 3)

            x3 = round(point3[0], 3)

            y3 = round(point3[1], 3)

            x4 = round(point4[0], 3)

            y4 = round(point4[1], 3)

            string = str(x1) + " " + str(y1) + " " + str(x2) + " " + str(y2) + " " + str(x3) + " " + str(

                y3) + " " + str(x4) + " " + str(y4)+" "+category

            save_txt.write(string)

 

def filp_pic_bboxes(img, inputtxt,outputiamge,outputtxt):

    # ---------------------- 翻转图像 ----------------------

    (filepath, tempfilename) = os.path.split(inputtxt)

    (filename, extension) = os.path.splitext(tempfilename)

    output_vert_flip_img = os.path.join(outputiamge + "/" + filename + "_vert_flip" + ".jpg")

    output_vert_flip_txt = os.path.join(outputtxt + "/" + filename + "_vert_flip"  + extension)

    output_horiz_flip_img = os.path.join(outputiamge + "/" + filename + "_horiz_flip" + ".jpg")

    output_horiz_flip_txt = os.path.join(outputtxt + "/" + filename + "_horiz_flip" + extension)

 

    h,w,_ = img.shape


    #垂直翻转

    vert_flip_img =  cv.flip(img, 1)

    cv.imwrite(output_vert_flip_img,vert_flip_img)


    # 水平翻转

    horiz_flip_img = cv.flip(img, 0)

    cv.imwrite(output_horiz_flip_img,horiz_flip_img)

    # ---------------------- 调整boundingbox ----------------------

 

    save_vert_txt = open(output_vert_flip_txt, 'w')

    save_horiz_txt = open(output_horiz_flip_txt, 'w')

    f = open(inputtxt)

    for line in f.readlines():

        line = line.split(" ")

        x1 = float(line[0])

        y1 = float(line[1])

        x2 = float(line[2])

        y2 = float(line[3])

        x3 = float(line[4])

        y3 = float(line[5])

        x4 = float(line[6])

        y4 = float(line[7])

        category = str(line[8])

 

        horiz_string = str(round(w-x1,3)) + " " + str(y1) + " " + str(round(w-x2,3)) + " " + str(y2) + " " + str(round(w-x3,3)) + " " + str(y3) + " " + str(

            round(w - x4, 3)) + " " + str(y4) + " " + category

        vert_string = str(x1) + " " + str(round(h-y1,3)) + " " + str(x2) + " " + str(round(h-y2,3)) + " " + str(x3) + " " + str(

            round(h - y3, 3)) + " " + str(x4) + " " + str(round(h-y4,3)) + " " + category

 

        save_horiz_txt.write(vert_string)

        save_vert_txt.write(horiz_string)

 

#平移图像

def shift_pic_bboxes(img, inputtxt,outputiamge,outputtxt):

    # ---------------------- 平移图像 ----------------------

    w = img.shape[1]

    h = img.shape[0]

    x_min = w  # 裁剪后的包含所有目标框的最小的框

    x_max = 0

    y_min = h

    y_max = 0

    f = open(inputtxt)

    for line in f.readlines():

        line = line.split(" ")

        x1 = float(line[0])

        y1 = float(line[1])

        x2 = float(line[2])

        y2 = float(line[3])

        x3 = float(line[4])

        y3 = float(line[5])

        x4 = float(line[6])

        y4 = float(line[7])

        category = str(line[8])

 

        x_min = min(x_min, x1,x2,x3,x4)

        y_min = min(y_min, y1,y2,y3,y4)

        x_max = max(x_max, x1,x2,x3,x4)

        y_max = max(y_max, y1,y2,y3,y4)

 

    d_to_left = x_min  # 包含所有目标框的最大左移动距离

    d_to_right = w - x_max  # 包含所有目标框的最大右移动距离

    d_to_top = y_min  # 包含所有目标框的最大上移动距离

    d_to_bottom = h - y_max  # 包含所有目标框的最大下移动距离

 

    x = random.uniform(-(d_to_left - 1) / 3, (d_to_right - 1) / 3)

    y = random.uniform(-(d_to_top - 1) / 3, (d_to_bottom - 1) / 3)

 

    (filepath, tempfilename) = os.path.split(inputtxt)

    (filename, extension) = os.path.splitext(tempfilename)

    if x>=0 and y>=0 :

        outputiamge = os.path.join(outputiamge + "/" + filename + "_shift_" + str(round(x,3))+"_"+str(round(y,3)) + ".jpg")

        outputtxt = os.path.join(outputtxt + "/" + filename + "_shift_" + str(round(x,3))+"_"+str(round(y,3)) + extension)

    elif x>=0 and y<0:

        outputiamge = os.path.join(

            outputiamge + "/" + filename + "_shift_" + str(round(x, 3)) + "__" + str(round(abs(y), 3)) + ".jpg")

        outputtxt = os.path.join(

            outputtxt + "/" + filename + "_shift_" + str(round(x, 3)) + "__" + str(round(abs(y), 3)) + extension)

    elif x<0 and="" y="">=0:

        outputiamge = os.path.join(

            outputiamge + "/" + filename + "_shift__" + str(round(abs(x), 3)) + "_" + str(round(y, 3)) + ".jpg")

        outputtxt = os.path.join(

            outputtxt + "/" + filename + "_shift__" + str(round(abs(x), 3)) + "_" + str(round(y, 3)) + extension)

    elif x<0 and y<0:

        outputiamge = os.path.join(

            outputiamge + "/" + filename + "_shift__" + str(round(abs(x), 3)) + "__" + str(round(abs(y), 3)) + ".jpg")

        outputtxt = os.path.join(

            outputtxt + "/" + filename + "_shift__" + str(round(abs(x), 3)) + "__" + str(round(abs(y), 3)) + extension)

 

    M = np.float32([[1, 0, x], [0, 1, y]])  # x为向左或右移动的像素值,正为向右负为向左; y为向上或者向下移动的像素值,正为向下负为向上

    shift_img = cv.warpAffine(img, M, (img.shape[1], img.shape[0]))

    cv.imwrite(outputiamge,shift_img)


    # ---------------------- 平移boundingbox ----------------------

    save_txt=open(outputtxt,"w")

    f = open(inputtxt)

    for line in f.readlines():

        line = line.split(" ")

        x1 = float(line[0])

        y1 = float(line[1])

        x2 = float(line[2])

        y2 = float(line[3])

        x3 = float(line[4])

        y3 = float(line[5])

        x4 = float(line[6])

        y4 = float(line[7])

        category = str(line[8])

        shift_str=str(round(x1+x,3))+" "+str(round(y1+y,3))+" "+str(round(x2+x,3))+" "+str(round(y2+y,3))\

                  +" "+str(round(x3+x,3))+" "+str(round(y3+y,3))+" "+str(round(x4+x,3))+" "+str(round(y4+y,3))+" "+category

        save_txt.write(shift_str)

 

#裁剪

def crop_img_bboxes(img, inputtxt,outputiamge,outputtxt):

    # ---------------------- 裁剪图像 ----------------------

    w = img.shape[1]

    h = img.shape[0]

    x_min = w  # 裁剪后的包含所有目标框的最小的框

    x_max = 0

    y_min = h

    y_max = 0

    f = open(inputtxt)

    for line in f.readlines():

        line = line.split(" ")

        x1 = float(line[0])

        y1 = float(line[1])

        x2 = float(line[2])

        y2 = float(line[3])

        x3 = float(line[4])

        y3 = float(line[5])

        x4 = float(line[6])

        y4 = float(line[7])

        category = str(line[8])

 

        x_min = min(x_min, x1, x2, x3, x4)

        y_min = min(y_min, y1, y2, y3, y4)

        x_max = max(x_max, x1, x2, x3, x4)

        y_max = max(y_max, y1, y2, y3, y4)

 

    d_to_left = x_min  # 包含所有目标框的最小框到左边的距离

    d_to_right = w - x_max  # 包含所有目标框的最小框到右边的距离

    d_to_top = y_min  # 包含所有目标框的最小框到顶端的距离

    d_to_bottom = h - y_max  # 包含所有目标框的最小框到底部的距离

    # 随机扩展这个最小框

    crop_x_min = int(x_min - random.uniform(0, d_to_left))

    crop_y_min = int(y_min - random.uniform(0, d_to_top))

    crop_x_max = int(x_max + random.uniform(0, d_to_right))

    crop_y_max = int(y_max + random.uniform(0, d_to_bottom))

    # 随机扩展这个最小框 , 防止别裁的太小

    # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left))

    # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top))

    # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right))

    # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom))

    # 确保不要越界

    crop_x_min = max(0, crop_x_min)

    crop_y_min = max(0, crop_y_min)

    crop_x_max = min(w, crop_x_max)

    crop_y_max = min(h, crop_y_max)

    crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]

 

    (filepath, tempfilename) = os.path.split(inputtxt)

    (filename, extension) = os.path.splitext(tempfilename)

    outputiamge = os.path.join(outputiamge + "/" + filename + "_crop_" + str(crop_x_min) + "_" +

                               str(crop_y_min) + "_"+str(crop_x_max) + "_" +

                               str(crop_y_max) +".jpg")

    outputtxt = os.path.join(outputtxt + "/" + filename + "_crop_" + str(crop_x_min) + "_" +

                               str(crop_y_min) + "_"+str(crop_x_max) + "_" +

                               str(crop_y_max)  + extension)

    cv.imwrite(outputiamge,crop_img)

    # ---------------------- 裁剪boundingbox ----------------------

    # 裁剪后的boundingbox坐标计算

    save_txt = open(outputtxt, "w")

    f = open(inputtxt)

    for line in f.readlines():

        line = line.split(" ")

        x1 = float(line[0])

        y1 = float(line[1])

        x2 = float(line[2])

        y2 = float(line[3])

        x3 = float(line[4])

        y3 = float(line[5])

        x4 = float(line[6])

        y4 = float(line[7])

        category = str(line[8])

        crop_str = str(round(x1-crop_x_min, 3)) + " " + str(round(y1-crop_y_min, 3)) + " " + str(round(x2-crop_x_min, 3)) + " " + str(\

            round(y2-crop_y_min, 3)) + " " + str(round(x3-crop_x_min, 3)) + " " + str(round(y3-crop_y_min, 3)) + " " + str(\

            round(x4-crop_x_min, 3)) + " " + str(round(y4-crop_y_min, 3)) + " " + category

        save_txt.write(crop_str)

 

if __name__=='__main__':

    inputiamge="./split_data_overlap512/agumnet_data/images"

    inputtxt = "./split_data_overlap512/agumnet_data/labels"

    outputiamge = "./split_data_overlap512/agumnet_data/images"

    outputtxt = "./split_data_overlap512/agumnet_data/labels"

    angle=[30,60,90,120,150,180]

    tempfilename=os.listdir(inputiamge)

    for file in tempfilename:

        (filename,extension)=os.path.splitext(file)

        input_image=os.path.join(inputiamge+"/"+file)

        input_txt=os.path.join(inputtxt+"/"+filename+".txt")

 

        img = cv.imread(input_image)

        #图像亮度变化

        #changeLight(img,input_txt,outputiamge,outputtxt)

        #加高斯噪声

        #gasuss_noise(img, input_txt, outputiamge, outputtxt, mean=0, var=0.001)

        #改变图像对比度

        #ContrastAlgorithm(img, input_txt, outputiamge, outputtxt)

        #图像旋转

        #rotate_img_bbox(img, input_txt, outputiamge, outputtxt, angle)

        #图像镜像

        filp_pic_bboxes(img, input_txt, outputiamge, outputtxt)

        #平移

        #shift_pic_bboxes(img, input_txt, outputiamge, outputtxt)

        #剪切

        #crop_img_bboxes(img, input_txt, outputiamge, outputtxt)

    print("###finished!!!")

```

 2.模型

RetinaNet 原始论文为发表于 2017 ICCV 的 Focal Loss for Dense Object Detection。one-stage 网络首次超越 two-stage 网络,拿下了 best student paper,仅管其在网络结构部分并没有颠覆性贡献。


图示 AI 生成的内容可能不正确。

 

2.1 backbone 部分

RetinaNet 网络详细结构如下所示,与 FPN 不同,FPN 会使用 C2,而 RetinaNet 则没有,因为 C2 生成的 P2 会占用更多的计算资源,所以作者直接从 C3 开始生产 P3。关于 backbone 部分和 FPN 部分基本类似,所以具体细节部分就不细讲了。第二个不同点在于 P6 这个地方,原论文是在 C5 的基础上生成的(最大池化下采样得到的),这里是根据 pytorch 官方的实现绘制的,是通过 3 $\times$ 3 的卷积层来实现下采样的。第三个不同是 FPN 是从 P2 到 P6,而 RetinaNet 是从 P3 到 P7。


一些文字和图片的手机截图 AI 生成的内容可能不正确。

 

上图也给出了 P3 到 P7 上使用的 scale 和 ratios。在 FPN 中每个特征层上使用了一个 scale 和三个 ratios。在 RetinaNet 中是三个 scale 和三个 ratios 共计 9 个 anchor。 注意,这里 scale 等于 32 对应的 anchor 的面积是 32 的平方的。所以在 RetinaNet 中最小的 scale 是 32,最大的则是接近 813。

 

2.2 预测器部分

由于 RetinaNet 是一个 one-stage 的网络,所以不用 ROI pooling,直接使用如下图所示的权重共享的基于卷积操作的预测器。预测器分为两个分支,分别预测每个 anchor 所属的类别,以及目标边界框回归参数。最后的 kA 中 k 是检测目标的类别个数,注意这里的 k 不包含背景类别,对于 PASCAL VOC 数据集的话就是 20。这里的 A 是预测特征层在每一个位置生成的 anchor 的个数,在这里就是 9。(现在基本都是这样的类别不可知 anchor 回归参数预测,也可以理解为每一类共享了同一个 anchor 回归参数预测器)


 

2.3 正负样本匹配

针对每一个 anchor 与事先标注好的 GT box 进行比对,如果 iou 大于 0.5 则是正样本,如果某个 anchor 与所有的 GT box 的 iou 值都小于 0.4,则是负样本。其余的进行舍弃。


图片包含 表格 AI 生成的内容可能不正确。

 

2.4 损失计算

本文一个核心的贡献点就是 focal loss。总损失依然分为两部分,一部分是分类损失,一部分是回归损失。Focal loss 比较独特的一个点就是正负样本都会来计算分类损失,然后仅对正样本进行回归损失的计算。回归损失在 SSD 以及 Faster R-CNN 中都有讲解。


图示 AI 生成的内容可能不正确。



2.5 Focal Loss

为了实现正负样本的比例均衡,不至于整个训练过程被负样本“淹没”,一般采取抽样的方法,将正负样本比例控制在1:3,从而在正负样本间保持合理的比例。因为 one-stage 只有一个阶段,产生的候选框相比 two-stage 要多太多。通常需要大约100K个位置(例如 SSD 的 8700+ 个位置),且这里面正样本几个十几个,少之又少。即使你抽样了,最后在训练过程中,还是会惊奇的发现,整个过程还是被大量容易区分的负样本,也就是背景所主导。Focal loss 则是一个动态缩放的交叉熵损失,一言以蔽之,通过一个动态缩放因子,可以动态降低训练过程中易区分样本的权重,从而将 loss 的重心快速聚焦在那些难区分的样本上 (注意:难以区分的样本不一定是正样本)。


图示 AI 生成的内容可能不正确。

 

2.6 Cross Entropy Loss

Focal loss的起源是二分类交叉熵 CE,它的形式是这样的:

文本, 白板 AI 生成的内容可能不正确。

在上式中,y的取值有 1 和 -1 两种,代表前景和背景。p的取值范围是 [0,1],是模型预测的属于前景的概率,为了表示方便,定义一个pt

文本, 信件 AI 生成的内容可能不正确。

综合(1)(2)两个式子就可以得到:

文本 AI 生成的内容可能不正确。

CE 曲线就是下图中的蓝色曲线,可以看到,相比较其他曲线,蓝色线条是变化最平缓的,即使在 p > 0.5  (已经属于很好区分的样本)的情况下,它的损失相对于其他曲线仍然是高的,尽管它相对于自己前面的已经下降很多了,但是当数量巨大的易区分样本损失相加,就会主导我们的训练过程,所以要进一步增加前后损失的大小比。


图示 AI 生成的内容可能不正确。

 

2.7 Balanced Cross Entropy

Balanced Cross Entropy 是常见的解决类不平衡的方法,其思想是引入一个权重因子α∈ [0,1],当类标签是1时,权重因子是α当类标签是-1时,权重因子是1-α。同样为了表示方便,用αt表示权重因子,那么此时的损失函数被改写为:

文本 AI 生成的内容可能不正确。

2.8 Focal Loss的计算

Balanced Cross Entropy 解决了正负样本的比例失衡问题(positive/negative examples),但是这种方法仅仅解决了正负样本之间的平衡问题,并没有区分简单还是难分样本(easy/hard examples)。当容易区分的负样本的泛滥时,整个训练过程都是围绕容易区分的样本进行(小损失积少成多超过大损失),而被忽略的难区分的样本才是训练的重点。作者新引入了一个调制因子,公式如下:

图片包含 图标 AI 生成的内容可能不正确。

其中 γ 也是一个参数,范围在 [0,5]。观察上式可以发现,当pt趋向于 1 时,说明该样本比较容易区分,整个调制因子(1-pt)^γ是趋向于 0 的,也就是 loss 的贡献值会很小;如果某样本被错分,pt很小,那么此时调制因子(1-pt)^γ是趋向 1 的,对 loss 没有大的影响(相对于基础的交叉熵),参数 γ 能够调整权重衰减的速率。从下面这张图可以看出,当γ = 0的时候,$FL $就是原来的交叉熵损失 CE,随着γ的增大,调整速率也在变化,实验表明,在γ= 2时,效果最佳。

图示 AI 生成的内容可能不正确。

 

结合正负样本平衡以及难易样本平衡,最终的 Focal loss 形式如下:

它的功能可以解释为:通过αt可以抑制正负样本的数量失衡,通过γ可以控制简单/难区分样本数量失衡。

对于 Focal loss,总结如下:

* 无论是前景类还是背景类,pt越大,权重 (1-pt)^γ就越小,即简单样本的损失可以通过权重进行抑制;

* αt 用于调节正负样本损失之间的比例,前景类别使用αt时,对应背景类别使用1-αt;

* γ和 αt的最优值是相互影响的,所以在评估准确度时需要把两者组合起来调节。作者在论文中给出γ= 2, αt =0.25 $时,ResNet-101+FPN 作为 backbone 的 RetinaNet 有最优的性能。这里αt =0.25 正样本的权重小,负样本的权重大有利于压低负样本的分类损失,尽可能将负样本的损失压低。

 

3.模型量化

模型量化(Model Quantization)就是通过某种方法将浮点模型转为定点模型。比如说原来的模型里面的权重(weight)都是float32,通过模型量化,将模型变成权重(weight)都是int的定点模型。

 

3.1 量化工具MOCA

这是一套基于mqbench的量化工具研发的,适用于光计算中的量化工具以及部署模型onnx的转换,这里的onnx模型中的算子是自定义的光计算硬件可执行的算子,适用于8/4/3/2bit的定点的量化工具。

 

 3.2 环境依赖

2.环境安装

1.主要是安装mmrotate repo和moca_cv(mqbench和omac),具体安装包的版本在如下:

 

Package                                     Version                  Editable project location

--------------------------------    ------------                 ------------------------------

accimage                         0.2.0

addict                           2.4.0

aliyun-python-sdk-core           2.15.2

aliyun-python-sdk-kms            2.16.5

appdirs                          1.4.4

backports-datetime-fromisoformat  2.0.0

backports.weakref                1.0.post1

certifi                          2024.7.4

cffi                             1.17.1

charset-normalizer               3.3.2

click                            8.1.7

cmake                            3.21.0

colorama                         0.4.6

coloredlogs                      15.0.1

coverage                         7.4.3

crcmod                           1.7

cryptography                     43.0.1

cycler                           0.12.1

Cython                           3.0.0

e2cnn                            0.2.3

easydict                         1.9

et-xmlfile                       1.1.0

exceptiongroup                   1.2.1

filelock                         3.14.0

flatbuffers                      24.3.25

graphviz                         0.20.3

humanfriendly                    10.0

idna                             3.7

imageio                          2.9.0

importlib_metadata               8.5.0

iniconfig                        2.0.0

Jinja2                           3.1.4

jmespath                         0.10.0

joblib                           1.2.0

jsonpickle                       3.0.1

kiwisolver                       1.4.5

lgt_license                      1.2.3

lightning-utilities              0.11.7

llvmlite                         0.31.0

lvis                             0.5.3

Mako                             1.3.2

Markdown                         3.7

markdown-it-py                   3.0.0

MarkupSafe                       2.1.3

matplotlib                       3.3.3

mdurl                            0.1.2

mkl-fft                          1.3.8

mkl-random                       1.2.4

mkl-service                      2.4.0

mmcv                             2.0.0rc4

mmcv-full                        1.7.2

mmdet       2.28.2     /local/miniconda/envs/moca_cv/lib/python3.8/site-packages

mmengine                         0.10.4

mmrotate                         0.3.4        /root/dota_workspace/mmrotate

model-index                      0.1.11

mpmath                           1.3.0

MQBench 0.0.6        /root/workspace_moca_shankun_new/moca/QuantForCV/mqbench

networkx                         3.1

ninja                            1.11.1.1

numba                            0.48.0

numpy                            1.22.4

onnx                             1.16.0

onnxruntime                      1.19.2

onnxsim                          0.4.36

opencv-python                    4.1.2.30

opendatalab                      0.0.10

openmim                          0.3.9

openpyxl                         3.0.9

openxlab                         0.1.1

ordered-set                      4.1.0

osimulator                       1.2.3

oss2                             2.17.0

packaging                        24.0

pandas                           0.25.3

Pillow                           6.2.1

pip                              23.3.1

platformdirs                     3.10.0

pluggy                           1.5.0

prettytable                      3.10.0

protobuf                         3.20.3

psutil                           5.9.0

pycocotools                      2.0.7

pycparser                        2.22

pycryptodome                     3.18.0

pycuda                           2024.1

Pygments                         2.18.0

pyparsing                        3.1.2

pytest                           7.4.3

pytest-cov                       4.1.0

pytest-html                      4.0.2

pytest-metadata                  3.1.1

pytest-runner                    6.0.1

python-dateutil                  2.9.0.post0

pytools                          2024.1.6

pytz                             2023.4

PyWavelets                       1.4.1

PyYAML                           6.0.1

requests                         2.28.2

rich                             13.4.2

scikit-image                     0.19.3

scikit-learn                     1.3.2

scipy                            1.6.3

setuptools                       60.2.0

shapely                          2.0.6

SharedArray                      3.2.1

six                              1.16.0

sympy                            1.13.2

tabulate                         0.9.0

tensorboardX                     1.8

termcolor                        2.4.0

terminaltables                   3.1.10

threadpoolctl                    3.5.0

tifffile                         2023.7.10

tomli                            2.0.1

torch                            1.10.1+cu113

torchaudio                       0.10.1+cu113

torchmetrics                     1.4.1

torchvision                      0.11.2+cu113

tqdm                             4.65.2

typing_extensions                4.11.0

urllib3                          1.26.20

wcwidth                          0.2.13

wheel                            0.43.0

yapf                             0.40.2

zipp                             3.20.2

注意:这里的mqbench是安装MOCA这套工具

 

3.3 执行程序

QAT量化执行程序,低bit训练量化

首先需要配置QAT,在train.py文件里面/root/dota_workspace/mmrotate/mmrotate/apis/train.py


文本 AI 生成的内容可能不正确。



 文本 AI 生成的内容可能不正确。 




文本 AI 生成的内容可能不正确。

 

QAT训练4bit模型,QAT训练脚本:

CUDA_VISIBLE_DEVICES=0 python -u tools/train.py configs/

rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_ms_rr_le90.py

--resume-from checkpoints

rotated_retinanet_obb_r50_fpn_1x_dota_ms_rr_le90-1da1ec9c.pth --work-dir logs/train_epoch_35/

 > train_log_epoch35.log 2>&1 &

 

配置文件:

rotated_retinanet_obb_r50_fpn_1x_dota_ms_rr_le90.py

 

预训练模型:

rotated_retinanet_obb_r50_fpn_1x_dota_ms_rr_le90-1da1ec9c.pth

 

训练的log保存:

logs/train_epoch_35/

 

3.4 训练结果

(1)对于1024x1024的输入:

 

图形用户界面 AI 生成的内容可能不正确。

 

(2)对于512x512的输入:


图片包含 图形用户界面 AI 生成的内容可能不正确。

 

4. 在光计算模拟器上的测试结果

推理脚本执行:

CUDA_VISIBLE_DEVICES=3 python -u tools/test_qat.py 

configs/rotated_retinanet/ rotated_retinanet_obb_r50_fpn_1x_dota_ms_rr_le90.py

checkpoints/epoch_15.pth --eval mAP > test_alllayers.log 2>&1 &

 

4.1 对于1024x1024的输入,采用compass模拟器


图形用户界面 AI 生成的内容可能不正确。

 

4.2 对于512x512的输入,采用compass模拟器


图形用户界面 AI 生成的内容可能不正确。

 

4.3 对于1024x1024的输入,采用pace2模拟器


图形用户界面 AI 生成的内容可能不正确。

 

4.4 对于512x512的输入,采用pace2模拟器


图形用户界面 AI 生成的内容可能不正确。


4.5 在光计算模拟器上的测试可视化


城市街道与高楼大厦的景色 AI 生成的内容可能不正确。


图片包含 盒子, 蛋糕, 桌子, 大 AI 生成的内容可能不正确。