TensorFlow中 tf.space_to_depth()函数的用法

目录

一、函数定义

二、解释范例

三、代码验证


一、函数定义

通俗易懂些,就是把输入为[batch, height, width, channels]形式的Tensor,其在height和width维的值将移至depth维


  
  1. space_to_depth(
  2. input,
  3. block_size,
  4. data_format="NHWC",
  5. name=None)

input:就是输入的Tensor

block_size:就是块大小,一般≥2

data_format:数据形式,如下表,默认“NHWC

NHWC [ batch, height, width, channels ]
NCHW [ batch, channels, height, width ]
NCHW_VECT_C qint8 [ batch, channels / 4, height, width, 4 ]

name:名字


*输出张量的深度是block_size * block_size * input_depth


二、解释范例

一共三种Tensor:

(1)输入大小为:[1, 2, 2, 1]


  
  1. test_tensor_1 = [[[[1], [2]],
  2. [[3], [4]]]]

输出大小:[1, 1, 1, 4]

[[[[1 2 3 4]]]]
 

(2)输入大小为:[1, 2, 2, 3]


  
  1. test_tensor_2 = [[[[1, 2, 3], [4, 5, 6]],
  2. [[7, 8, 9], [10, 11, 12]]]]

输出大小:[1, 1, 1, 12]

[[[[ 1  2  3  4  5  6  7  8  9 10 11 12]]]]
 

(3)输入大小为:[1, 4, 4, 1]


  
  1. test_tensor_3 = [[[[1], [2], [5], [6]],
  2. [[3], [4], [7], [8]],
  3. [[9], [10], [13], [14]],
  4. [[11], [12], [15], [16]]]]

输出大小:[1, 2, 2, 4]


  
  1. [[[[ 1 2 3 4]
  2. [ 5 6 7 8]]
  3. [[ 9 10 11 12]
  4. [13 14 15 16]]]]

三、代码验证

 test_space_to_depth.py


  
  1. import tensorflow as tf
  2. import os
  3. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  4. test_tensor_1 = tf.placeholder('float32', [1, 2, 2, 1], name='test_tensor_1')
  5. test_tensor_1 = [[[[1], [2]],
  6. [[3], [4]]]]
  7. test_tensor_1 = tf.space_to_depth(test_tensor_1, 2)
  8. test_tensor_2 = tf.placeholder('float32', [1, 2, 2, 3], name='test_tensor_1')
  9. test_tensor_2 = [[[[1, 2, 3], [4, 5, 6]],
  10. [[7, 8, 9], [10, 11, 12]]]]
  11. test_tensor_2 = tf.space_to_depth(test_tensor_2, 2)
  12. test_tensor_3 = tf.placeholder('float32', [1, 4, 4, 1], name='test_tensor_1')
  13. test_tensor_3 = [[[[1], [2], [5], [6]],
  14. [[3], [4], [7], [8]],
  15. [[9], [10], [13], [14]],
  16. [[11], [12], [15], [16]]]]
  17. test_tensor_3 = tf.space_to_depth(test_tensor_3, 2)
  18. sess = tf.Session()
  19. print(test_tensor_1)
  20. print(sess.run(test_tensor_1))
  21. print(test_tensor_2)
  22. print(sess.run(test_tensor_2))
  23. print(test_tensor_3)
  24. print(sess.run(test_tensor_3))

结果:

文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。

原文链接:nickhuang1996.blog.csdn.net/article/details/89471553

(完)