阅读 389

Tensorflow--图像数据处理(用tensorflow对图像进行训练)

1.TFRecord输入数据格式

1.1 TFRecord格式介绍

TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下代码给出了定义:

message Example{     Features features = 1; } message Features{     map<String, Features> feature = 1; } message Feature{         oneof kind{         BytesList bytes_list = 1;         Floatlist float_list = 2;         Int64List int64_list = 3;     } } 复制代码

从以上代码可以看出,tf.train.Example的数据结构是比较简洁的。tf.train.Example中包含了一个从属性名称到取值的字典。其中属性名称为一个字符串,属性的取值可以为字符串、实数列表或者整数列表。

1.2 TFRecord样例程序

# 先定义一个FileName filename = "......" # 创建一个writer来写TFRecord文件 writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples):     # 将图像矩阵转化成一个字符串     image_raw = images[index].toString     # 将一个样例转化成Example Protocol Buffer。并将所有信息写入这个数据结构     example = tf.train.Example(features=tf.train.Features(feature={         'pixels':_int64_feature(pixels),         'label':_int64_feature(np.argmax(labels[index])),         'image_raw':_bytes_feature(image_raw)}))     # 将一个Example写入TFRecord文件     writer.write(example.SerializeToString) writer.close() 复制代码

2.图像数据处理

2.1 图像编码处理

一张RGB色彩模式的图像可以看成一个三维矩阵,矩阵中的每一个数表示了图像上不同位置,不同颜色的亮度。但在存储时,并不是记录矩阵上的数字,而是记录经过压缩编码之后的结果,这时就需要编码、解码函数。以下为tensorFlow示例代码:

# matplotlib.pyplot是一个python的画图工具 import matplotlib.pyplot as plt import tensorflow as tf # 读取图像的原始数据 image_raw_data = tf.gfile.FastGFile("文件路径名", r).read() with tf.Session() as sess:     # 对图像进行jepg格式解码从而得到图相对应的三维矩阵。TensorFlow还提供了     # tf.image_decode_png函数对png格式的图像进行解码。解码之后的结果为一个张量     # 在使用它的取值之前需要明确调用运行的过程     image_data = tf.image.deode_jepg(image_raw_data)     # 接下来就用pyplot可视化工具展示图像 # 将表示一张图像的三维矩阵重新按照jepg格式编码并存入文件中。打开这张图像 # 可以得到和原始图像一样的数据 encoded_image = tf.image.encode_jpeg(img_data) with tf.gfile.GFile("/path/to/output", "wb") as f:     f.write(encoded_image.eval()) 复制代码

2.2 图像大小调整

获取的图像大小是不一的,但神经网络输入的节点的个数是固定的,所以需要在输入之前统一大小。

图像大小调整有两种方式,第一中通过算法使得新的图像尽量保存原始图像上的所有信息。TensorFlow提供了四种不同的方法,并封装到了tf.image.resize_images函数。以下代码为示例程序:

# 首先将图片数据转化为实数类型。这一步将0~255的像素值转化为0.0-1.0范围内的实数。 # 大多数图像处理API支持整书和实数类型的输入。如果输入是整数类型,这些API会在内部将输入转化为实数处理 # 再输出转化为整数。如果有多个处理步骤,在整数和实数之间的反复转化将导致精度损失,因此推荐在图像处理前将其 # 转化为实数类型。下面的样例将略去这一步骤,假设img_data是经过类型转化的图像 img_data = tf.image.convert_image_dtype(img_data, dtype=tf.float32) # 通过tf.image.resize_images函数调整图像的大小。这个函数的第一个参数为原始图像 # 第二和第三个参数为调整后图像的大小,method参数给出了调整图像大小的算法 # 注意,如果输入数据是unit8格式,那么输出将是0~255之内的实数,不方便后续处理,故建议转化为实数类型 resized = tf.image.resize_images(img_data, [300, 300], method=0) 复制代码

下表给出了tf.image.resize_images函数的method参数取值对应的图像大小调整算法:

