tensorflow 1.X迁移至tensorflow2 的代码写法
本文主要介绍了tensorflow 1.X迁移至tensorflow2 的代码写法,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
目录
目标:代码改写成tf2格式
tf1和tf2区别:
改写内容:
tf.placeholder
tf.Sess,sess.run
具体例子1:
结论:
目标:代码改写成tf2格式
把tensorflow 1.X中的代码,迁移到tensorflow2中。一些常见的改写经验。包括sess,tf.placeholder, tf.InteractiveSession(),tf.Session()
tensorflow2相比于tensorflow 1.x版本有较大的变化,且网上现在好多文章的代码都是基于tf1.x版本的,学会简单的转换,帮助我们看代码。
整体来说,tensorflow2更加简洁了。
本文将持续更新中。
当然用tf.compat.v1也能解决部分问题。但是不推荐,毕竟tf2才是未来。
tf1和tf2区别:
1、tf1基于图模式,tf2基于eager模式,tf2对程序员更友好,更像是函数,更方便调试。
2、tf2更向keras靠拢,对分布式训练的支持更好。
改写内容:
tf.placeholder
tensorflow 1.x版本中的placeholder,在tf2中已经被取消,在tf2中,可以用tf.keras.Inputs代替。
示例:
tf1中
1 | input_ids = tf.placeholder(dtype = tf.int32, shape = [ None ]) |
tf2中,改写为:
1 | input_ids = tf.keras. Input (dtype = tf.int32, shape = [ None ]) |
tf.Sess,sess.run
tensorflow 1.x由于是基于静态图机制(Graph Execution),需要先构造图,然后才真正运行,因此需要用显示调用Session后,才会真正触发计算。对调试代码非常不利。
tensorflow 2.x默认是基于动态图机制(Eager Execution),就像常规函数一样,调用时就触发计算。对调试代码非常方便。
所以,tf1中session部分代码,可以全部去掉。
示例:
tf1中
1 2 | sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) |
tf2中,改写为:直接不要
具体例子1:
tf1的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import tensorflow as tf import numpy as np # 定义一个未知变量input_ids用于存储索引 input_ids = tf.placeholder(dtype = tf.int32, shape = [ None ]) # 定义一个已知变量embedding,是一个5*3的矩阵 embedding = a = np.asarray([[ 0.1 , 0.2 , 0.3 ], [ 1.1 , 1.2 , 1.3 ], [ 2.1 , 2.2 , 2.3 ], [ 3.1 , 3.2 , 3.3 ], [ 4.1 , 4.2 , 4.3 ]]) # 根据input_ids中的id,查找embedding中对应的元素 input_embedding = tf.nn.embedding_lookup(embedding, input_ids) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) # print(embedding.eval()) print (sess.run(input_embedding, feed_dict = {input_ids: [ 1 , 2 , 3 , 0 , 3 , 2 , 1 ]})) |
改写tf2代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 | import tensorflow as tf import numpy as np # 定义一个未知变量input_ids用于存储索引 input_ids = tf.keras. Input (dtype = tf.int32, shape = [ None ]) # 定义一个已知变量embedding,是一个5*3的矩阵 embedding = a = np.asarray([[ 0.1 , 0.2 , 0.3 ], [ 1.1 , 1.2 , 1.3 ], [ 2.1 , 2.2 , 2.3 ], [ 3.1 , 3.2 , 3.3 ], [ 4.1 , 4.2 , 4.3 ]]) input_ids = np.array([ 1 , 2 , 3 , 0 , 3 , 2 , 1 ]) # 根据input_ids中的id,查找embedding中对应的元素 input_embedding = tf.nn.embedding_lookup(embedding, input_ids) print (input_embedding) |
可见,tf2代码简洁明了不少,多动手试试,就能体会。
结论:
深刻体会tf2带来的变革。
1、体会静态图和动态图的差别
2、体会对分布式训练的优化(未来写)
3、体会模型训练的便利性(直接用compile等,keras的便利性。)
到此这篇关于tensorflow 1.X迁移至tensorflow2 的代码写法的文章就介绍到这了
原文链接:https://blog.csdn.net/weixin_43290383/article/details/121896799
伪原创工具 SEO网站优化 https://www.237it.com/