*本文由RS components 贊助發表,轉載自DesignSpark部落格原文連結
| 作者/攝影 | 張嘉鈞 |
| 難度 |
★★★★☆(中偏難) |
| 材料表 |
|
DCGAN 到 cDCGAN
先來稍微複習一下DCGAN (深度捲積生成對抗網路),字面上就是利用捲積網路的架構來做生成對抗,主要由生成器與鑑別器所構成,如下圖所示:

生成器會將一組雜訊或稱做潛在空間的張量轉換成一張照片,這張照片再經由鑑別器去判斷圖片是否夠真實,越接近0越假;越接近1越真。
由於我們在訓練的時候其實是沒有載入標籤的!所以他生成的時候都是隨機生成,為了能限制特定的輸出我們必須載入標籤,概念圖就會變成下面這張:

透過標籤的導入,讓生成器知道要生成的對象是哪一個數字,並且鑑別器訓練的目標變成「圖像是否真實」加上「是否符合該類別」,cDCGAN跟DCGAN相比,訓練的結果通常會比較好,因為DCGAN神經網路是盲目的去生成,而cDCGAN則是會將生成的範圍縮小,整體而言會收斂更快且更好。
將標籤合併於資料中
首先我們要先了解如何加入標籤,對於DCGAN來說有兩種加入標籤的方法,第一個是一開始就將圖片或雜訊跟標籤合併;另一個方法則是在深層做合併,讀者們在實作的時候可以自行調整看看差異,那較常見的做法是深層合併,而我寫的也是!
![]() |
![]() |
| 潛層合併,先合併再輸入網路 | 深層合併:各別輸入後再合併 |
其中詳細的差別我還沒涉略到,不過選定了深層合併接著就可以先來實作生成器跟鑑別器了。首先先來建構生成器,可以參考上一篇DCGAN的程式碼,這邊幫大家整理了一張概念圖:

輸入的z是維度為 ( 100, 1, 1) 的雜訊,為了將標籤跟雜訊能合併,必須轉換到相同大小也就是 (1, 1),可以看到這邊 y 的維度是 ( 10, 1, 1 ) 原因在於我們將原先阿拉伯數字的標籤轉成 onehot 編碼格式,如下圖所示。