Method取值图像大小调整算法
0双线性插值法(Bilinear interpolation)
1最近邻居法(Nearest neighbor interpolation)
2双三次插值法(Bicubic interpolation)
3面积插值法(Area interpolation)

除了将整张图像信息完整保存,TensorFlow还提供了API对图像进行裁剪或者填充。以下为示例代码:

# 通过tf.image.resize_image_with_crop_or_pad函数调整图像的大小。这个函数的 # 第一个参数为原始图像,后面两个参数是调整后的目标图像大小。如果原始图像的尺寸大于 # 目标图像,这个函数会自动在视奏填充全0背景。因为原始图像大小为1797x2673 # 所以第一个命令会自动裁剪,第二个命令会自动填充 croped = tf.image.resize_image_with_crop_or_pad(image_data, 1000, 1000) padded = tf.image_resize_image_with_crop_or_pad(image_data), 3000, 3000) 复制代码

TensorFlow还支持通过比例调整图像大小,以下为示例代码:

# 通过tf.image.central_crop函数可以按比例裁剪图像。该函数的第一个参数为原始图像,第二个为调整比例。这个比例需要是一个(0,1]的实数 central_cropped = tf.image.central_crop(img_data, 0.5) 复制代码

TensorFlow还提供了tf.image.crop_to_bounding_box函数和 tf.image.pad_to_bounding_box函数来裁剪或填充给定区域的图像。不过这两个函数都要求给出来的尺寸满足一定的要求。否则会报错。比如在使用crop函数时,tf要求提供的图像尺寸要大于目标尺寸。

2.3 图像翻转

TensorFlow提供了一些函数来支持对图像的翻转。以下代码实现了上下翻转、左右翻转以及沿对角线翻转,还有随机翻转:

# 上下翻转 flipped = tf.image.flip_up_down(img_data) # 左右翻转 flipped = tf.image.flip_left_right(img_data) # 对角线翻转 transposed = tf.image.transpose_image(img_data) # 随机翻转训练图像 # 以%50概率上下翻转 flipped = tf.image_random_flip_up_down(img_data) # 以%50概率左右翻转 flipped = tf.image_random_flip_left_right(img_data) 复制代码

2.4 图像色彩调整

调整亮度、对比度、饱和度和色相。在训练神经网络时,可以随机调整训练图像的这些属性,从而使得训练得到的模型尽可能小地受到无关因素的影响。tf提供了相关API,以下代码展示如何修改亮度:

# 将图像的亮度-0.5 adjusted = tf.image.adjust_brightness(img_data, -0.5) # 色彩调整的API可能导致像素的实数值超出0.0-1.0的范围,因此在输出最终图像前需要 # 将其值截断在0.0-1.0范围区间内,否则不仅图像无法正常可视化,以此为输入的神经网络 # 的训练质量也可能受到影响 # 如果对图像进行多项处理操作,那么这一截断过程应当在所有处理完成后进行。举个例子, # 假如对图像以此提高亮度和减少对比度,那么第二个操作可能将第一个操作生成的部分 # 过亮的像素拉回到不超过1.0的范围内,因此在第一个操作后不应该立即截断 adjusted = tf.clip_by_value(adjusted, 0.0, 1.0) adjusted = tf.image.adjust_brightness(img_data, 0.5) # 在[-max_delta, max_delta]的范围随机调整图像的亮度 adjusted = tf.image.random_brightness(image, max_delta) 复制代码

以下代码显示了如何调整图像的对比度:

# 将图像的对比度减少到0.5倍 adjusted = tf.image.adjust_contrast(img_data, 0.5) # 将图像的对比度增加5倍 adjusted = tf.image.adjust_contrast(img_data, 5) # 在[lower, upper]的范围内随机调整图的对比度 adjusted = tf.image.random_contrast(image, lower, upper) 复制代码

以下代码展示了如何调整图像的色相:

# 下面四条命令分别将色相增加0.1, 0.3, 0.6, 0.9 adjusted = tf.image.adjust_hue(img_data, 0.1) adjusted = tf.image.adjust_hue(img_data, 0.3) adjusted = tf.image.adjust_hue(img_data, 0.6) adjusted = tf.image.adjust_hue(img_data, 0.9) # 在[-max_delta, max_delta]的范围内随即调整图像的色相 # max_delta的取值在[0, 0.5]之间 adjusted = tf.image.random_hue(image, maxdelta) 复制代码

以下代码展示了如何调整图像饱和度:

# 将图像的饱和度-5 adjusted = tf.image.adjust_saturation(img_data, -5) #在[lower, upper]内随机调整饱和度 adjusted = f.image.random_saturation(image, lower, upper) 复制代码

TensorFlow还提供了API完成图像标准化操作,就是亮度均值变成0,方差变成1:

# 将代表一张图像的三维矩阵中的数字均值变成0,方差变成1 adjusted = tf.image.per_image_standardization(img_data) 复制代码

2.5 处理标注框

在图像中需要突出的部分添加标注框:

# 将图像缩小一些,这样可视化能让标注框更加清楚 img_data = tf.image.resize_images(img_data, [180, 267], method=1) # tf.image.draw_bounding_boxes函数要求图像矩阵中的数字为实数,所以需要先将 # 图像矩阵转化为实数类型。tf.image.draw_bounding_boxes函数图像的输入是一个batch的数据 # 也就是多张图像组成的四维矩阵,所以需要将解码后的图像矩阵加一维。 batched = tf.expend_dims(tf.image.convert_image_dtype(img_data, tf.float32), 0) # 给出每一张图像的所有标注框,一个标注框有4个数字,分别代表[ymin, xmin, ymax, xmax] # 这里给出的数字都是图像的相对位置。比如在180x267图像中, # [0.35, 0.47, 0.5, 0.56]代表了从(63, 125)到(90, 150)的图像 boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]]) result = tf.imagedraw_bounding_boxes(batched, boxes) 复制代码

随机截取图像上有信息含量的部分也是一个提高模型健壮性的一种方式。这样可以使训练得到的模型不受被识别物体大小的影响。以下代码展示了如何通过tf.image.sample_distorted_bounding_box函数来完成随机截取图像的过程:

boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, o.56]]]) # 可以通过提供标注框的方式来告诉随机截取图像的算法哪些部分是“有信息量”的 # min_object_covered = 0.4表示截取部分至少包含某个标注框%40的内容。 begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(     tf.shape(img_data), bounding_boxes = boxes,     min_object_covered = 0.4) # 通过标注框可视化截取得到的图像 batched = tf.expend_dims(     tf.image.convert_image_dtype(img_data, tf.float32), 0) image_with_box = tf.image.draw_boudning_boxes(batched, bbox_for_draw) # 截取随机出来的图像,因为算法带有随机成分,所以每次结果都会有所不同 distorted_image = tf.slice(img_data, bigin, size) 复制代码

2.6 图像预处理总结

  • 给定一张图像,随即调整图像的色彩,调整亮度、对比度、饱和度,可自定义先后顺序;

  • 如果没有提供标注框,则认为整个图像就是需要关注的部分;

  • 转换图像张量的类型;

  • 随机截取图像,减小需要关注的物体大小对图像识别算法的影响;

  • 将随机截取的图像调整为神经网络输入层的大小,大小调整的算法是随机的;

  • 随机左右翻转图像;

  • 使用一种随机的顺序调整图像色彩。

3.多线程输入数据处理框架

4.png

3.1 队列和多线程

在TensorFlow中,队列和变量类似,都是计算图上有状态的节点。其他的计算节点可以修改它们的状态。对于变量,可以通过赋值操作修改变量的取值。对于队列,修改队列状态的操作主要有Enqueue、EnqueueMany和Dequeue。以下代码展示如何操作一个队列:

import tensorflow as tf # 创建一个先进先出的队列,指定队列最多保存两个元素,并指定类型为整型 q = tf.FIFOQueue(2, "int32") # 使用enqueue_many函数来初始化队列中的元素。类似变量的初始化,使用队列前必须得初始化 init = q.enqueue_many(([0, 10], )) # 使用dequeue将第一个元素出队列 x = q.dequeue() # 后面的就浅显易懂了 y = x + 1 q_inc = q.enqueue(y) with tf.Session() as sess:     init.run()     for _ in range(5):          v, _ = sess.run([x, q_inc])         print(v) 复制代码

