ホーム > カテゴリ > Python・人工知能・Django >

CNNで畳み込み/プーリング後のテンソルのサイズ(Shape)を確認する

Python/TensorFlowの使い方(目次)

CNN(畳み込みニューラルネットワーク)で畳み込み、プーリング後のテンソルのサイズは非常にわかりにくいです。

そこで、Tensorオブジェクトのget_shape()メソッドを使用すると、テンソルのサイズを簡単に確認する事が可能です。

※全コードは後述する参考文献を参照してください。

[5x5]ストライド1でパディングあり

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='SAME')

print(conv1.get_shape())

(?, 28, 28, 32)

[5x5]ストライド1でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID')

print(conv1.get_shape())

(?, 24, 24, 32)

[3x3]ストライド1でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID')

print(conv1.get_shape())

(?, 26, 26, 32)

[3x3]ストライド2でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='VALID')

print(conv1.get_shape())

(?, 13, 13, 32)

[3x3]ストライド2でパディングあり

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='SAME')

print(conv1.get_shape())

(?, 14, 14, 32)

[3x3]ストライド3でパディングなし

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='VALID')

print(conv1 .get_shape())

(?, 9, 9, 32)

[3x3]ストライド3でパディングあり

x = tf.placeholder(tf.float32, [None, 784])
img = tf.reshape(x,[-1,28,28,1])
f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1))
conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='SAME')

print(conv1 .get_shape())

(?, 10, 10, 32)

最後に

これらの例は畳み込みでしたが、プーリングでも同様に確認可能です。

参考文献

TensorFlowではじめるDeepLearning実践入門のサンプルコード





関連記事



公開日:2018年08月02日 最終更新日:2018年08月24日
記事NO:02710


この記事を書いた人

💻 ITスキル・経験
サーバー構築からWebアプリケーション開発。IoTをはじめとする電子工作、ロボット、人工知能やスマホ/OSアプリまで分野問わず経験。

画像処理/音声処理/アニメーション、3Dゲーム、会計ソフト、PDF作成/編集、逆アセンブラ、EXE/DLLファイルの書き換えなどのアプリを公開。詳しくは自己紹介へ
プチモンテ代表、アーティスト名:プチモンテ
🎵 音楽制作
BGMは楽器(音源)さえあれば、何でも制作可能。歌モノは主にロック、バラード、ポップスを制作。歌詞は叙情詩、叙情的な楽曲が多い。楽曲制作は2023年12月中旬 ~

オリジナル曲を始めました✨

YouTubeで各楽曲を公開しています🌈
https://www.youtube.com/@petitmonte

【男性ボーカル】DA・KA・RA | 新たな明日が風と共に訪れる

【男性、女性ボーカル】時空を超越する先に | 時空と風の交響曲

【女性、男性ボーカル】絆 | 穏やかな心に奏でる旋律