OneHot編碼主要在於讓標籤離散,如果將標籤都用阿拉伯數字表示,對於神經網路而言他們屬於連續性的數值或許會將前後順序、距離給考慮進去,但是用onehot之後將可以將各類標籤單獨隔開並且對於彼此的距離也會相同。
建立Generator
接下來是程式的部分,如何在神經網路中做分流又合併,其實對於PyTorch而言非常的簡單只要在forward的地方做torch.cat就可以了。首先一樣要先定義網路層,我們定義了三個 Sequential,其中input_x是給圖像用的所以第一層deconv的輸入維度是z_dim;而input_y則是標籤用所以deconv的輸入是label_dim,可以對照上面的圖片看看:
[pastacode lang=”python” manual=”%20%20%20%20%20%20%20%20def%20__init__(self%2C%20z_dim%2C%20label_dim)%3A%0A%20%20%20%20%20%20%20%20super(Generator%2C%20self).__init__()%0A%20%20%20%20%20%20%20%20self.input_x%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20input%20is%20Z%2C%20going%20into%20a%20convolution%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ConvTranspose2d(z_dim%2C%20256%2C%204%2C%201%2C%200%2C%20bias%3DFalse)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.BatchNorm2d(256)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ReLU(True)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20image%20size%20%3D%20%20(1-1)*1%20-%202*0%20%2B%204%20%3D%204%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20self.input_y%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20input%20is%20Z%2C%20going%20into%20a%20convolution%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ConvTranspose2d(%20label_dim%2C%20256%2C%204%2C%201%2C%200%2C%20bias%3DFalse)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.BatchNorm2d(256)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ReLU(True)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20image%20size%20%3D%20%20(1-1)*1%20-%202*0%20%2B%204%20%3D%204%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20self.concat%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20%E5%9B%A0%E7%82%BA%20x%20%E8%B7%9F%20y%20%E6%B0%B4%E5%B9%B3%E5%90%88%E4%BD%B5%E6%89%80%E4%BB%A5%E8%A6%81%E5%86%8D%E4%B9%98%E4%BB%A5%202%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ConvTranspose2d(256*2%2C%20128%2C%204%2C%202%2C%201%2C%20bias%3DFalse)%2C%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.BatchNorm2d(128)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ReLU(True)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20image%20size%20%3D%20%20(4-1)*2%20-%202*1%20%2B%204%20%3D%208%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ConvTranspose2d(%20128%2C%2064%2C%204%2C%202%2C%201%2C%20bias%3DFalse)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.BatchNorm2d(64)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ReLU(True)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20image%20size%20%3D%20%20(8-1)*2%20-%202*1%20%2B%204%20%3D%2016%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.ConvTranspose2d(%2064%2C%201%2C%204%2C%202%2C%203%2C%20bias%3DFalse)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.Tanh()%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20image%20size%20%3D%20%20(16-1)*2%20-%202*3%20%2B%204%20%3D%2028%0A%20%20%20%20%20%20%20%20)” message=”” highlight=”” provider=”manual”/]接下來看 forward的部分,可以看到我們在向前傳遞的時候要丟入兩個數值,雜訊跟標籤,將x跟y丟進各自的Sequential中,接著我們使用torch.cat將x, y從橫向 ( dim=1 ) 合併後再進到concat中。
[pastacode lang=”python” manual=”%20%20%20%20def%20forward(self%2C%20x%2C%20y)%3A%0A%20%20%20%20%20%20%20%20x%20%3D%20self.input_x(x)%0A%20%20%20%20%20%20%20%20y%20%3D%20self.input_y(y)%0A%20%20%20%20%20%20%20%20out%20%3D%20torch.cat(%5Bx%2C%20y%5D%20%2C%20dim%3D1)%0A%20%20%20%20%20%20%20%20out%20%3D%20self.concat(out)%0A%20%20%20%20%20%20%20%20return%20out” message=”” highlight=”” provider=”manual”/]接下來可以試著將網路架構顯示出來,我們直接使用print也使用torchsummary來顯示,你可以發現其實你沒辦法看出網路分支再合併的狀況
[pastacode lang=”python” manual=”def%20print_div(text)%3A%0A%20%20%20%20%0A%20%20%20%20div%3D’%5Cn’%0A%20%20%20%20for%20i%20in%20range(60)%3A%20div%20%2B%3D%20%22%3D%22%0A%20%20%20%20div%2B%3D’%5Cn’%0A%20%20%20%20print(%22%7B%7D%20%7B%3A%5E60%7D%20%7B%7D%22.format(div%2C%20text%2C%20div))%0A%20%20%20%20%0A%22%22%22%20Define%20Generator%20%22%22%22%0AG%20%3D%20Generator(100%2C%2010)%0A%20%0A%22%22%22%20Use%20Print%22%22%22%0Aprint_div(‘Print%20Model%20Directly’)%0Aprint(G)%0A%20%0A%22%22%22%20Use%20Torchsummary%20%22%22%22%0Aprint_div(‘Print%20Model%20With%20Torchsummary’)%0Atest_x%20%3D%20(100%2C%201%2C%201)%0Atest_y%20%3D%20(10%2C%201%2C%201)%0Asummary(G%2C%20%5Btest_x%2C%20test_y%5D%2C%20batch_size%3D64)” message=”” highlight=”” provider=”manual”/]


所以我決定使用更圖像化一點的方式來視覺化我們的網路架構,現在有不下10種的圖形化方式,我舉兩個例子:Tensorboard、hiddenlayer。
視覺化模型
Tensorboard 是Google 出的強大視覺化工具,一般的文字、數值、影像、聲音都可以動態的紀錄在上面,一開始只支援Tensorflow 但是 PyTorch 1.2 之後都包含在其中 ( 但是要使用的話還是要先安裝tensorboard ) ,你可以直接從 torch.utils.tensorboard 中呼叫 Tensorboard,首先需要先實體化 SummaryWritter,接著直接使用add_graph即可將圖片存到伺服器上
[pastacode lang=”python” manual=”%22%22%22%20Initial%20Parameters%22%22%22%0Abatch_size%20%3D%201%0Atest_x%20%3D%20torch.rand(batch_size%2C%20100%2C%201%2C%201)%0Atest_y%20%3D%20torch.rand(batch_size%2C%2010%2C%201%2C%201)%0A%20%0Aprint_div(‘Print%20Model%20With%20Tensorboard’)%0Aprint(‘open%20terminal%20and%20input%20%22tensorboard%20–logdir%3Druns%22’)%0Aprint(‘open%20browser%20and%20key%20http%3A%2F%2Flocalhost%3A6006’)%0Awriter%20%3D%20SummaryWriter()%0Awriter.add_graph(G%2C%20(test_x%2C%20test_y))%0Awriter.close()” message=”” highlight=”” provider=”manual”/]接下來要開啟伺服器,在終端機中移動到與程式碼同一層級的位置並且輸入:
[pastacode lang=”python” manual=”%3E%20tensorboard%20%E2%80%93logdir%3D.%2Fruns” message=”” highlight=”” provider=”manual”/]
一開始就可以看到 input > Generator 的箭頭有寫 2 tensor,而這些方塊都可以打開:

開啟後你可以看到更細部的資訊,也很清楚就可以看到支線合併的狀況。

每一次捲積後的形狀大小也都有顯示出來:

接下來簡單介紹一下hiddenlayer ,它不能用來取代高級API像是tensorboard之類的,它僅僅就是用來顯示神經網路模型,但是非常的輕巧所以我個人蠻愛使用它的,首先要先透過pip安裝hiddenlayer、graphviz:
[pastacode lang=”python” manual=”%3E%20pip%20install%20hiddenlayer%0A%3E%20Pip%20install%20graphviz” message=”” highlight=”” provider=”manual”/]如果是用Jetson Nano的話,建議用 apt去裝 graphviz
[pastacode lang=”python” manual=”%24%20sudo%20apt-get%20install%20graphviz” message=”” highlight=”” provider=”manual”/]接著用 build_graph就能產生圖像也能直接儲存:
[pastacode lang=”python” manual=”%22%22%22%20Initial%20Parameters%22%22%22%0Abatch_size%20%3D%201%0Atest_x%20%3D%20torch.rand(batch_size%2C%20100%2C%201%2C%201)%0Atest_y%20%3D%20torch.rand(batch_size%2C%2010%2C%201%2C%201)%0A%20%0Aprint_div(‘Print%20Model%20With%20HiddenLayer’)%0Ag_graph%20%3D%20hl.build_graph(G%2C%20(test_x%2C%20test_y))%0Ag_graph.save(‘.%2Fimages%2FG_grpah’%2C%20format%3D%22jpg%22)%0Ag_graph” message=”” highlight=”” provider=”manual”/]因為太長了所以我截成兩半方便觀察,這邊就可以注意到前面的ConvTranspose、BatchNorm、ReLU是分開的,之後才合併這邊還特別給了一個Concat的方塊,我喜歡使用它的原因是簡單明瞭,捲積後的維度也都有寫下來,並且直接執行就可以看到結果,不用像Tensorboard還要再開啟服務。

建立Discriminator