TensorFlow中提供了FIFOQueue和RandomShuffleQueue两种队列,后者会将队列的元素打乱,每次出队列的操作得到的是从当前队列所有元素中随机选择的一个。在TensorFlow中,队列不仅仅是一种数据结构,还是异步计算张量取值的一个重要机制。比如多个线程可以同时向一个队列中写元素,或者同时读取队列中的元素。

TensorFlow提供了tf.Coordinator和tf.QueueRunner两个类来完成多线程协同的功能。t.Coordinator主要用于协同多个线程一起停止,并提供了should_stop、requets_stp和join三个函数。在启动线程前,需要先声明一个tf.Coordinator类,并将这个类传入每一个创建的线程中。启动的线程需要一直查询tf.Coordinator类中提供的should_stop函数,当这个函数的返回值为true时,则当前线程也需要退出。每一个启动的线程都可以通过调用request_stop函数来通知其他线程退出。当某一个线程调用request_stop函数之后,should_stop函数的返回值就被设置成了true,这样其他线程就可以终止了,以下为示例代码:

import tensorflow as tf import numpy as np import threading import time # 线程中运行的程序,这个程序每隔1秒判断是否需要停止并打印自己的ID def MyLoop(coord, worker_id):     # 使用tf.Coordinator类提供的协同工具判断当前线程是否需要停止     while not coord.should_stop():         # 随机停止所有的线程         if np.random.rand() < 0.1:             print("Stoping from id:%d\n" % worker_id)             # 调用crod.request_stop()函数来通知其他线程停止             coord.request_stop(0)         else:             # 打印当前线程的id             print("Working on id:%d\n" % worker_id)             # 暂停1秒             time.sleep(1) # 声明一个tf.train.Coordinator类来协同多个线程 coord = tf.rain.Coordinator() # 声明创建5个线程 threads = [     threading.Thread(target=MyLoop, args=(coord, i, )) for i in range(5)] # 启动所有线程 for t in threads:t.start() # 等待所有线程退出 coord.join(threads) 复制代码

tf.QueueRunner主要用于启动多个线程来操作同一个队列,启动的这些线程可以通过tf.Coordinator来统一管理。

以下代码为tf.QueueRunner和tf.Coordinator来管理多线程队列操作:

import tensorflow as tf # 声明一个先进先出的队列,队列中最多100个元素,类型为实数 queue = tf.FIFOQueue(100, "float") # 定义队列的入队操作 enqueue_op = queue.enqueue([tf.random_normal([1])]) # 使用tf.train.QueueRunner来创建多个线程娙队列的入队操作 # tf.train.QueueRunner的第一个参数给出了被操作的队列,[enqueue_op] * 5 # 表示了需要启动5个线程,每个线程中运行的是enqueue_op操作 qr = tf.train.QueueRunner(queue, [enqueue_op] * 5) # 将定义过的QueueRunner加入TensorFlow计算图上指定的集合 # tf.train.add_queue_runner函数没有指定集合 # qr加入默认的tf.GraphKeys.QUEUE_RUNNERS集合 tf.train.add_queue_runner(qr) # 定义出队操作 out_tensor = queue.dequeue() with tf.Session() as sess:     # 使用tf.train.Coordinator来协同启动线程     coord = tf.train.Coordinator()     '''     使用tf.train.QueueRunner时,需要明确调用tf.train.start_queue_runners来启动所有的线程     否则因为没有线程运行入队操作,当调用出队操作时,程序会一直等待入队操作被运行。     tf.train.start_queue_runners函数会默认启动tf.GraphKeys.QUEUE_RUNNERS集合中所有     的QueueRunner。因为这个函数只支持启动指定集合中的QueueRunner,所以一般来说     tf.train.add_queue_runner函数和tf.train.start_queue_runners函数会指定同一个集合     '''     threads = tf.train.start_queue_runners(sess=sess, coord=coord)     # 获取队列中的取值     for _ in range(3): print(sess.run(out_tensor)[0])          # 使用tf.train.Coordinator来停止所有的线程     coord.request_stop()     coord.join(threads) 复制代码

