TF_Record格式数据制作、读取 (基于猫狗大战、cifar10数据)

1、制作TF_Record数据集

  试用Image.open()打开图片会占用很大的内存,在这里我试用的是tf.gfile.Gfile,所以建议大家试用tf.gfile.Gfile(path,‘rb’)打开图片;

def create_tf_example(img_list,label,sess):
    #image = Image.open(img_list)   # 使用PIL skimage cv 读取图片 占用内存较大
    #image = image.resize((300,300))
    #image = image.tobytes()

    with tf.gfile.FastGFile(img_list,'rb') as fid:
        img = fid.read()

    ##  数据预处理   但是速度较慢   还是提前转换好尺寸大小较好
    #img = tf.image.decode_png(img, channels=3)  # 这里,也可以解码为 1 通道
    #img = tf.image.resize_image_with_crop_or_pad(img,40,40)   # 预处理速度很慢  但是效果较好   补充黑边 或 中心裁剪
    #img = tf.image.resize_images(img,[40,40])   # 速度较快   但还是很慢
    #image = sess.run(img)
    #image_bytes = image.tobytes()   # 将张量转换为 bytes   注意这种格式的解码方式不一样

    example = tf.train.Example(features=tf.train.Features(
        feature={
            'label': _int64_feature(label),
            'img_raw': _bytes_feature(img),
            #'width':_int64_feature(width),
            #'height':_int64_feature(height)
        }))

    return example    # 返回一个可写入的example

2、读取Cat_vs_Dogs数据,并生成record:

filepath = '/Users/***/Git_Mac/Cat_Vs_Dog/train/'     #cat_dog 根目录  下面
out_dir = 'Record/cat_dogs_2.record' 
def Creat_Cats_Vs_Dogs(file_dir,out_dir):
    cats = []
    label_cats = []
    dogs = []
    label_dogs = []
    for file in os.listdir(file_dir):   #读取所有图片的路径
        name = file.split(sep='.')
        if name[0]=='cat':
            cats.append(file_dir + file)
            label_cats.append(0)
        else:
            dogs.append(file_dir + file)
            label_dogs.append(1)
    print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs)))

    image_list = np.hstack((cats, dogs))    #组合数据
    label_list = np.hstack((label_cats, label_dogs))

    temp = np.array([image_list, label_list])
    temp = temp.transpose()   # 转置
    np.random.shuffle(temp)   # 打乱数据

    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]

    sess = tf.Session()
    write = tf.python_io.TFRecordWriter(out_dir)
    count =0
    for img,lbe in zip(image_list,label_list):

        example = create_tf_example(img,lbe,sess)   #每个图片生成一个example  并写入
        write.write(example.SerializeToString())

        count += 1
        if(count % 1000 == 0):
            print(count)

3、读取Cifar10数据,并生成record:

path = 'Git_Mac/cifar10'   # data 根目录
def Create_Cifar10_Record(path,out_dir):
    write = tf.python_io.TFRecordWriter(out_dir)

    label_list = [0,1,2,3,4,5,6,7,8,9]  # 标签
    sess = tf.Session()

    for index,directory in zip(label_list,['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']):
        img_list = glob.glob(os.path.join(path,'{}/*.png'.format(directory)))  #读取所有的图片路径
        count = 0

        for img in img_list:
            example = create_tf_example(img,index,sess=sess)
            write.write(example.SerializeToString())
            count+=1
            if(count %100 ==0):   # 查看标签  和  进度
                print(count)

    sess.close()
    write.close()

4、从TF_Record格式中读取数据

   从record格式中读取数据并解码  