跟建立Generator的概念相似,我們要個別處理輸入的圖片跟標籤,所以一樣宣告兩個 Sequential 個別處理接著再將輸出 concate 在一起,主要要注意的是 y 的輸入為度為 (10, 28, 28):
[pastacode lang=”python” manual=”import%20torch%0Aimport%20torch.nn%20as%20nn%0Afrom%20torchsummary%20import%20summary%0A%20%0Aclass%20Discriminator(nn.Module)%3A%0A%20%20%20%20%0A%20%20%20%20def%20__init__(self%2C%20c_dim%3D1%2C%20label_dim%3D10)%3A%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20super(Discriminator%2C%20self).__init__()%0A%20%0A%20%20%20%20%20%20%20%20self.input_x%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Input%20size%20%3D%201%20%2C28%20%2C%2028%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(c_dim%2C%2064%2C%20(4%2C4)%2C%202%2C%201)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.LeakyReLU()%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20self.input_y%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Input%20size%20%3D%2010%20%2C28%20%2C%2028%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(label_dim%2C%2064%2C%20(4%2C4)%2C%202%2C%201)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.LeakyReLU()%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20self.concate%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Input%20size%20%3D%2064%2B64%20%2C14%20%2C%2014%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(64*2%20%2C%2064%2C%20(4%2C4)%2C%202%2C%201)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.LeakyReLU()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Input%20size%20%3D%20(14-4%2B2)%2F2%20%2B1%20%3D%207%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(64%2C%20128%2C%203%2C%202%2C%201)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.LeakyReLU()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Input%20size%20%3D%20(7-3%2B2)%2F2%20%2B1%20%3D%204%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(128%2C%201%2C%204%2C%202%2C%200)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.Sigmoid()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20output%20size%20%3D%20(4-4)%2F2%20%2B1%20%3D%201%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20def%20forward(self%2C%20x%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20x%20%3D%20self.input_x(x)%0A%20%20%20%20%20%20%20%20y%20%3D%20self.input_y(y)%0A%20%20%20%20%20%20%20%20out%20%3D%20torch.cat(%5Bx%2C%20y%5D%20%2C%20dim%3D1)%0A%20%20%20%20%20%20%20%20out%20%3D%20self.concate(out)%0A%20%20%20%20%20%20%20%20return%20out%0A%20%20%20%20%0AD%20%3D%20Discriminator(1%2C%2010)%0Atest_x%20%3D%20torch.rand(64%2C%201%2C28%2C28)%0Atest_y%20%3D%20torch.rand(64%2C%2010%2C28%2C28)%0A%20%0Awriter%20%3D%20SummaryWriter()%0Awriter.add_graph(D%2C%20(test_x%2C%20test_y))%0Awriter.close()%0A%20%0Ahl.build_graph(D%2C%20(test_x%2C%20test_y))” message=”” highlight=”” provider=”manual”/]
視覺化的結果如下:

數據處理
神經網路都建置好就可以準備來訓練啦!當然第一步要先將數據處理好,那我個人自學神經網路的過程我覺得最難的就是數據處理了,這次數據處理有2個部分:
- 宣告固定的雜訊跟標籤用來預測用
- 將標籤轉換成onehot格式 ( scatter )
Onehot數據處理,在torch中可以直接使用scatter的方式,我在程式註解的地方有推薦一篇文章大家可以去了解scatter的概念,至於這邊我先附上實驗的程式碼:
[pastacode lang=”python” manual=”%22%22%22%20OneHot%20%E6%A0%BC%E5%BC%8F%20%E4%B9%8B%20scatter%20%E6%87%89%E7%94%A8%22%22%22%0A%22%22%22%20%E8%B6%85%E5%A5%BD%E7%90%86%E8%A7%A3%E7%9A%84%E5%9C%96%E5%BD%A2%E5%8C%96%E6%95%99%E5%AD%B8%20https%3A%2F%2Fmedium.com%2F%40yang6367%2Funderstand-torch-scatter-b0fd6275331c%20%22%22%22%0A%20%0Alabel%20%3Dtorch.tensor(%5B1%2C5%2C6%2C9%5D)%0Aprint(label%2C%20label.shape)%0A%20%0A%20%0Aa%20%3D%20torch.zeros(10).scatter_(0%2C%20label%2C%201)%0Aprint(a)%0A%20%0Aprint(‘%5Cn%5Cn’)%0Alabel_%3Dlabel.unsqueeze(1)%0Aprint(label_%2C%20label_.shape)%0Ab%20%3D%20torch.zeros(4%2C10).scatter_(1%2C%20label_%2C%201)%0Aprint(b)” message=”” highlight=”” provider=”manual”/]
接下來我們將兩個部分分開處理,先來處理測試用的雜訊跟標籤,測試用圖片為每個類別各10張,所以總共有100張圖片代表是100組雜訊及對應label:
[pastacode lang=”python” manual=”%22%22%22%20%E7%94%A2%E7%94%9F%E5%9B%BA%E5%AE%9A%E8%B3%87%E6%96%99%EF%BC%8C%E6%AF%8F%E5%80%8B%E9%A1%9E%E5%88%A510%E5%BC%B5%E5%9C%96(%E9%9B%9C%E8%A8%8A)%20%E4%BB%A5%E5%8F%8A%20%E5%B0%8D%E6%87%89%E7%9A%84%E6%A8%99%E7%B1%A4%EF%BC%8C%E7%94%A8%E6%96%BC%E8%A6%96%E8%A6%BA%E5%8C%96%E7%B5%90%E6%9E%9C%20%22%22%22%0Atemp_noise%20%3D%20torch.randn(label_dim%2C%20z_dim)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20(10%2C%20100)%2010%E5%BC%B5%E5%9C%96%0Afixed_noise%20%3D%20temp_noise%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%0Afixed_c%20%3D%20torch.zeros(label_dim%2C%201)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20(10%2C%201%20)%2010%E5%80%8B%E6%A8%99%E7%B1%A4%0A%20%0Afor%20i%20in%20range(9)%3A%0A%20%20%20%20fixed_noise%20%3D%20torch.cat(%5Bfixed_noise%2C%20temp_noise%5D%2C%200)%20%20%20%20%23%E5%B0%87%E6%AF%8F%E5%80%8B%E9%A1%9E%E5%88%A5%E7%9A%84%E5%8D%81%E5%BC%B5%E9%9B%9C%E8%A8%8A%E4%BE%9D%E5%BA%8F%E5%90%88%E4%BD%B5%EF%BC%8C%E7%B6%AD%E5%BA%A61%E6%9C%83%E8%87%AA%E5%8B%95boardcast%0A%20%20%20%20temp%20%3D%20torch.ones(label_dim%2C%201)%20%2B%20i%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%E4%BE%9D%E5%BA%8F%E5%B0%87%E6%A8%99%E7%B1%A4%E5%B0%8D%E6%87%89%E4%B8%8A%200~9%0A%20%20%20%20fixed_c%20%3D%20torch.cat(%5Bfixed_c%2C%20temp%5D%2C%200)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%E5%B0%87%E6%A8%99%E7%B1%A4%E4%B9%9F%E4%BE%9D%E5%BA%8F%E5%90%88%E4%BD%B5%0A%20%0Afixed_noise%20%3D%20fixed_noise.view(-1%2C%20z_dim%2C%201%2C%201)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%E7%94%B1%E6%96%BC%E6%98%AF%E6%8D%B2%E7%A9%8D%E6%89%80%E4%BB%A5%E6%88%91%E5%80%91%E8%A6%81%E5%B0%87%E5%BD%A2%E7%8B%80%E8%BD%89%E6%8F%9B%E6%88%90%E4%BA%8C%E7%B6%AD%E7%9A%84%0Aprint(‘Predict%20Noise%3A%20’%2C%20fixed_noise.shape)%0Aprint(‘Predict%20Label%20(before)%3A%20’%2C%20fixed_c.shape%2C%20’%5Ct%5Ct%5Ct’%2C%20fixed_c%5B50%5D)%20%20%20%20%0A%20%0A%22%22%22%20%E9%87%9D%E5%B0%8D%20lael%20%E5%81%9A%20onehot%20%22%22%22%0Afixed_label%20%3D%20torch.zeros(100%2C%20label_dim)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%E5%85%88%E7%94%A2%E7%94%9F%20%5B100%2C10%5D%20%E7%9A%84%E5%85%A80%E5%BC%B5%E9%87%8F%EF%BC%8C100%E5%80%8B%E6%A8%99%E7%B1%A4%EF%BC%8C%E6%AF%8F%E5%80%8B%E6%A8%99%E7%B1%A4%E7%B6%AD%E5%BA%A6%E6%98%AF%2010%0Afixed_label.scatter_(1%2C%20fixed_c.type(torch.LongTensor)%2C%201)%20%20%20%23%E8%BD%89%E6%88%90%20onehot%E7%B7%A8%E7%A2%BC%20(1%2C%20)%20-%3E%20(10%2C%20)%0Afixed_label%20%3D%20fixed_label.view(-1%2C%20label_dim%2C%201%2C%201)%20%20%20%20%20%20%20%20%20%20%23%E8%BD%89%E6%8F%9B%E5%BD%A2%E7%8B%80%20(10%2C%201%2C%201%20)%20%0Aprint(‘Predict%20Label%20(onehot)%3A%20’%2Cfixed_label.shape%2C%20’%5Ct%5Ct’%2C%20fixed_label%5B50%5D.view(1%2C-1)%2C%20’%5Cn’)” message=”” highlight=”” provider=”manual”/]我在顯示的時候有將形狀從 (10,1)變成(1,10) 來方便做觀察:
接下來要幫訓練的數據做前處理,處理方式跟前面雷同,主要差別在要餵給鑑別器的標籤 ( fill ) 處理方式比較不同,從結果圖就能看的出來彼此不同的地方:
[pastacode lang=”python” manual=”%22%22%22%20%E5%B9%AB%E6%A8%99%E7%B1%A4%E5%81%9A%E5%89%8D%E8%99%95%E7%90%86%EF%BC%8Conehot%20for%20g%2C%20fill%20for%20d%20%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20%E7%94%A2%E7%94%9F%20(10%2C10)%2010%E5%80%8B%E6%A8%99%E7%B1%A4%EF%BC%8C%E7%B6%AD%E5%BA%A6%E7%82%BA10%20(onehot)%0A%20%0Aprint(‘Train%20G%20label%3A’%2Conehot%5B1%5D.shape%2C%20’%5Cn’%2C%20onehot%5B1%5D%2C%20’%5Cn’)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20%E5%81%87%E8%A8%AD%E6%88%91%E5%80%91%E8%A6%81%E5%8F%96%E5%BE%97%E6%A8%99%E7%B1%A4%201%20%E7%9A%84%20onehot%20(10%2C1%2C1)%EF%BC%8C%E7%9B%B4%E6%8E%A5%E8%BC%B8%E5%85%A5%E7%B4%A2%E5%BC%95%201%0A%20%0Afill%20%3D%20torch.zeros(%5Blabel_dim%2C%20label_dim%2C%20image_size%2C%20image_size%5D)%20%20%20%20%23%20%E7%94%A2%E7%94%9F%20(10%2C%2010%2C%2028%2C%2028)%20%E6%84%8F%E5%8D%B3%2010%E5%80%8B%E6%A8%99%E7%B1%A4%20%E7%B6%AD%E5%BA%A6%E9%83%BD%E6%98%AF%20(10%2C28%2C28)%0Afor%20i%20in%20range(label_dim)%3A%0A%20%20%20%20fill%5Bi%2C%20i%2C%20%3A%2C%20%3A%5D%20%3D%201%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20%E8%88%89%E4%BE%8B%20%E6%A8%99%E7%B1%A4%205%EF%BC%8C%E7%AC%AC%E4%B8%80%E5%80%8B%5B%5D%E4%BB%A3%E8%A1%A8%E6%A8%99%E7%B1%A45%EF%BC%8C%E7%AC%AC%E4%BA%8C%E5%80%8B%5B%5D%E4%BB%A3%E8%A1%A8onehot%E7%82%BA1%E7%9A%84%E4%BD%8D%E7%BD%AE%20%0Aprint(‘Train%20D%20Label%3A%20’%2C%20fill.shape)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%0Aprint(‘%5Cn’%2C%20fill%5B1%5D.shape%2C%20’%5Cn’%2C%20fill%5B1%5D)%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20%E5%81%87%E8%A8%AD%E6%88%91%E5%80%91%E8%A6%81%E5%8F%96%E5%BE%97%E6%A8%99%E7%B1%A4%201%20%E7%9A%84%20onehot%20(10%2C28%2C28)” message=”” highlight=”” provider=”manual”/]
開始訓練-起手式
一樣從基本的參數開始宣告起,流程個別是:基本參數、數據載入、建立訓練相關的東西(模型、優化器、損失)、開始訓練。
[pastacode lang=”python” manual=”%22%22%22%20%E5%9F%BA%E6%9C%AC%E5%8F%83%E6%95%B8%20%22%22%22%0Aepoch%20%3D%2010%0Alr%20%3D%201e-5%0Abatch%20%3D%204%0Adevice%20%3D%20torch.device(‘cuda%3A0’%20if%20torch.cuda.is_available()%20else%20’cpu’)%0Az_dim%20%3D%20100%20%20%20%20%20%20%20%20%23%20latent%20Space%0Ac_dim%20%3D%201%20%20%20%20%20%20%20%20%20%20%23%20Image%20Channel%0Alabel_dim%20%3D%2010%20%20%20%20%20%23%20label%20%0A%20%0A%20%0A%22%22%22%20%E5%8F%96%E5%BE%97%E6%95%B8%E6%93%9A%E9%9B%86%E4%BB%A5%E5%8F%8ADataLoader%20%22%22%22%0Atransform%20%3D%20trans.Compose(%5B%0A%20%20%20%20trans.ToTensor()%2C%0A%20%20%20%20trans.Normalize((0.5%2C)%2C(0.5%2C))%2C%0A%5D)%0A%20%0Atrain_set%20%3D%20dset.MNIST(root%3D’.%2Fmnist_data%2F’%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20train%3DTrue%2C%20transform%3Dtransform%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20download%3DTrue)%0A%20%0Atest_set%20%3D%20dset.MNIST(root%3D’.%2Fmnist_data%2F’%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20train%3DFalse%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20transform%3Dtransform%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20download%3DFalse)%0A%20%0Atrain_loader%20%3D%20torch.utils.data.DataLoader(%0A%20%20%20%20dataset%20%3D%20train_set%2C%0A%20%20%20%20batch_size%20%3D%20batch%2C%0A%20%20%20%20shuffle%3DTrue%2C%0A%20%20%20%20drop_last%3DTrue%0A)%0A%20%0Atest_loader%20%3D%20torch.utils.data.DataLoader(%0A%20%20%20%20dataset%20%3D%20test_set%2C%0A%20%20%20%20batch_size%20%3D%20batch%2C%0A%20%20%20%20shuffle%3DFalse%20%20%20%20%20%20%20%20%0A)%0A%20%0A%22%22%22%20%E8%A8%93%E7%B7%B4%E7%9B%B8%E9%97%9C%20%22%22%22%0A%20%0AD%20%3D%20Discriminator(c_dim%2C%20label_dim).to(device)%0AG%20%3D%20Generator(z_dim%2C%20label_dim).to(device)%0Aloss_fn%20%3D%20nn.BCELoss()%0AD_opt%20%3D%20optim.Adam(D.parameters()%2C%20lr%3D%20lr)%0AG_opt%20%3D%20optim.Adam(G.parameters()%2C%20lr%3D%20lr)%0AD_avg_loss%20%3D%20%5B%5D%0AG_avg_loss%20%3D%20%5B%5D%0A%20%0Aimg%20%3D%20%5B%5D%0Als_time%20%3D%20%5B%5D” message=”” highlight=”” provider=”manual”/]開始訓練 – 手動更新學習率
會手動更新主要原因在於其實GAN的訓練並不是那麼的順利,如果速度太快會導致震盪嚴重訓練生成效果極差,所以GAN普遍的學習率都會更新並且都蠻低的,這邊我也稍微調整一下:
[pastacode lang=”python” manual=”%20%20%20%20%22%22%22%20%E7%9C%8B%E5%88%B0%E5%BE%88%E5%A4%9A%E7%AF%84%E4%BE%8B%E9%83%BD%E6%9C%89%E6%89%8B%E5%8B%95%E8%AA%BF%E6%95%B4%E5%AD%B8%E7%BF%92%E7%8E%87%20%22%22%22%0A%20%20%20%20if%20epoch%20%3D%3D%208%3A%0A%20%20%20%20%20%20%20%20G_opt.param_groups%5B0%5D%5B’lr’%5D%20%2F%3D%205%0A%20%20%20%20%20%20%20%20D_opt.param_groups%5B0%5D%5B’lr’%5D%20%2F%3D%205″ message=”” highlight=”” provider=”manual”/]開始訓練 – 訓練D、G
一樣參考上一篇的DCGAN來改良,主要差別在於需要引入label,並且需要將label轉換成onehot格式,其中
鑑別器 (D) 的訓練步驟一樣先學真實圖片給予標籤1 再學生成圖片給予標籤 0,生成圖片的部分要產生對應的亂數label,丟入G的時候是從先前寫的 onehot 中提取對應的onehot格式標籤而丟入D的時候是從 fill 中提取~
生成器 (G) 的訓練方式就是把D的後半段拿出來用,但是標籤需要改成 1,因為它的目的是要騙過D!
[pastacode lang=”python” manual=”%22%22%22%20%E8%A8%93%E7%B7%B4%20D%20%22%22%22%0A%20%0AD_opt.zero_grad()%0A%20%0Ax_real%20%3D%20data.to(device)%0Ay_real%20%3D%20torch.ones(batch%2C%20).to(device)%0Ac_real%20%3D%20fill%5Blabel%5D.to(device)%0A%20%0Ay_real_predict%20%3D%20D(x_real%2C%20c_real).squeeze()%20%20%20%20%20%20%20%20%23%20(-1%2C%201%2C%201%2C%201)%20-%3E%20(-1%2C%20)%0Ad_real_loss%20%3D%20loss_fn(y_real_predict%2C%20y_real)%0Ad_real_loss.backward()%0A%20%0Anoise%20%3D%20torch.randn(batch%2C%20z_dim%2C%201%2C%201%2C%20device%20%3D%20device)%0Anoise_label%20%3D%20(torch.rand(batch%2C%201)%20*%20label_dim).type(torch.LongTensor).squeeze()%0Anoise_label_onehot%20%3D%20onehot%5Bnoise_label%5D.to(device)%20%20%20%23%E9%9A%A8%E6%A9%9F%E7%94%A2%E7%94%9Flabel%20(-1%2C%20)%0A%20%0Ax_fake%20%3D%20G(noise%2C%20noise_label_onehot)%20%20%20%20%20%20%20%23%20%E7%94%9F%E6%88%90%E5%81%87%E5%9C%96%0Ay_fake%20%3D%20torch.zeros(batch%2C%20).to(device)%20%20%20%20%23%20%E7%B5%A6%E4%BA%88%E6%A8%99%E7%B1%A4%200%0Ac_fake%20%3D%20fill%5Bnoise_label%5D.to(device)%20%20%20%20%20%20%20%23%20%E8%BD%89%E6%8F%9B%E6%88%90%E5%B0%8D%E6%87%89%E7%9A%84%2010%2C28%2C28%20%E7%9A%84%E6%A8%99%E7%B1%A4%0A%20%0Ay_fake_predict%20%3D%20D(x_fake%2C%20c_fake).squeeze()%0Ad_fake_loss%20%3D%20loss_fn(y_fake_predict%2C%20y_fake)%0Ad_fake_loss.backward()%0AD_opt.step()%0A%20%0A%22%22%22%20%E8%A8%93%E7%B7%B4%20G%20%22%22%22%0A%20%0AG_opt.zero_grad()%0A%20%0Anoise%20%3D%20torch.randn(batch%2C%20z_dim%2C%201%2C%201%2C%20device%20%3D%20device)%0Anoise_label%20%3D%20(torch.rand(batch%2C%201)%20*%20label_dim).type(torch.LongTensor).squeeze()%0Anoise_label_onehot%20%3D%20onehot%5Bnoise_label%5D.to(device)%20%20%20%23%E9%9A%A8%E6%A9%9F%E7%94%A2%E7%94%9Flabel%20(-1%2C%20)%0A%20%0Ax_fake%20%3D%20G(noise%2C%20noise_label_onehot)%0A%23y_fake%20%3D%20torch.ones(batch%2C%20).to(device)%20%20%20%20%23%E9%80%99%E9%82%8A%E7%9A%84%20y_fake%20%E8%B7%9F%E4%B8%8A%E8%BF%B0%E7%9A%84%20y_real%20%E4%B8%80%E6%A8%A3%EF%BC%8C%E9%83%BD%E6%98%AF%201%20%20%0Ac_fake%20%3D%20fill%5Bnoise_label%5D.to(device)%0A%20%0Ay_fake_predict%20%3D%20D(x_fake%2C%20c_fake).squeeze()%0Ag_loss%20%3D%20loss_fn(y_fake_predict%2C%20y_real)%20%20%20%20%23%E7%9B%B4%E6%8E%A5%E7%94%A8%20y_real%20%E6%9B%B4%E7%9B%B4%E8%A7%80%0Ag_loss.backward()%0AG_opt.step()%0A%20%0AD_loss.append(d_fake_loss.item()%20%2B%20d_real_loss.item())%0AG_loss.append(g_loss.item())” message=”” highlight=”” provider=”manual”/]成果
起初我在第五次迭代的時候調整了學習率結果原本 1 到 5 學習的都不錯,到第 6次的時候開始有了偏差,所以真的不能亂調學習率阿~

下面是迭代15次的成果,感覺上比參考的gihub還要差了一些,仔細看了一下應該是D的結構跟learning rate的調整有差,大家可以再自己調整看看。

訓練時間比較
一樣都是 10 個 epoch ,Jetson Nano所需要的時間大約是 1 小時 40 分鐘,其實還算是蠻快的,大家可以試試看 CPU 去跑跑看就可以知道差異了。
![]() |
![]() |
結語
最後相信大家到看完這篇以及上一篇DCGAN已經對生成對抗網路有一定的熟悉度了,接下來我們可以找些GAN的github的範例來玩玩看並且增加應用。
*本文由RS components 贊助發表,轉載自DesignSpark部落格原文連結(本篇文章完整範例程式請至原文下載)








