colab怎么提高训练速度,特别对于第一步训练时间特别长 colab切换为1的版本

一、必看:

谷歌给我们提供了yunGPU,挺好用的,但是坑多。我voc2007+voc2012数据集,训练集总共20000多张图片,第一步训练就需要花费2个多小时,总共也才6个小时左右,所以提高训练速度很重要。训练速度慢的原因主要是,每一批次都需要去google drive上去读,所以说第一步特别慢。

注意:1、图片数量过少,可以直接复制到创建的文件下

%cp -av 源文件 目标文件夹

如果,数据量大,还是按照下面步骤来。先压缩,在上传,在复制,在解压。

二、解决思路是,

1、先在自己电脑上压缩成.zip文件,上传到google drive(谷歌云盘)
2、在把zip文件复制到工作目录,在继续训练,速度大大提高,第一步训练提高到了8分钟左右
一、创建文件,工作目录。
上传压缩数据集,这些大家都会,这里主要说一下第二步。读取数据慢主要是图片,所以这里只需要把图片弄成压缩文件就可以了,这里需要提一下,2个G的压缩文件上传到谷歌云盘大概10分钟,解压大概2分钟,速度已经很快了。

一、创建文件,工作目录。

!mkdir train_local


结果:

二、复制压缩文件到创建的根目录

%cp -av /content/drive/MyDrive/Voc/VOC2012/JPEGImages.rar /content/train_local

三、解压

!pip install pyunpack
!pip install patool
from pyunpack import Archive
Archive('/content/train_local/JPEGImages.rar').extractall('/content/train_local')

四、复制代码训练,训练可以解决。

三、 补充:由于图片路径更改,对于的voc_txt文件更改。

一、方法一:在进行xml文件转txt文件进行修改路径,修改对应的路径,需要跟换的图片路径:annotation

import os
import argparse
import xml.etree.ElementTree as ET

def convert_voc_annotation(data_path, data_type, anno_path, use_difficult_bbox=True):

    classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
               'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
               'train', 'tvmonitor']
    img_inds_file = os.path.join(data_path, 'ImageSets', 'Main1', data_type + '.txt')
    with open(img_inds_file, 'r') as f:
        txt = f.readlines()
        image_inds = [line.strip() for line in txt]

    with open(anno_path, 'a') as f:
        for image_ind in image_inds:
            image_path = os.path.join(data_path, 'JPEGImages', image_ind + '.jpg')
            annotation = os.path.join('/content/train_local/JPEGImages/'+image_ind + '.jpg')
            label_path = os.path.join(data_path, 'Annotations', image_ind + '.xml')
            root = ET.parse(label_path).getroot()
            objects = root.findall('object')
            for obj in objects:
                difficult = obj.find('difficult').text.strip()
                if (not use_difficult_bbox) and(int(difficult) == 1):
                    continue
                bbox = obj.find('bndbox')
                class_ind = classes.index(obj.find('name').text.lower().strip())
                xmin = int(round((float(bbox.find('xmin').text.strip())), 2))
                xmax = int(round((float(bbox.find('xmax').text.strip())), 2))
                ymin = int(round(float((bbox.find('ymin').text.strip())), 2))
                ymax = int(round(float((bbox.find('ymax').text.strip())), 2))
                annotation += ' ' + ','.join([str(xmin), str(ymin), str(xmax), str(ymax), str(class_ind)])
            print(annotation)
            f.write(annotation + "\n")
    return len(image_inds)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", default="D:/tensorflow-yolov3/VOCtest_06-Nov-2007/VOCdevkit/VOC2007")
    parser.add_argument("--train_annotation", default="D:/tensorflow-yolov3/VOCtrainval_11-May-2012 (1)/VOCdevkit/VOC2012/data/dataset/voc_train1.txt")
    parser.add_argument("--test_annotation",  default="D:/tensorflow-yolov3/VOCtrainval_11-May-2012 (1)/VOCdevkit/VOC2012/data/dataset/voc_test.txt")
    flags = parser.parse_args()

    if os.path.exists(flags.train_annotation):os.remove(flags.train_annotation)
    if os.path.exists(flags.test_annotation):os.remove(flags.test_annotation)

    num1 = convert_voc_annotation(os.path.join('D:/tensorflow-yolov3/VOCtrainval_11-May-2012 (1)/VOCdevkit/VOC2012/'), 'train', flags.train_annotation, True)



二、方法二、直接在对应的txt文件进行修改

import os
import argparse
import xml.etree.ElementTree as ET    

#保存.txt路径
save_file=open('/content/drive/MyDrive/voc_2007_crack/Voc2007_1/data/new_data_1.txt','w')  
#需要修改的TXT文件  
img_inds_file = os.path.join("/content/drive/MyDrive/voc_2007_crack/Voc2007_1/data/dataset/voc_train.txt")

with open(img_inds_file, 'r') as f:
  txt = f.readlines()
  image_inds = [line.strip() for line in txt]
  for image_ind in image_inds:
    annoth = '/content/drive/MyDrive/voc_2007_crack/Voc2007_1/JPEGImages/'
    annoth+=image_ind.split('\\')[-1]
    save_file.write(annoth+'\n')

三、conda 切换为tensorflow 1.x

% tensorflow_version 1.x
全部评论

相关推荐

只写bug的程序媛:才15,我招行20多万,建设银行50多万,说放弃就放弃
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务