3.2 输入文件队列

  • TensorFlow提供了tf.train.match_filenames_once函数来获取符合一个正则表达式的所有文件,得到的文件列表可以通过tf.train.string_input_producer函数进行有效的管理;

  • tf.train.string_input_producer函数会使用初始化时提供的文件列表创建一个输入队列,输入队列中原始的元素为文件列表中的所有文件;

  • 通过设置shuffle参数,tf.train.strin_input_producer函数支持随机打乱文件列表中文件出队的顺序。当shuffle参数为True时,文件在加入队列之前会被打乱顺序,所以出队的顺序也是随机的,随机打乱文件顺序以及加入输入队列的过程会跑在一个单独的线程上,这样不会影响获取文件的速度;

  • 当一个输入队列中的所有文件都被处理完后,它会将初始化时提供的文件列表中的文件全部重新加入队列。

tf.train.string_input_producer函数可以设置num_epochs参数来限制加载初始文件列表的最大轮数,以下代码展示了如何生成样例:

import tensorflow as tf # 创建TFRecord文件的帮助函数 def _int64_feature(value):     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 模拟海量数据情况下将数据写入不同的文件。num_shards定义了总共写入多少个文件。 # instances_per_shard定义了每个文件中有多少数据 num_shards = 2 instance_per_shard = 2 for i in range(num_shards):     '''     将数据分为多少个文件时,可以将不同文件以类似0000n-of-0000m的后缀区分。其中m表示     了数据总共被存在了多少个文件中,n表示当前文件的编号。式样的方式既方便了通过正     则表达式获取文件列表,又在文件名中加入了更多的信息     '''     filename = ('/path/to/data.tfrecords-%.5d-of-%.5d' % (i, num_shards))     writer = tf.python_io.TFRecordWriter(filename)     # 将数据封装成Example结构并写入TFRecord文件     for j in range(instance_per_shard):         example = tf.train.Example(features=tf.train.Features(features={             'i': _int64_feature(i),             'j': _int64_feature(j)         }))         writer.write(example.SerializeToString()) writer.close() 复制代码

上述程序运行后,会在目标文件下生成两个文件,每个文件中都存储了样例,在生成了样例数据之后,以下代码展示了tf.train.match_filenames_once函数和tf.train.string_input_producer函数的使用方法:

import tensorflow as tf # 使用tf.train.match_filenames_once函数获取文件列表 files = tf.train.match_filenames_once("/path/to/data.tfrecords-*") # 通过tf.train.string_input_producer函数创建输入队列,输入队列中的文件列表为 # tf.train.match_filenames_once函数获取的文件列表这里将shuffle参数设置为False # 来避免随机打乱文件顺序。但一般在解决真实问题时,会将shuffle参数设置为True filename_queue = tf.train.string_input_producer(files, shuffle=False) # 解析一个样本 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(     serialized_example,     features={         'i': tf.FixedLenFeature([], tf.int64),         'j': tf.FixedLenFeature([], tf.int64)     } ) with tf.Session() as sess:     # 初始化一些变量     tf.local_variables_initializer().run()     print(sess.run(files))          # 声明tf.train.Coordinator类来协同不同线程,并启动线程     coord = tf.train.Coordinator()     threads = tf.train.start_queue_runners(sess=sess, coord=coord)          # 多次执行获取数据的操作     for i in range(6):         print(sess.run([features['i'], features['j']]))     coord.request_stop()     coord.join(threads) 复制代码

3.3 组合训练数据

将单个数据组织成batch。TensorFlow提供了tf.train.batch和f.train.shuffle_batch函数来将单个的样例组织成bantch的形式输出这两个函数都会生成一个队列,队列的入队操作是生成单个样例的方法,而每次出队得到的是一个batch的样例。它们唯一的区别在于是否会将顺序打乱。以下为代码展示:

