怎么新建MySQL数据库

发布时间:2022-09-16 作者:admin
阅读:282
在日常操作或是项目的实际应用中,有不少朋友对于“如何理解python神经网络tf.train.batch函数的使用”的问题会存在疑惑,下面小编给大家整理和分享了相关知识和资料,易于大家学习和理解,有需要的朋友可以借鉴参考,下面我们一起来了解一下吧。


当我在快乐的学习SSD训练部分的时候,我发现了一个batch我看不太懂,主要是因为tfrecords的数据读取方式我不理解,所以好好学一下batch吧

tf.train.batch函数

tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

其中:

1、tensors:利用slice_input_producer获得的数据组合。

2、batch_size:设置每次从队列中获取出队数据的数量。

3、num_threads:用来控制线程的数量,如果其值不唯一,由于线程执行的特性,数据获取可能变成乱序。

4、capacity:一个整数,用来设置队列中元素的最大数量

5、allow_samller_final_batch:当其为True时,如果队列中的样本数量小于batch_size,出队的数量会以最终遗留下来的样本进行出队;当其为False时,小于batch_size的样本不会做出队处理。

6、name:名字

测试代码

1、allow_samller_final_batch=True

import pandas as pd
import numpy as np
import tensorflow as tf
# 生成数据
def generate_data():
    num = 18
    label = np.arange(num)
    return label
# 获取数据
def get_batch_data():
    label = generate_data()
    input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)
    label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True)
    return label_batch
# 数据组
label = get_batch_data()
sess = tf.Session()
# 初始化变量
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# 初始化batch训练的参数
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
    while not coord.should_stop():
        # 自动获取下一组数据
        l = sess.run(label)
        print(l)
except tf.errors.OutOfRangeError:
    print('Done training')
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

运行结果为:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
[17]
Done training

2、allow_samller_final_batch=False

相比allow_samller_final_batch=True,输出结果少了[17]

import pandas as pd
import numpy as np
import tensorflow as tf
# 生成数据
def generate_data():
    num = 18
    label = np.arange(num)
    return label
# 获取数据
def get_batch_data():
    label = generate_data()
    input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)
    label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)
    return label_batch
# 数据组
label = get_batch_data()
sess = tf.Session()
# 初始化变量
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# 初始化batch训练的参数
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
    while not coord.should_stop():
        # 自动获取下一组数据
        l = sess.run(label)
        print(l)
except tf.errors.OutOfRangeError:
    print('Done training')
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

运行结果为:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
Done training


上述内容具有一定的借鉴价值,感兴趣的朋友可以参考,希望能对大家有帮助,想要了解更多"如何理解python神经网络tf.train.batch函数的使用"的内容,大家可以关注群英网络的其它相关文章。

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。

二维码-群英

长按识别二维码并关注微信

更方便到期提醒、手机管理

7*24 全天候服务

售前 400-678-4567

售后 0668-2555666

售后 400 678 4567

信息安全 0668-2555 118

域名空间 3004329145