阅读 314

tensorflow 1.X迁移至tensorflow2 的代码写法

本文主要介绍了tensorflow 1.X迁移至tensorflow2 的代码写法,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

目录
  • 目标:代码改写成tf2格式

  • tf1和tf2区别:

  • 改写内容:

    • tf.placeholder

    • tf.Sess,sess.run

  • 具体例子1:

    • 结论:

      目标:代码改写成tf2格式

      把tensorflow 1.X中的代码,迁移到tensorflow2中。一些常见的改写经验。包括sess,tf.placeholder, tf.InteractiveSession(),tf.Session()
      tensorflow2相比于tensorflow 1.x版本有较大的变化,且网上现在好多文章的代码都是基于tf1.x版本的,学会简单的转换,帮助我们看代码。
      整体来说,tensorflow2更加简洁了。
      本文将持续更新中。

      当然用tf.compat.v1也能解决部分问题。但是不推荐,毕竟tf2才是未来。

      tf1和tf2区别:

      1、tf1基于图模式,tf2基于eager模式,tf2对程序员更友好,更像是函数,更方便调试。

      2、tf2更向keras靠拢,对分布式训练的支持更好。

      改写内容:

      tf.placeholder

      tensorflow 1.x版本中的placeholder,在tf2中已经被取消,在tf2中,可以用tf.keras.Inputs代替。
      示例:
      tf1中

      1
      input_ids = tf.placeholder(dtype=tf.int32, shape=[None])

      tf2中,改写为:

      1
      input_ids = tf.keras.Input(dtype=tf.int32, shape=[None])

      tf.Sess,sess.run


      • tensorflow 1.x由于是基于静态图机制(Graph Execution),需要先构造图,然后才真正运行,因此需要用显示调用Session后,才会真正触发计算。对调试代码非常不利。

      • tensorflow 2.x默认是基于动态图机制(Eager Execution),就像常规函数一样,调用时就触发计算。对调试代码非常方便。

      所以,tf1中session部分代码,可以全部去掉。

      示例:

      tf1中

      1
      2
      sess = tf.InteractiveSession()
      sess.run(tf.global_variables_initializer())

      tf2中,改写为:直接不要

      具体例子1:

      tf1的代码:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      import tensorflow as tf
      import numpy as np
       
      # 定义一个未知变量input_ids用于存储索引
      input_ids = tf.placeholder(dtype=tf.int32, shape=[None])
       
      # 定义一个已知变量embedding,是一个5*3的矩阵
      embedding = a = np.asarray([[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]])
       
      # 根据input_ids中的id,查找embedding中对应的元素
      input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
       
      sess = tf.InteractiveSession()
      sess.run(tf.global_variables_initializer())
      # print(embedding.eval())
      print(sess.run(input_embedding, feed_dict={input_ids: [1, 2, 3, 0, 3, 2, 1]}))

      改写tf2代码:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      import tensorflow as tf
      import numpy as np
       
      # 定义一个未知变量input_ids用于存储索引
      input_ids = tf.keras.Input(dtype=tf.int32, shape=[None])
       
      # 定义一个已知变量embedding,是一个5*3的矩阵
      embedding = a = np.asarray([[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]])
       
      input_ids = np.array([1, 2, 3, 0, 3, 2, 1])
      # 根据input_ids中的id,查找embedding中对应的元素
      input_embedding = tf.nn.embedding_lookup(embedding, input_ids)
      print(input_embedding)

      可见,tf2代码简洁明了不少,多动手试试,就能体会。

      结论:

      深刻体会tf2带来的变革。
      1、体会静态图和动态图的差别
      2、体会对分布式训练的优化(未来写)
      3、体会模型训练的便利性(直接用compile等,keras的便利性。)

      到此这篇关于tensorflow 1.X迁移至tensorflow2 的代码写法的文章就介绍到这了

      原文链接:https://blog.csdn.net/weixin_43290383/article/details/121896799

      伪原创工具 SEO网站优化  https://www.237it.com/ 


      文章分类
      代码人生
      版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
      相关推荐