Tensorflow 构建自己的目标检测与识别模型之数据增强(三)

Tensorflow 构建自己的目标检测与识别模型之数据增强(三)

上一篇的博客中如何对图像进行数据增强进行的叙述,见链接:https://blog.csdn.net/weixin_41644725/article/details/85678348
在本章内容中中,主要对采用数据增强技术后的图像进行保存,并将边界框信息存入到csv文件中,方便在生成tfrecord时用到(后面会提到)。
例如,以下是未采用数据增强时所生成的csv文件:

以下时未采用数据增强时存放图像的文件夹:

采用上一篇博客中(https://blog.csdn.net/weixin_41644725/article/details/85678348)中多提到的调整图像亮度,裁剪,cutout,旋转。
代码如下:

def creat_image_DA(img_name,img,bboxs,csv_path,img_class):
    if bboxs is not None:
        for bbox in bboxs:
            '''调整亮度'''
            list_box = []
            list_box.append(bbox)
            change_light_img, x_min1, y_min1, x_max1, y_max1 = changeLight(img=img, bboxes=list_box)
            change_light_img_size = change_light_img.shape
            b, g, r = cv2.split(change_light_img)
            change_light_img = cv2.merge([r, g, b])
            change_light_img = cv2.GaussianBlur(change_light_img, (3, 3), 0)
            msg1 = "change_light_" + img_name + "," + str(change_light_img_size[1]) + "," + str(
                change_light_img_size[0]) + "," \
                   + img_class + "," + str(x_min1) + "," + str(y_min1) + "," + str(x_max1) + "," + str(y_max1) + "\n"
            cv2.imwrite('./images/change_light_' + img_name, change_light_img)
            '''cutout'''
            cut_out_img, x_min2, y_min2, x_max2, y_max2 = cutout(img=img, bboxes=list_box)
            cut_out_img_size = cut_out_img.shape
            b, g, r = cv2.split(cut_out_img)
            cut_out_img = cv2.merge([r, g, b])
            cut_out_img = cv2.GaussianBlur(cut_out_img, (3, 3), 0)
            cv2.imwrite('./images/cut_out_' + img_name, cut_out_img)
            msg2 = "cut_out_" + img_name + "," + str(cut_out_img_size[1]) + "," + str(
                cut_out_img_size[0]) + "," + img_class + \
                   "," + str(x_min2) + "," + str(y_min2) + "," + str(x_max2) + "," + str(y_max2) + "\n"
            '''旋转'''
            rotate_img, x_min3, y_min3, x_max3, y_max3 = rotate_img_bbox(img=img, bboxes=list_box)
            rotate_img_size = rotate_img.shape
            b, g, r = cv2.split(rotate_img)
            rotate_img = cv2.merge([r, g, b])
            rotate_img = cv2.GaussianBlur(rotate_img, (3, 3), 0)
            cv2.imwrite('./images/rotate_' + img_name, rotate_img)
            msg3 = "rotate_" + img_name + "," + str(rotate_img_size[1]) + "," + str(
                rotate_img_size[0]) + "," + img_class + \
                   "," + str(x_min3) + "," + str(y_min3) + "," + str(x_max3) + "," + str(y_max3) + "\n"
            '''裁剪'''
            crop_img, x_min4, y_min4, x_max4, y_max4 = crop_img_bboxes(img=img, bboxes=list_box)
            crop_img_size = crop_img.shape
            b, g, r = cv2.split(crop_img)
            crop_img = cv2.merge([r, g, b])
            crop_img = cv2.GaussianBlur(crop_img, (3, 3), 0)
            cv2.imwrite('./images/crop_' + img_name, crop_img)
            msg4 = "crop_" + img_name + "," + str(crop_img_size[1]) + "," + str(crop_img_size[0]) + "," + img_class + \
                   "," + str(x_min4) + "," + str(y_min4) + "," + str(x_max4) + "," + str(y_max4) + "\n"
            all_msg = msg1 + msg2 + msg3 + msg4
            
            f = open(csv_path, 'a+')      #写入csv文件
            f.write(all_msg)
            f.close()

加载图像数据集时使用如下代码:

def load_train(train_path,csv_path):
    print('Going to read training images')
    m1 = 'change_light'
    m2 = 'cut_out'
    m3 = 'rotate'
    m4 = 'crop'
    m5 = 'shift'
    files = glob.glob(train_path)  #每个图像路径读取
    #print(len(files))
    for fl in files:
        m1_true = m1 in fl
        m2_true = m2 in fl
        m3_true = m3 in fl
        m4_true = m4 in fl
        m5_true = m5 in fl
        if m1_true!=True or m2_true!=True or m3_true!=True or m4_true!=True or m5_true!=True:
            img = cv2.imread(fl)
            b, g, r = cv2.split(img)
            img = cv2.merge([r, g, b])
            img = cv2.GaussianBlur(img, (3, 3), 0)
            coords, img_class = get_bbox(fl[7:], csv_path)
            coords = [coord[:4] for coord in coords]
            creat_image_DA(fl[7:], img, coords, csv_path, img_class)
def main():
	 csv_path = './csv/class.csv'
     train_path = 'images/*g'
     load_train(train_path,  csv_path)

main()

结果如图所示:

将图像数据集分为训练集和验证集,代码如下:

def split_train_vaild(csv_path):
    full_labels = pd.read_csv(csv_path)
    gb = full_labels.groupby('filename')
    grouped_list = [gb.get_group(x) for x in gb.groups]
    len_imge = len(grouped_list)
    train_index = np.random.choice(len_imge, size=int(len_imge*0.8), replace=False)
    test_index = np.setdiff1d(list(range(len_imge)), train_index)
    train = pd.concat([grouped_list[i] for i in train_index])
    test = pd.concat([grouped_list[i] for i in test_index])
    print(len(train_index), len(test_index))
    train.to_csv('data_set/all_train.csv', index=None)
    test.to_csv('data_set/all_vaild.csv', index=None)
csv_path = 'csv/class.csv'
split_train_vaild(csv_path)

然后生成tfrecord格式,代码如下:

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import os
import io
import pandas as pd
import tensorflow as tf
from PIL import Image
#from object_detection.utils import dataset_util
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict
flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS

 def class_text_to_int(row_label):
    if row_label == 'class1':
        return 1
    elif row_label == class2':
        return 2
    elif row_label == 'class3':
        return 3
    elif row_label == 'class4':
        return 4
    elif row_label == 'class5':
        return 5
    elif row_label == 'class6':
        return 6
    else:
        print('NONE: ' + row_label)
        # None
def split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
def create_tf_example(group, path):
    print(os.path.join(path, '{}'.format(group.filename)))
    with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size
    filename = (group.filename + '.jpg').encode('utf8')
    image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    for index, row in group.object.iterrows():
        xmins.append(int(row['xmin']) / int(width))
        xmaxs.append(int(row['xmax']) / int(width))
        ymins.append(int(row['ymin']) / int(height))
        ymaxs.append(int(row['ymax']) / int(height))
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_text_to_int(row['class']))

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example
def main(csv_input, output_path, imgPath):
    writer = tf.python_io.TFRecordWriter(output_path)
    path = imgPath
    examples = pd.read_csv(csv_input)
    grouped = split(examples, 'filename')
    for group in grouped:
        tf_example = create_tf_example(group, path)
        writer.write(tf_example.SerializeToString())
    writer.close()
    print('Successfully created the TFRecords: {}'.format(output_path))