# 这里example结构中i表示一个样例的特征向量 # 比如一张图像的矩阵,而j表示该样例对应的标签 example, label = features['i'], features['j'] # 一个batch中样例的个数 batch_size = 3 # 组合样例的队列中最多可以存储的样例个数。这个队列如果太大,那么需要占用很多内存资源 # 如果太小,那么出队操作可能会因为没有数据而被阻碍,从而导致训练效率降低,一般来说, # 这个队列的大小会和每一个batch的大小相关,下面给出了设置队列大小的一种方式 capacity = 1000 + 3 * batch_size # 使用tf.train.batch来组合样例。[example, label]给出了需要组合的元素 # 一般example和label分别代表训练样本和这个样本对应的正确标签。batch_size参数给出了 # 每个batch中样例的个数。capacity给出了队列的最大容量。当队列长度等于容量时, # TensorFlow将自动重新启动入队操作 example_batch, label_batch = tf.train.batch(     [example, label], batch_size=batch_size, capacity=capacity ) with tf.Session() as sess:     tf.initialize_all_variables().run()     coord = tf.train.Coordinator()     threads = tf.train.start_queue_runners(sess=sess, coord=coord)          # 获取并打印组合之后的样例。在真实问题中,这个输出一般会作为神经网络的输入     for i in range(2):         cur_example_batch, cur_label_batch = sess.run(             [example_batch, label_batch]         )         print(cur_example_batch, cur_label_batch)     coord.request_stop()     coord.join(threads) 复制代码

下面代码展示了tf.train.shuffle_batch函数的使用方法:

example, label = features['i'], features['j'] ''' 使用tf.train.shuffle_batch函数来组合样例。tf.train.shuffle_batch函数 的参数大部分都和tf.train.batch函数相似,但是min_after_dequeue参数限制了出队时队列中元素的 最少个数。当队列中元素个数太少时,随机打乱样例顺序的作用就不大了。所以 tf.train.shuffle_batch函数提供了限制出队时最少元素的个数来保证随机打乱顺序的作用,当出队函数被调用 但是队列元素个数不够时,出队操作将等待更多的元素入队才会完成, 如果min_after_dequeue参数被设定,capacity也应该相应调整来满足性能需求 ''' example_batch, label_batch = tf.train.shuffle_batch(     [example, label], batch_size=batch_size,     capacity=capacity, min_after_dequeue=30 ) 复制代码

  • 通过设置tf.train.shuffle_batch函数中的num_threads参数,可以指定多个线程同时执行入队操作。当num_threads参数大于1时,多个线程会同时读取一个文件中的不同样例并进行预处理;

  • 对于tf.train.shuffle_batch函数,不同的线程会读取同一个文件。如果一个文件中的样例比较相似,那么神经网络的训练效果可能会受到影响,所以在使用tf.train.shuffle_batch函数时,需要尽量将同一个FRecord文件中的样例随机打乱。而使用tf.train.shuffle_batch_join函数时,不同线程会读取不同文件,如果读取数据的线程数比总文件数还大,那么多个线程可能会读取同一个文件中相近部分的数据。而且多个线程读取多个文件可能会导致过多的硬盘寻址,从而使得读取效率降低。

3.4 输入数据处理框架

