tfrecord是tensorflow中常用的数据打包格式,这篇文章给大家介绍的就是关于tfrecord文件的生成和读取,本文有具体以及步骤,具有的一定的参考价值,需要的朋友可以参考学习。
训练模型时,我们并不是直接将图像送入模型,而是先将图像转换为tfrecord文件,再将tfrecord文件送入模型。为进一步理解tfrecord文件,本例先将6幅图像及其标签转换为tfrecord文件,然后读取tfrecord文件,重现6幅图像及其标签。
1、生成tfrecord文件
import os import numpy as np import tensorflow as tf from PIL import Image filenames = [ 'images/cat/1.jpg', 'images/cat/2.jpg', 'images/dog/1.jpg', 'images/dog/2.jpg', 'images/pig/1.jpg', 'images/pig/2.jpg',] labels = {'cat':0, 'dog':1, 'pig':2} def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) with tf.Session() as sess: output_filename = os.path.join('images/train.tfrecords') with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: for filename in filenames: #读取图像 image_data = Image.open(filename) #图像灰度化 image_data = np.array(image_data.convert('L')) #将图像转化为bytes image_data = image_data.tobytes() #读取label label = labels[filename.split('/')[-2]] #生成protocol数据类型 example = tf.train.Example(features=tf.train.Features(feature={'image': bytes_feature(image_data), 'label': int64_feature(label)})) tfrecord_writer.write(example.SerializeToString())
2、读取tfrecord文件
import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image # 根据文件名生成一个队列 filename_queue = tf.train.string_input_producer(['images/train.tfrecords']) reader = tf.TFRecordReader() # 返回文件名和文件 _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) # 获取图像数据 image = tf.decode_raw(features['image'], tf.uint8) # 恢复图像原始尺寸[高,宽] image = tf.reshape(image, [60, 160]) # 获取label label = tf.cast(features['label'], tf.int32) with tf.Session() as sess: # 创建一个协调器,管理线程 coord = tf.train.Coordinator() # 启动QueueRunner, 此时文件名队列已经进队 threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(6): image_b, label_b = sess.run([image, label]) img = Image.fromarray(image_b, 'L') plt.imshow(img) plt.axis('off') plt.show() print(label_b) # 通知其他线程关闭 coord.request_stop() # 其他所有线程关闭之后,这一函数才能返回 coord.join(threads)
以上就是关于怎样实现tfrecord文件生成与读取的操作介绍,希望文本对大家学习有帮助,想要了解更多tfrecord文件生成与读取的内容大家可以关注其他相关文章。
文本转载自脚本之家免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
长按识别二维码并关注微信
更方便到期提醒、手机管理