TF_Record格式数据制作、读取 (基于猫狗大战、cifar10数据)
1、制作TF_Record数据集
试用Image.open()打开图片会占用很大的内存,在这里我试用的是tf.gfile.Gfile,所以建议大家试用tf.gfile.Gfile(path,‘rb’)打开图片;
def create_tf_example(img_list,label,sess): #image = Image.open(img_list) # 使用PIL skimage cv 读取图片 占用内存较大 #image = image.resize((300,300)) #image = image.tobytes() with tf.gfile.FastGFile(img_list,'rb') as fid: img = fid.read() ## 数据预处理 但是速度较慢 还是提前转换好尺寸大小较好 #img = tf.image.decode_png(img, channels=3) # 这里,也可以解码为 1 通道 #img = tf.image.resize_image_with_crop_or_pad(img,40,40) # 预处理速度很慢 但是效果较好 补充黑边 或 中心裁剪 #img = tf.image.resize_images(img,[40,40]) # 速度较快 但还是很慢 #image = sess.run(img) #image_bytes = image.tobytes() # 将张量转换为 bytes 注意这种格式的解码方式不一样 example = tf.train.Example(features=tf.train.Features( feature={ 'label': _int64_feature(label), 'img_raw': _bytes_feature(img), #'width':_int64_feature(width), #'height':_int64_feature(height) })) return example # 返回一个可写入的example
2、读取Cat_vs_Dogs数据,并生成record:
filepath = '/Users/***/Git_Mac/Cat_Vs_Dog/train/' #cat_dog 根目录 下面 out_dir = 'Record/cat_dogs_2.record'
def Creat_Cats_Vs_Dogs(file_dir,out_dir): cats = [] label_cats = [] dogs = [] label_dogs = [] for file in os.listdir(file_dir): #读取所有图片的路径 name = file.split(sep='.') if name[0]=='cat': cats.append(file_dir + file) label_cats.append(0) else: dogs.append(file_dir + file) label_dogs.append(1) print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs))) image_list = np.hstack((cats, dogs)) #组合数据 label_list = np.hstack((label_cats, label_dogs)) temp = np.array([image_list, label_list]) temp = temp.transpose() # 转置 np.random.shuffle(temp) # 打乱数据 image_list = list(temp[:, 0]) label_list = list(temp[:, 1]) label_list = [int(i) for i in label_list] sess = tf.Session() write = tf.python_io.TFRecordWriter(out_dir) count =0 for img,lbe in zip(image_list,label_list): example = create_tf_example(img,lbe,sess) #每个图片生成一个example 并写入 write.write(example.SerializeToString()) count += 1 if(count % 1000 == 0): print(count)
3、读取Cifar10数据,并生成record:
path = 'Git_Mac/cifar10' # data 根目录
def Create_Cifar10_Record(path,out_dir): write = tf.python_io.TFRecordWriter(out_dir) label_list = [0,1,2,3,4,5,6,7,8,9] # 标签 sess = tf.Session() for index,directory in zip(label_list,['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']): img_list = glob.glob(os.path.join(path,'{}/*.png'.format(directory))) #读取所有的图片路径 count = 0 for img in img_list: example = create_tf_example(img,index,sess=sess) write.write(example.SerializeToString()) count+=1 if(count %100 ==0): # 查看标签 和 进度 print(count) sess.close() write.close()
4、从TF_Record格式中读取数据
从record格式中读取数据并解码
def read_and_decode(tfrecords_file, batch_size, shuffle,n_class,one_hot = False): filename_queue = tf.train.string_input_producer([tfrecords_file]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) img_features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string), #'width': tf.FixedLenFeature([], tf.int64), #'height': tf.FixedLenFeature([], tf.int64) }) #image = tf.decode_raw(img_features['img_raw'], tf.uint8) #width = tf.cast(img_features['width'], tf.int32) #height = tf.cast(img_features['height'], tf.int32) img = img_features['img_raw'] img = tf.image.decode_png(img, channels=3) # 解码图片 png格式 jpg 使用 decode_jpeg() image = tf.reshape(img, [32, 32, 3]) # 32*32*3 这个需要根据你自己的格式进行修改 label = tf.cast(img_features['label'], tf.int32) #image = tf.reshape(image, [300,300,3]) image = tf.image.per_image_standardization(image) # 标准化处理 if shuffle: # 是否打乱数据顺序 如果capacity设置过小 会导致数据混合不完全 打乱数据读取会占用很多内存 image_batch, label_batch = tf.train.shuffle_batch( [image, label], batch_size = batch_size, num_threads= 64, capacity = 20000, min_after_dequeue = 1000) else: image_batch, label_batch= tf.train.batch( [image,label], batch_size = batch_size, num_threads = 64, capacity= 2000) image_batch = tf.cast(image_batch, tf.float32) # 转换为tf.float32 格式 if(one_hot == True): # 生成one_hot格式标签 one_hot格式标签 对应不同的loss 设置方式 label_batch = tf.one_hot(label_batch, depth= n_class) label_batch = tf.cast(label_batch, dtype=tf.int32) label_batch = tf.reshape(label_batch, [batch_size, n_class]) return image_batch, label_batch
线程读取数据
def Read_Record(filepath): with tf.Session() as sess: #开始一个会话 image,label = read_and_decode(filepath,batch_size=batch_size,shuffle=True,n_class=2,one_hot=False) init_op = tf.global_variables_initializer() sess.run(init_op) coord=tf.train.Coordinator() # 很重要的 threads= tf.train.start_queue_runners(coord=coord) try: for step in range(MAX_STEP): if coord.should_stop(): break img,lbe = sess.run([image,label])
# 添加你自己的模型 teain #plot_images(img,lbe,batch_size=batch_size)
except tf.errors.OutOfRangeError as e: print(e) finally: coord.request_stop() coord.join(threads)