import tensorflow as tf import temp # 创建文件列表,并通过文件列表创建输入文件队列。在调用输入数据处理流程前,需要统一 # 所有原始数据的格式并将它们存储到TFRecord文件中。下面给出的文件列表应该包含所有 # 提供训练数据的TFRecord files = tf.train.match_filenames_once("/path/to/file_pattern-*") filename_queue = tf.train.string_input_producer(files, shuffle=False) # 解析数据,这里假设Image中存的是图像的原始数据,label为该样例对应的标签 # height、width和channels给出了图片的维度 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(     serialized_example,     features={         'image': tf.FixedLenFeature([], tf.string),         'label': tf.FixedLenFeature([], tf.int64),         'height': tf.FixedLenFeature([], tf.int64),         'width': tf.FixedLenFeature([], tf.int64),         'channels': tf.FixedLenFeature([], tf.int64),     } ) image, label = features['image'], features['label'] height, width = features['height'], features['width'] channels = features['channels'] # 从原始图像数据解析出像素矩阵,并根据图像尺寸还原图像 decoded_image = tf.decode_raw(image, tf.unit8) decoded_image.set_shape([height, width, channels]) image_size = 299 distorted_image = temp.preprocess_for_train(decoded_image, image_size, image_size, None) # 将处理后的图像和标签数据通过tf.train.shuffle_batch整理成神经网络训练时需要的batch min_after_dequeue = 10000 batch_size = 100 capacity = min_after_dequeue + 3 * batch_size image_batch, label_batch = tf.train.shuffle_batch(     [distorted_image, label], batch_size=batch_size, capacity=capacity,     min_after_dequeue=min_after_dequeue ) # 定义神经网络的结构及优化过程。image_batch可以作为输入提供给神经网络的输入层。 # label_batch则提供了输入batch中样例的正确答案 learning_rate = 0.01 logit = inference(image_batch) loss = calc_loss(logit, label_batch) train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) # 声明会话并运行神经网络的优化过程 with tf.Session() as sess:     # 神经网络训练准备工作,包括变量初始化,线程启动     sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))     coord = tf.train.Coordinator()     threads = tf.train.start_queue_runners(sess=sess, coord=coord)          # 神经网络训练过程     TRAINING_ROUNDS = 5000     for i in range(TRAINING_ROUNDS):         sess.run(train_step)          # 停止所有线程     coord.request_stop()     coord.join(threads) 复制代码

输入数据处理流程:

  • 第一步,获取存储训练数据的文件列表;

  • 第二步,通过tf.train.string_input_producer函数,选择性地将文件列表中文件的顺序打乱;

  • 第三步,获取图像数据并进行预处理,会通过tf.train.shuffle_batch提供的机制并行地跑在多个线程中;

  • 第四步,tf.train.shuffle_batch函数将处理好的单个输入样例整理成batch提供给神经网络的输入层。

4.数据集(DataSet)

4.1 数据集的基本用法

在数据集框架中,每一个数据集代表一个数据来源:数据可能来自一个张量,一个TFRecord格式文件,或者经过sharding的一系列文件,等等。由于训练数据通常无法全部写入内存中,从数据集中读取数据时需要使用一个迭代器(iterator)按顺序进行读取。与队列相似,数据集也是计算图上的一个节点。

利用数据集读取数据有以下步骤:

  1. 定义数据集的构造方法;

  2. 定义遍历器;

  3. 使用get_next()方法从遍历器中读取数据张量,作为计算图其他部分的输入。

在真实项目中,训练数据通常是保存在硬盘文件上的。这时可以用TextLineDataset来更方便的读取数据:

import tensorflow as tf # 从文本文件创建数据集,假定每行文字是一个训练例子。这里可以提供多个文件 input_files = ["/path/to/input_file1", "/path/to/input_file2"] dataset = tf.data.TextLineDataset(input_files) # 定义迭代器用于遍历数据集 iterator = dataset.make_one_shot_iterator() # 这里用get_next()返回一个字符串类型的张量,代表文件的一行 x = iterator.get_next() with tf.Session() as sess:     for i in range(3):         print(sess.run(x)) 复制代码

在图像相关任务中,输入数据通常以TFRecoed数据格式存储,这时可以用TFRecordDataset来读取数据。与文本文件不同,每一个TFRecord都有自己不同的feature格式,因此在读取TFRecord时,需要提供一个parse函数来解析所读取的TFRecoed的数据格式:

