Tensorflow 构建自己的目标检测与识别模型之数据增强(二)

Tensorflow 构建自己的目标检测与识别模型之数据增强(二)

上次的博客中对如何安装Tensorflow Object Detection API的步骤及所遇到的问题进行说明。见链接:https://blog.csdn.net/weixin_41644725/article/details/83007901
接下来,对图像数据进行图像增强。虽然在配置.config文件(后面会说到)时,其中会提到数据增强(data argumentation),但是若是想手动实现,可参考本文,若不想则跳过即可。

1.用labelImage工具生成.xml文件。

该工具的界面如图所示,关于如何安装labelImage,可参考网上的相关博客,在windows和Linux下都有相应的安装过程,此处不叙述安装过程。其中“Open Dir”为打开存储所有图像文件的文件夹。“Change Save Dir”为将生成的.xml文件存储在指定文件夹下面。“Save”表示保存当前的.xml文件。

xml文件的格式如下图所示:

2. xml 转成csv文件

(1)将xml文件转成csv文件代码如下:
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET

def xml_to_csv(path):
    xml_list = []
    for xml_file in glob.glob(path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
                     int(member[4][1].text),
                     int(member[4][2].text),
                     int(member[4][3].text)
                     )
            xml_list.append(value)
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(xml_list, columns=column_name)
    return xml_df

def main():
    xml_path = './xml'           #存储xml的文件夹
    xml_df = xml_to_csv(xml_path)
    xml_df.to_csv('./csv/class.csv', index=None)   #生成csv文件并存储在该路径下
    print('Successfully converted xml to csv.')
    