if __name__ == '__main__':
    imgPath = 'images/all_images'
    # 生成train.record文件
    output_path = 'data_set/all_train.record'
    csv_input = 'data_set/all_train.csv'
    main(csv_input, output_path, imgPath)

    # 生成验证文件 eval.record
    output_path = 'data_set/all_vaild.record'
    csv_input = 'data_set/all_vaild.csv'
    main(csv_input, output_path, imgPath)

在此处要注意下面部分,有几个类设置几个

def class_text_to_int(row_label):
    if row_label == 'class1':
        return 1
    elif row_label == class2':
        return 2
    elif row_label == 'class3':
        return 3
    elif row_label == 'class4':
        return 4
    elif row_label == 'class5':
        return 5
    elif row_label == 'class6':
        return 6
    else:
        print('NONE: ' + row_label)
        # None
全部评论

相关推荐

贺兰星辰:不要漏个人信息,除了简历模板不太好以外你这个个人简介是不是太夸大了...
点赞 评论 收藏
分享
Yushuu:你的确很厉害,但是有一个小问题:谁问你了?我的意思是,谁在意?我告诉你,根本没人问你,在我们之中0人问了你,我把所有问你的人都请来 party 了,到场人数是0个人,誰问你了?WHO ASKED?谁问汝矣?誰があなたに聞きましたか?누가 물어봤어?我爬上了珠穆朗玛峰也没找到谁问你了,我刚刚潜入了世界上最大的射电望远镜也没开到那个问你的人的盒,在找到谁问你之前我连癌症的解药都发明了出来,我开了最大距离渲染也没找到谁问你了我活在这个被辐射蹂躏了多年的破碎世界的坟墓里目睹全球核战争把人类文明毁灭也没见到谁问你了😆
点赞 评论 收藏
分享
评论
点赞
收藏
分享
牛客网
牛客企业服务