CNNで畳み込み/プーリング後のテンソルのサイズ(Shape)を確認する
CNN(畳み込みニューラルネットワーク)で畳み込み、プーリング後のテンソルのサイズは非常にわかりにくいです。
そこで、Tensorオブジェクトのget_shape()メソッドを使用すると、テンソルのサイズを簡単に確認する事が可能です。
※全コードは後述する参考文献を参照してください。
[5x5]ストライド1でパディングあり
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='SAME') print(conv1.get_shape())
(?, 28, 28, 32)
[5x5]ストライド1でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID') print(conv1.get_shape())
(?, 24, 24, 32)
[3x3]ストライド1でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID') print(conv1.get_shape())
(?, 26, 26, 32)
[3x3]ストライド2でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='VALID') print(conv1.get_shape())
(?, 13, 13, 32)
[3x3]ストライド2でパディングあり
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='SAME') print(conv1.get_shape())
(?, 14, 14, 32)
[3x3]ストライド3でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='VALID') print(conv1 .get_shape())
(?, 9, 9, 32)
[3x3]ストライド3でパディングあり
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='SAME') print(conv1 .get_shape())
(?, 10, 10, 32)
最後に
これらの例は畳み込みでしたが、プーリングでも同様に確認可能です。
参考文献
TensorFlowではじめるDeepLearning実践入門のサンプルコード
スポンサーリンク
関連記事
前の記事: | TensorFlow.jsのHello World [WebでAIモデルを実行する] |
次の記事: | 画像内の物体を検出するObject Detection APIの使用方法 [TensorFlow] |
公開日:2018年08月02日 最終更新日:2018年08月24日
記事NO:02710
この記事を書いた人
![]() | 💻 ITスキル・経験 サーバー構築からWebアプリケーション開発。IoTをはじめとする電子工作、ロボット、人工知能やスマホ/OSアプリまで分野問わず経験。 画像処理/音声処理/アニメーション、3Dゲーム、会計ソフト、PDF作成/編集、逆アセンブラ、EXE/DLLファイルの書き換えなどのアプリを公開。詳しくは自己紹介へ |
プチモンテ代表、アーティスト名:プチモンテ | |
🎵 音楽制作 BGMは楽器(音源)さえあれば、何でも制作可能。歌モノは主にロック、バラード、ポップスを制作。歌詞は叙情詩、叙情的な楽曲が多い。楽曲制作は2023年12月中旬 ~ |
オリジナル曲を始めました✨
YouTubeで各楽曲を公開しています🌈
https://www.youtube.com/@petitmonte