main()
(2)得到该图像中对应类的边界框(bounding box),代码如下:
import os
import cv2
import pandas as pd
import matplotlib.pyplot as plt
def get_bbox(image_name,csv_path):
    full_labels = pd.read_csv(csv_path)
    selected_value = full_labels[full_labels.filename == image_name]
    images_bbox = []
    img_class = ''
    for index,row in selected_value.iterrows():
        list_bbox = []
        list_bbox.append(row['xmin'])
        list_bbox.append(row['ymin'])
        list_bbox.append(row['xmax'])
        list_bbox.append(row['ymax'])
        list_bbox.append(image_name)
        img_class = row['class']
        images_bbox.append(list_bbox)
    return images_bbox,img_class
    
 img_path = '023.jpg'
 csv_path = ''./csv/class.csv''
 img = cv2.imread(img_path)
 b, g, r = cv2.split(img)
 img = cv2.merge([r, g, b])
 image = cv2.GaussianBlur(img, (3, 3), 0)
 coords = get_bbox(img_path)
 coords = [coord[:4] for coord in coords]
 for i in range(len(coords)):
     bbox = coords[i]
     x_min = bbox[0]
     y_min = bbox[1]
     x_max = bbox[2]
     y_max = bbox[3]
     cv2.rectangle(image, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
 plt.subplot(111), plt.imshow(image), plt.title('original', fontsize='medium')
 plt.show()

输出结果如下:

3.图像数据增强

(1)调整图像亮度

代码如下:

 import os
 import cv2
 import pandas as pd
 import matplotlib.pyplot as plt
    '''调整亮度'''
 def changeLight(img,bboxes):
        flag = random.uniform(1.5, 2)  # flag>1为调暗,小于1为调亮
        img = exposure.adjust_gamma(img, flag)
        cv2.imwrite('./1.jpg', img)
        img = cv2.imread('./1.jpg')
        os.remove('./1.jpg')
        for i in range(len(bboxes)):
            bbox = bboxes[i]
            x_min = bbox[0]
            y_min = bbox[1]
            x_max = bbox[2]
            y_max = bbox[3]
            cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
        return img
    img_path = '023.jpg'
    img = cv2.imread(img_path)
    b, g, r = cv2.split(img)
    img = cv2.merge([r, g, b])
    img = cv2.GaussianBlur(img, (3, 3), 0)
    image = cv2.GaussianBlur(img, (3, 3), 0)
    coords = get_bbox(img_path)
    coords = [coord[:4] for coord in coords]
    for i in range(len(coords)):
        bbox = coords[i]
        x_min = bbox[0]
        y_min = bbox[1]
        x_max = bbox[2]
        y_max = bbox[3]
        cv2.rectangle(image, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
    '''调整亮度'''
    change_light_img = changeLight(img=img, bboxes=coords)
    plt.subplot(121), plt.imshow(image), plt.title('original', fontsize='medium')
    plt.subplot(122), plt.imshow(change_light_img), plt.title('change light', fontsize='medium')
    plt.show()

输出结果如下:

(2)cutout

代码如下:

    '''cutout'''
    def cutout(img, bboxes, length=100, n_holes=1, threshold=0.5):
        '''
        原版本:https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
        Randomly mask out one or more patches from an image.
        Args:
            img : a 3D numpy array,(h,w,c)
            bboxes : 框的坐标
            n_holes (int): Number of patches to cut out of each image.
            length (int): The length (in pixels) of each square patch.
        '''
        def cal_iou(boxA, boxB):
            '''
            boxA, boxB为两个框,返回iou
            boxB为bouding box
            '''
            # determine the (x, y)-coordinates of the intersection rectangle
            xA = max(boxA[0], boxB[0])
            yA = max(boxA[1], boxB[1])
            xB = min(boxA[2], boxB[2])
            yB = min(boxA[3], boxB[3])
            if xB <= xA or yB <= yA:
                return 0.0
            # compute the area of intersection rectangle
            interArea = (xB - xA + 1) * (yB - yA + 1)
            # compute the area of both the prediction and ground-truth
            # rectangles
            boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
            boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
            # compute the intersection over union by taking the intersection
            # area and dividing it by the sum of prediction + ground-truth
            # areas - the interesection area
            iou = interArea / float(boxAArea + boxBArea - interArea)
            #iou = interArea / float(boxBArea)
             # return the intersection over union value
            return iou
        # 得到h和w
        if img.ndim == 3:
            h, w, c = img.shape
        else:
            _, h, w, c = img.shape
        mask = np.ones((h, w, c), np.float32)
        for n in range(n_holes):
            chongdie = True  # 看切割的区域是否与box重叠太多
            while chongdie:
                y = np.random.randint(h)
                x = np.random.randint(w)
                y1 = np.clip(y - length // 2, 0,
                            h)  # numpy.clip(a, a_min, a_max, out=None), clip这个函数将将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min
                y2 = np.clip(y + length // 2, 0, h)
                x1 = np.clip(x - length // 2, 0, w)
                x2 = np.clip(x + length // 2, 0, w)
                chongdie = False
                for box in bboxes:
                    if cal_iou([x1, y1, x2, y2], box) > threshold:
                        chongdie = True
                        break
            mask[y1: y2, x1: x2, :] = 0.
        # mask = np.expand_dims(mask, axis=0)
        img = img * mask
        for i in range(len(bboxes)):
            bbox = bboxes[i]
            x_min = bbox[0]
            y_min = bbox[1]
            x_max = bbox[2]
            y_max = bbox[3]
            cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
        cv2.imwrite('./1.jpg', img)
        img = cv2.imread('./1.jpg')
        os.remove('./1.jpg')
        return img
 img_path = '023.jpg'
 img = cv2.imread(img_path)
 b, g, r = cv2.split(img)
 img = cv2.merge([r, g, b])
 img = cv2.GaussianBlur(img, (3, 3), 0)
 image = cv2.GaussianBlur(img, (3, 3), 0)
 coords = get_bbox(img_path)
 coords = [coord[:4] for coord in coords]
 for i in range(len(coords)):
     bbox = coords[i]
     x_min = bbox[0]
     y_min = bbox[1]
     x_max = bbox[2]
     y_max = bbox[3]
     cv2.rectangle(image, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
 '''调整亮度'''
 cut_out_img = cutout(img=img, bboxes=coords)
 plt.subplot(121), plt.imshow(image), plt.title('original', fontsize='medium')
 plt.subplot(122), plt.imshow(cut_out_img), plt.title('cutout', fontsize='medium')
 plt.show()

输出结果如下:

(3)旋转

代码如下:

'''旋转'''
def rotate_img_bbox(img, bboxes, angle=5, scale=1.):
    '''
    参考:https://blog.csdn.net/u014540717/article/details/53301195crop_rate
    输入:
        img:图像array,(h,w,c)
        bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
        angle:旋转角度
        scale:默认1
    输出:
        rot_img:旋转后的图像array
        rot_bboxes:旋转后的boundingbox坐标list
    '''
    # ---------------------- 旋转图像 ----------------------
    w = img.shape[1]
    h = img.shape[0]
    # 角度变弧度
    rangle = np.deg2rad(angle)  # angle in radians
    # now calculate new image width and height
    nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
    nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
    # ask OpenCV for the rotation matrix
    rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
    # calculate the move from the old center to the new center combined
    # with the rotation
    rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
    # the move only affects the translation, so update the translation
    # part of the transform
    rot_mat[0, 2] += rot_move[0]
    rot_mat[1, 2] += rot_move[1]
    # 仿射变换
    rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
    # ---------------------- 矫正bbox坐标 ----------------------
    # rot_mat是最终的旋转矩阵
    # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
    rot_bboxes = list()
    for bbox in bboxes:
        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[2]
        ymax = bbox[3]
        point1 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymin, 1]))
        point2 = np.dot(rot_mat, np.array([xmax, (ymin + ymax) / 2, 1]))
        point3 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymax, 1]))
        point4 = np.dot(rot_mat, np.array([xmin, (ymin + ymax) / 2, 1]))
        # 合并np.array
        concat = np.vstack((point1, point2, point3, point4))
        # 改变array类型
        concat = concat.astype(np.int32)
        # 得到旋转后的坐标
        rx, ry, rw, rh = cv2.boundingRect(concat)
        rx_min = rx
        ry_min = ry
        rx_max = rx + rw
        ry_max = ry + rh
        # 加入list中
        rot_bboxes.append([rx_min, ry_min, rx_max, ry_max])
    for i in range(len(rot_bboxes)):
        bbox = rot_bboxes[i]
        x_min = bbox[0]
        y_min = bbox[1]
        x_max = bbox[2]
        y_max = bbox[3]
        cv2.rectangle(rot_img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
    cv2.imwrite('./1.jpg', rot_img)
    rot_img = cv2.imread('./1.jpg')
    os.remove('./1.jpg')
    return rot_img
 img_path = '023.jpg'
 img = cv2.imread(img_path)
 b, g, r = cv2.split(img)
 img = cv2.merge([r, g, b])
 img = cv2.GaussianBlur(img, (3, 3), 0)
 image = cv2.GaussianBlur(img, (3, 3), 0)
 coords = get_bbox(img_path)
 coords = [coord[:4] for coord in coords]
 for i in range(len(coords)):
     bbox = coords[i]
     x_min = bbox[0]
     y_min = bbox[1]
     x_max = bbox[2]
     y_max = bbox[3]
     cv2.rectangle(image, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
 '''调整亮度'''
rotate_img = rotate_img_bbox(img=img, bboxes=coords)
plt.subplot(121), plt.imshow(image), plt.title('original', fontsize='medium')
plt.subplot(122), plt.imshow(rotate_img), plt.title('rotate', fontsize='medium')
plt.show()

输出结果如下:

(4)裁剪

代码如下:

'''裁剪'''
def crop_img_bboxes(img, bboxes):
    '''
    裁剪后的图片要包含所有的框
    输入:
        img:图像array
        bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
    输出:
        crop_img:裁剪后的图像array
        crop_bboxes:裁剪后的bounding box的坐标list
    '''
    # ---------------------- 裁剪图像 ----------------------
    w = img.shape[1]
    h = img.shape[0]
    x_min = w  # 裁剪后的包含所有目标框的最小的框
    x_max = 0
    y_min = h
    y_max = 0
    for bbox in bboxes:
        x_min = min(x_min, bbox[0])
        y_min = min(y_min, bbox[1])
        x_max = max(x_max, bbox[2])
        y_max = max(y_max, bbox[3])
    d_to_left = x_min  # 包含所有目标框的最小框到左边的距离
    d_to_right = w - x_max  # 包含所有目标框的最小框到右边的距离
    d_to_top = y_min  # 包含所有目标框的最小框到顶端的距离
    d_to_bottom = h - y_max  # 包含所有目标框的最小框到底部的距离
    # 随机扩展这个最小框
    crop_x_min = int(x_min - random.uniform(0, d_to_left))
    crop_y_min = int(y_min - random.uniform(0, d_to_top))
    crop_x_max = int(x_max + random.uniform(0, d_to_right))
    crop_y_max = int(y_max + random.uniform(0, d_to_bottom))
    # 随机扩展这个最小框 , 防止别裁的太小
    # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left))
    # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top))
    # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right))
    # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom))
    # 确保不要越界
    crop_x_min = max(0, crop_x_min)
    crop_y_min = max(0, crop_y_min)
    crop_x_max = min(w, crop_x_max)
    crop_y_max = min(h, crop_y_max)
    crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]
    # ---------------------- 裁剪boundingbox ----------------------
    # 裁剪后的boundingbox坐标计算
    crop_bboxes = list()
    for bbox in bboxes:
        crop_bboxes.append([bbox[0] - crop_x_min, bbox[1] - crop_y_min, bbox[2] - crop_x_min, bbox[3] - crop_y_min])
    for i in range(len(crop_bboxes)):
        bbox = crop_bboxes[i]
        x_min = bbox[0]
        y_min = bbox[1]
        x_max = bbox[2]
        y_max = bbox[3]
        cv2.rectangle(crop_img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
    cv2.imwrite('./1.jpg', crop_img)
    crop_img = cv2.imread('./1.jpg')
    os.remove('./1.jpg')
    return crop_img
 img_path = '023.jpg'
 img = cv2.imread(img_path)
 b, g, r = cv2.split(img)
 img = cv2.merge([r, g, b])
 img = cv2.GaussianBlur(img, (3, 3), 0)
 image = cv2.GaussianBlur(img, (3, 3), 0)
 coords = get_bbox(img_path)
 coords = [coord[:4] for coord in coords]
 for i in range(len(coords)):
     bbox = coords[i]
     x_min = bbox[0]
     y_min = bbox[1]
     x_max = bbox[2]
     y_max = bbox[3]
     cv2.rectangle(image, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
 '''调整亮度'''
crop_img = crop_img_bboxes(img=img, bboxes=coords)
plt.subplot(121), plt.imshow(image), plt.title('original', fontsize='medium')
plt.subplot(122), plt.imshow(crop_img), plt.title('crop', fontsize='medium')
plt.show()

输出结果如下:

全部评论

相关推荐

10-04 17:25
门头沟学院 Java
snqing:Java已经饱和了,根本不缺人。随便一个2000工资的都200人起投递
点赞 评论 收藏
分享
10-27 17:26
东北大学 Java
点赞 评论 收藏
分享
oppo 应用软开 22*15+0.5*12
拿到了ssp完美:真的坎坷,但是你至少拿到这么多offer了!
点赞 评论 收藏
分享
评论
点赞
收藏
分享
牛客网
牛客企业服务