def read_and_decode(tfrecords_file, batch_size, shuffle,n_class,one_hot = False):
    filename_queue = tf.train.string_input_producer([tfrecords_file])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  

    img_features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string),
            #'width': tf.FixedLenFeature([], tf.int64),
            #'height': tf.FixedLenFeature([], tf.int64)
        })

    #image = tf.decode_raw(img_features['img_raw'], tf.uint8)
    #width = tf.cast(img_features['width'], tf.int32)
    #height = tf.cast(img_features['height'], tf.int32)

    img = img_features['img_raw']
    img = tf.image.decode_png(img, channels=3)  # 解码图片  png格式   jpg 使用 decode_jpeg()
    image = tf.reshape(img, [32, 32, 3])   # 32*32*3   这个需要根据你自己的格式进行修改
    label = tf.cast(img_features['label'], tf.int32) 
    #image = tf.reshape(image, [300,300,3])
    image = tf.image.per_image_standardization(image)  # 标准化处理

    if shuffle:         # 是否打乱数据顺序  如果capacity设置过小 会导致数据混合不完全 打乱数据读取会占用很多内存
        image_batch, label_batch = tf.train.shuffle_batch(
            [image, label],
            batch_size = batch_size,
            num_threads= 64,
            capacity = 20000,
            min_after_dequeue = 1000)
    else:
        image_batch, label_batch= tf.train.batch(
            [image,label],
            batch_size = batch_size,
            num_threads = 64,
            capacity= 2000)

    image_batch = tf.cast(image_batch, tf.float32)   # 转换为tf.float32 格式

    if(one_hot == True):    # 生成one_hot格式标签  one_hot格式标签  对应不同的loss 设置方式
        label_batch = tf.one_hot(label_batch, depth= n_class)
        label_batch = tf.cast(label_batch, dtype=tf.int32)
        label_batch = tf.reshape(label_batch, [batch_size, n_class])

    return image_batch, label_batch

 线程读取数据

def Read_Record(filepath):
    with tf.Session() as sess: #开始一个会话
        image,label = read_and_decode(filepath,batch_size=batch_size,shuffle=True,n_class=2,one_hot=False)

        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        coord=tf.train.Coordinator()    # 很重要的
        threads= tf.train.start_queue_runners(coord=coord)
        try:
            for step in range(MAX_STEP):
                if coord.should_stop():
                    break  img,lbe = sess.run([image,label])
                # 添加你自己的模型 teain 
                #plot_images(img,lbe,batch_size=batch_size)

        except tf.errors.OutOfRangeError as e:
            print(e)
        finally:
            coord.request_stop()

        coord.join(threads)

全部评论

相关推荐

会飞的猿:我看你想进大厂,我给你总结一下学习路线吧,java语言方面常规八股要熟,那些java的集合,重点背hashmap八股吧,jvm类加载机制,运行时分区,垃圾回收算法,垃圾回收器CMS、G1这些,各种乐观锁悲观锁,线程安全,threadlocal这些。在进阶一些的比如jvm参数,内存溢出泄漏排查,jvm调优。我这里说的只是冰山一角,详细八股可以去网上找,这不用去买,都免费资源。mysql、redis可以去看小林coding,我看你简历上写了,你一定要熟,什么底层b+树、索引结构、innodb、mvcc、undo log、redo log、行级锁表级锁,这些东西高频出现,如果面试官问我这些我都能笑出来。消息队列rabbitmq也好kafka也好,学一种就行,什么分区啊副本啊确认机制啊怎么保证不重复消费、怎么保证消息不丢失这些基本的一定要会,进阶一点的比如LEO、高水位线、kafka和rocketmq底层零拷贝的区别等等。计算机网络和操作系统既然你是科班应该理解起来问题不大,去看小林coding这两块吧,深度够了。spring boot的八股好好看看吧,一般字节腾讯不这么问,其他的java大厂挺爱问的,什么循环依赖啥的去网上看看。数据结构的话科班应该问题不大,多去力扣集中突击刷题吧。项目的话其实说白了还是结合八股来,想一想你写的这些技术会给你挖什么坑。除此之外,还有场景题、rpc、设计模式、linux命令、ddd等。不会的就别往简历上写了,虽然技术栈很多的话好看些,但背起来确实累。总结一下,多去实习吧,多跳槽,直到跳到一个不错的中厂做跳板,这是一条可行的进大厂的路线。另外,只想找个小厂的工作的话,没必要全都照这些准备,太累了,重点放在框架的使用和一些基础八股吧。大致路线就这样,没啥太多难度,就是量大,你能达到什么高度取决于你对自己多狠,祝好。
点赞 评论 收藏
分享
醒工硬件:1学校那里把xxxxx学院去了,加了学院看着就不像本校 2简历实习和项目稍微精简一下。字太多,面试官看着累 3第一个实习格式和第二个实习不一样。建议换行 4项目描述太详细了,你快把原理图贴上来了。比如可以这样描述:使用yyyy芯片,使用xx拓扑,使用pwm控制频率与占空比,进行了了mos/电感/变压器选型,实现了xx功能 建议把技术栈和你做的较为有亮点的工作归纳出来 5熟悉正反激这个是真的吗
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务