import tensorflow as tf # 解析一个TFRecord的方法,record时从文件列表中读取的一个样例 def parser(record):     features = tf.parse_single_example(         record,         features={             'feat1': tf.FixedLenFeature([], tf.int64),             'feat2': tf.FixedLenFeature([], tf.int64),         }     )     return features['feat1'], features['feat2'] # 从TFRecord文件创建数据集 input_files = ["/path/to/input_file1", "/path/to/input_file2"] dataset = tf.data.TFRecordDataset(input_files) # map函数表示对数据集中的每一条数据进行调用相应方法。使用TFRecordDataset读出来的是二进制 # 数据,这里需要通过map()来调用parser()对二进制数据进行解析。类似的,map()函数也可以用来 # 完成其他数据与处理工作 dataset = dataset.map(parser) # 定义迭代器 iterator = dataset.make_one_shot_iterator() # feat1, feat2是parser()返回的一维int64型张量,可以作为输入用于进一步计算 feat1, feat2 = iterator.get_next() with tf.Session as sess:     for i in range(10):         f1, f2 = sess.run(feat1, feat2) 复制代码

上面这个例子使用了最简单的one_shot_iterator来遍历数据集。在使用one_shot_iterator时,数据集上所有的参数必须确定,因此one_shot_iterator不需要特别的初始化过程。如果需要用到placeholder来初始化数据集,就需要用到initializable_iterator。以下为示例:

import tensorflow as tf # 解析一个TFRecord的方法,record时从文件列表中读取的一个样例 def parser(record):     features = tf.parse_single_example(         record,         features={             'feat1': tf.FixedLenFeature([], tf.int64),             'feat2': tf.FixedLenFeature([], tf.int64),         }     )     return features['feat1'], features['feat2'] # 从TFRecord文件创建数据集,具体文件路径是一个placeholder,稍后再提供具体路径 input_files = tf.placeholder(tf.string) dataset = tf.data.TFRecordDataset(input_files) dataset = dataset.map(parser) # 定义迭代器 iterator = dataset.make_initializable_iterator() # feat1, feat2是parser()返回的一维int64型张量,可以作为输入用于进一步计算 feat1, feat2 = iterator.get_next() with tf.Session as sess:     sess.run(iterator.initializer, feed_dict={input_files:                                                   ["/path/to/iput_file1", ["/path/to/input_file2"]]})     while True:         try:             sess.run([feat1, feat2])         except tf.errors.OutOfRangeError:             break 复制代码

除了以上两种迭代器,谷歌还提供了reinitializable_iterator和feedable_iterator两种更加灵活的迭代器,前者可以多次initialize用于遍历不同的数据来源,而后者可以用feed_dict的方式动态指定运行哪个Iterator。

4.2 数据集的高层操作

dataset = dataset.map(parser) 复制代码

map是在数据集上进行操作的最常用的方法之一,在这里,map(parser)方法表示对数据集中的每一条数据调用参数中指定的parser方法。对每一条数据进行处理后,map将处理后的数据包装成一个新的数据集返回。map函数非常灵活。如下,在数据集框架中,可以用map来对每一条数据调用preprocess_for_train方法:

dataset = dataset.map(lambda x :preprocess_for_train(x, image_size, image_size, None)) 复制代码

在上面的代码中,lambda表达式的作用是将原来有4个参数的函数转化为只有1个参数的函数。

dataset = dataset.shuffle(buffer_size) # 随机打乱顺序 dataset = dataset.batch(batch_size) # 将数据组合成batch 复制代码

其中shuffle方法的参数buffer_size等效于tf.train.shuffle_batch的min_after_dequence参数。shuffle算法在内部使用一个缓冲区保存buffer_size条数据,每读入一条新数据时,从这个缓冲区中随机选择一条数据进行输出。

repeat是另一个常用的操作方法,这个方法将数据集中的数据复制多份,其中每一份数据被称为一个epoch。

dataset = dataset.repeat(N) # 将数据集重复N份 复制代码

除了这些方法以外,数据集还提供了其他多种操作,例如,concatentate()将两个数据集顺序连接起来,take(N)从数据集中读取前N项数据,skip(N)在数据集中跳过前N项数据,flap_map从多个数据集中轮流读取数据。


作者:RIP_Mamba
链接:https://juejin.cn/post/7035407965658021902

 伪原创工具 SEO网站优化  https://www.237it.com/ 


文章分类
代码人生
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