ホーム > カテゴリ > Python・人工知能・Django >

CNNで畳み込み/プーリング後のテンソルのサイズ(Shape)を確認する

Python/TensorFlowの使い方(目次)

CNN(畳み込みニューラルネットワーク)で畳み込み、プーリング後のテンソルのサイズは非常にわかりにくいです。

そこで、Tensorオブジェクトのget_shape()メソッドを使用すると、テンソルのサイズを簡単に確認する事が可能です。

※全コードは後述する参考文献を参照してください。

[5x5]ストライド1でパディングあり

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='SAME')

print(conv1.get_shape())

(?, 28, 28, 32)

[5x5]ストライド1でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID')

print(conv1.get_shape())

(?, 24, 24, 32)

[3x3]ストライド1でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID')

print(conv1.get_shape())

(?, 26, 26, 32)

[3x3]ストライド2でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='VALID')

print(conv1.get_shape())

(?, 13, 13, 32)

[3x3]ストライド2でパディングあり

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='SAME')

print(conv1.get_shape())

(?, 14, 14, 32)

[3x3]ストライド3でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='VALID')

print(conv1 .get_shape())

(?, 9, 9, 32)

[3x3]ストライド3でパディングあり

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='SAME')

print(conv1 .get_shape())

(?, 10, 10, 32)

最後に

これらの例は畳み込みでしたが、プーリングでも同様に確認可能です。

参考文献

TensorFlowではじめるDeepLearning実践入門のサンプルコード





関連記事



公開日:2018年08月02日 最終更新日:2018年08月24日
記事NO:02710