利用Jetson Nano、Google Colab實作CycleGAN:將拍下來的照片、影片轉換成梵谷風格 – 訓練、預測以及應用篇

*本文由RS components 贊助發表,轉載自DesignSpark部落格原文連結 

作者/攝影 嘉鈞
難度

★★★★☆(中偏難)

材料表

RK-NVIDIA® Jetson Nano™ Developer Kit B01 套件

訓練CycleGAN

首先先取得訓練資料:

[pastacode lang=”python” manual=”from%20tqdm%20import%20tqdm%0Aimport%20torchvision.utils%20as%20vutils%0A%0Atotal_len%20%3D%20len(dataA_loader)%20%2B%20len(dataB_loader)%0A%0Afor%20epoch%20in%20range(epochs)%3A%20%0A%20%20%20%20progress_bar%20%3D%20tqdm(enumerate(zip(dataA_loader%2C%20dataB_loader))%2C%20total%20%3D%20total_len)%20%0A%20%20%20%20for%20idx%2C%20data%20in%20progress_bar%3A%20%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20define%20training%20data%20%26%20label%20%23%23%23%23%23%23%23%23%23%23%23%23%20%0A%20%20%20%20%20%20%20%20real_A%20%3D%20data%5B0%5D%5B0%5D.to(device)%20%20%20%20%23%20vangogh%20image%0A%20%20%20%20%20%20%20%20real_B%20%3D%20data%5B1%5D%5B0%5D.to(device)%20%20%20%20%23%20real%20picture” message=”” highlight=”” provider=”manual”/]

我們要先訓練G,總共有三個標準要來衡量生成器:

1.是否能騙過鑑別器 (Adversial Loss ):

對於G_B2A來說,將A轉換成B之後給予1的標籤,並且計算跟real_B 之間的距離。

[pastacode lang=”python” manual=”%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20Train%20G%20%23%23%23%23%23%23%23%23%23%23%23%23%20%0A%20%20%20%20%20%20%20%20optim_G.zero_grad()%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20%20Train%20G%20-%20Adversial%20Loss%20%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20%20%20%20%20fake_A%20%3D%20G_B2A(real_B)%0A%20%20%20%20%20%20%20%20fake_out_A%20%3D%20D_A(fake_A)%20%0A%20%20%20%20%20%20%20%20fake_B%20%3D%20G_A2B(real_A)%0A%20%20%20%20%20%20%20%20fake_out_B%20%3D%20D_B(fake_B)%0A%0A%20%20%20%20%20%20%20%20real_label%20%3D%20torch.ones(%20(fake_out_A.size())%20%2C%20dtype%3Dtorch.float32).to(device)%0A%20%20%20%20%20%20%20%20fake_label%20%3D%20torch.zeros(%20(fake_out_A.size())%20%2C%20dtype%3Dtorch.float32).to(device)%20%0A%20%20%20%20%20%20%20%20adversial_loss_B2A%20%3D%20MSE(fake_out_A%2C%20real_label)%0A%20%20%20%20%20%20%20%20adversial_loss_A2B%20%3D%20MSE(fake_out_B%2C%20real_label)%0A%20%20%20%20%20%20%20%20adv_loss%20%3D%20adversial_loss_B2A%20%2B%20adversial_loss_A2B” message=”” highlight=”” provider=”manual”/]

2.是否能重新建構 (Consistency Loss):

舉例 G_B2A(real_B) 產生風格A的圖像 (fake_A) 後,再丟進 G_A2B(fake_A) 重新建構成B風格的圖像 (rec_B),並且計算 real_B 跟 rec_B之間的差距。

[pastacode lang=”python” manual=”%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20%20G%20-%20Consistency%20Loss%20(Reconstruction)%20%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20%20%20%20%20rec_A%20%3D%20G_B2A(fake_B)%0A%20%20%20%20%20%20%20%20rec_B%20%3D%20G_A2B(fake_A)%20%0A%20%20%20%20%20%20%20%20consistency_loss_B2A%20%3D%20L1(rec_A%2C%20real_A)%0A%20%20%20%20%20%20%20%20consistency_loss_A2B%20%3D%20L1(rec_B%2C%20real_B)%20%0A%20%20%20%20%20%20%20%20rec_loss%20%3D%20consistency_loss_B2A%20%2B%20consistency_loss_A2B” message=”” highlight=”” provider=”manual”/]

3.是否能保持一致 (Identity Loss):

以G_A2B來說,是否在丟入 real_B的圖片後,確實能輸出 B風格的圖片,是否能保持原樣?

[pastacode lang=”python” manual=”%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20%20G%20-%20Identity%20%20Loss%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20%20%20%20%20idt_A%20%3D%20G_B2A(real_A)%0A%20%20%20%20%20%20%20%20idt_B%20%3D%20G_A2B(real_B)%20%0A%20%20%20%20%20%20%20%20identity_loss_A%20%3D%20L1(idt_A%2C%20real_A)%0A%20%20%20%20%20%20%20%20identity_loss_B%20%3D%20L1(idt_B%2C%20real_B)%20%0A%20%20%20%20%20%20%20%20idt_loss%20%3D%20identity_loss_A%20%2B%20identity_loss_B” message=”” highlight=”” provider=”manual”/]

接著訓練D,它只要將自己的本份顧好就好了,也就是「能否分辨得出該風格的成像是否真實」。

[pastacode lang=”python” manual=”%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20Train%20D%20%23%23%23%23%23%23%23%23%23%23%23%23%20%0A%20%20%20%20%20%20%20%20optim_D.zero_grad()%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20D%20-%20Adversial%20D_A%20Loss%20%23%23%23%23%23%23%23%23%23%23%23%23%20%20%0A%20%20%20%20%20%20%20%20real_out_A%20%3D%20D_A(real_A)%0A%20%20%20%20%20%20%20%20real_out_A_loss%20%3D%20MSE(real_out_A%2C%20real_label)%20%0A%20%20%20%20%20%20%20%20fake_out_A%20%3D%20D_A(fake_A_sample.push_and_pop(fake_A))%0A%20%20%20%20%20%20%20%20fake_out_A_loss%20%3D%20MSE(real_out_A%2C%20fake_label)%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20loss_DA%20%3D%20real_out_A_loss%20%2B%20fake_out_A_loss%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20%20D%20-%20Adversial%20D_B%20Loss%20%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20real_out_B%20%3D%20D_B(real_B)%0A%20%20%20%20%20%20%20%20real_out_B_loss%20%3D%20MSE(real_out_B%2C%20real_label)%20%0A%20%20%20%20%20%20%20%20fake_out_B%20%3D%20D_B(fake_B_sample.push_and_pop(fake_B))%0A%20%20%20%20%20%20%20%20fake_out_B_loss%20%3D%20MSE(fake_out_B%2C%20fake_label)%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20loss_DB%20%3D%20(%20real_out_B_loss%20%2B%20fake_out_B_loss%20)%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20%20D%20-%20Total%20Loss%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20loss_D%20%3D%20(%20loss_DA%20%2B%20loss_DB%20)%20*%200.5%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20%20Backward%20%26%20Update%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20loss_D.backward()%0A%20%20%20%20%20%20%20%20optim_D.step()” message=”” highlight=”” provider=”manual”/]

最後我們可以將一些資訊透過tqdm印出來

[pastacode lang=”python” manual=”%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20progress%20info%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20%20%20%20%20progress_bar.set_description(%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22%5B%7Bepoch%7D%2F%7Bepochs%20-%201%7D%5D%5B%7Bidx%7D%2F%7Blen(dataloader)%20-%201%7D%5D%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22Loss_D%3A%20%7B(loss_DA%20%2B%20loss_DB).item()%3A.4f%7D%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22Loss_G%3A%20%7Bloss_G.item()%3A.4f%7D%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22Loss_G_identity%3A%20%7B(idt_loss).item()%3A.4f%7D%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22loss_G_GAN%3A%20%7B(adv_loss).item()%3A.4f%7D%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22loss_G_cycle%3A%20%7B(rec_loss).item()%3A.4f%7D%22)” message=”” highlight=”” provider=”manual”/]

接著訓練GAN非常重要的環節就是要記得儲存權重,因為說不定訓練第100回合的效果比200回合的還要好,所以都會傾向一定的回合數就儲存一次。儲存的方法很簡單大家可以上PyTorch的官網查看,大致上總共有兩種儲存方式:

1.儲存模型結構以及權重

[pastacode lang=”python” manual=”torch.save(%20model%20)” message=”” highlight=”” provider=”manual”/]

2.只儲存權重

[pastacode lang=”python” manual=”torch.save(%20model.static_dict()%20)” message=”” highlight=”” provider=”manual”/]

而我採用的方式是只儲存權重,這也是官方建議的方案:

[pastacode lang=”python” manual=”%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20if%20i%20%25%20log_freq%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20vutils.save_image(real_A%2C%20f%22%7Boutput_path%7D%2Freal_A_%7Bepoch%7D.jpg%22%2C%20normalize%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20vutils.save_image(real_B%2C%20f%22%7Boutput_path%7D%2Freal_B_%7Bepoch%7D.jpg%22%2C%20normalize%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20fake_A%20%3D%20(%20G_B2A(%20real_B%20).data%20%2B%201.0%20)%20*%200.5%0A%20%20%20%20%20%20%20%20%20%20%20%20fake_B%20%3D%20(%20G_A2B(%20real_A%20).data%20%2B%201.0%20)%20*%200.5%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20vutils.save_image(fake_A%2C%20f%22%7Boutput_path%7D%2Ffake_A_%7Bepoch%7D.jpg%22%2C%20normalize%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20vutils.save_image(fake_B%2C%20f%22%7Boutput_path%7D%2Ffake_A_%7Bepoch%7D.jpg%22%2C%20normalize%3DTrue)%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20torch.save(G_A2B.state_dict()%2C%20f%22weights%2FnetG_A2B_epoch_%7Bepoch%7D.pth%22)%0A%20%20%20%20torch.save(G_B2A.state_dict()%2C%20f%22weights%2FnetG_B2A_epoch_%7Bepoch%7D.pth%22)%0A%20%20%20%20torch.save(D_A.state_dict()%2C%20f%22weights%2FnetD_A_epoch_%7Bepoch%7D.pth%22)%0A%20%20%20%20torch.save(D_B.state_dict()%2C%20f%22weights%2FnetD_B_epoch_%7Bepoch%7D.pth%22)%0A%0A%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%20Update%20learning%20rates%20%23%23%23%23%23%23%23%23%23%23%23%23%0A%20%20%20%20lr_scheduler_G.step()%0A%20%20%20%20lr_scheduler_D.step()%0A%0A%23%23%23%23%23%23%23%23%23%23%23%23%20save%20last%20check%20pointing%20%23%23%23%23%23%23%23%23%23%23%23%23%0Atorch.save(netG_A2B.state_dict()%2C%20f%22weights%2FnetG_A2B.pth%22)%0Atorch.save(netG_B2A.state_dict()%2C%20f%22weights%2FnetG_B2A.pth%22)%0Atorch.save(netD_A.state_dict()%2C%20f%22weights%2FnetD_A.pth%22)%0Atorch.save(netD_B.state_dict()%2C%20f%22weights%2FnetD_B.pth%22)%20%20%20″ message=”” highlight=”” provider=”manual”/]

測試

其實測試非常的簡單,跟著以下的步驟就可以完成:

1.導入函式庫

[pastacode lang=”python” manual=”import%20os%0Aimport%20torch%0Aimport%20torchvision.datasets%20as%20dsets%0Afrom%20torch.utils.data%20import%20DataLoader%0Aimport%20torchvision.transforms%20as%20transforms%0Afrom%20tqdm%20import%20tqdm%0Aimport%20torchvision.utils%20as%20vutils” message=”” highlight=”” provider=”manual”/]

2.將測試資料建一個數據集並透過DataLoader載入:

這邊我創了一個Custom資料夾存放我自己的數據,並且新建了一個output資料夾方便察看結果。

[pastacode lang=”python” manual=”%20%20batch_size%20%3D%2012%0A%20%20device%20%3D%20’cuda%3A0’%20if%20torch.cuda.is_available()%20else%20’cpu’%0A%0A%20%20transform%20%3D%20transforms.Compose(%20%5Btransforms.Resize((256%2C256))%2C%0A%20%20%20%20%20%20%20%20transforms.ToTensor()%2C%0A%20%20%20%20%20%20%20%20transforms.Normalize(mean%3D%5B0.5%2C%200.5%2C%200.5%5D%2C%20std%3D%5B0.5%2C%200.5%2C%200.5%5D)%5D)%0A%0A%20%20root%20%3D%20r’vangogh2photo’%0A%20%20targetC_path%20%3D%20os.path.join(root%2C%20’custom’)%0A%20%20output_path%20%3D%20os.path.join(‘.%2F’%2C%20r’output’)%0A%0A%20%20if%20os.path.exists(output_path)%20%3D%3D%20False%3A%0A%20%20%20%20os.mkdir(output_path)%0A%20%20%20%20print(‘Create%20dir%20%3A%20’%2C%20output_path)%0A%0A%20%20dataC_loader%20%3D%20DataLoader(dsets.ImageFolder(targetC_path%2C%20transform%3Dtransform)%2C%20batch_size%3Dbatch_size%2C%20shuffle%3DTrue%2C%20num_workers%3D4)” message=”” highlight=”” provider=”manual”/]

3.實例化生成器、載入權重 (load_static_dict)、選擇模式 ( train or eval ),如果選擇 eval,PyTorch會將Drop給自動關掉;因為我只要真實照片轉成梵谷所以只宣告了G_B2A:

[pastacode lang=”python” manual=”%20%20%23%20get%20generator%0A%20%20G_B2A%20%3D%20Generator().to(device)%0A%0A%20%20%23%20Load%20state%20dicts%0A%20%20G_B2A.load_state_dict(torch.load(os.path.join(%22weights%22%2C%20%22netG_B2A.pth%22)))%0A%0A%20%20%23%20Set%20model%20mode%0A%20%20G_B2A.eval()” message=”” highlight=”” provider=”manual”/]

4.開始進行預測:

取得資料>丟進模型取得輸出>儲存圖片

[pastacode lang=”python” manual=”progress_bar%20%3D%20tqdm(enumerate(dataC_loader)%2C%20total%3Dlen(dataC_loader))%0A%0A%20%20for%20i%2C%20data%20in%20progress_bar%3A%0A%E3%80%80%E3%80%80%E3%80%80%23%20get%20data%0A%20%20%20%20%20%20real_images_B%20%3D%20data%5B0%5D.to(device)%0A%0A%20%20%20%20%20%20%23%20Generate%20output%0A%20%20%20%20%20%20fake_image_A%20%3D%200.5%20*%20(G_B2A(real_images_B).data%20%2B%201.0)%0A%0A%20%20%20%20%20%20%23%20Save%20image%20files%0A%20%20%20%20%20%20vutils.save_image(fake_image_A.detach()%2C%20f%22%7Boutput_path%7D%2FFakeA_%7Bi%20%2B%201%3A04d%7D.jpg%22%2C%20normalize%3DTrue)%0A%0A%20%20%20%20%20%20progress_bar.set_description(f%22Process%20images%20%7Bi%20%2B%201%7D%20of%20%7Blen(dataC_loader)%7D%22)” message=”” highlight=”” provider=”manual”/]

5.去output查看結果:

可能是因為我只有訓練100回合,梵谷風格的細節線條還沒學起來,大家可以嘗試再訓練久一點,理論上200回合就會有不錯的成果了!

 

ORIGINAL

 

TRANSFORM

 

好的,那現在已經會建構、訓練以及預測了,接下來我們來想個辦法應用它!講到Style Transfer的應用,第一個就想到微軟大大提供的Style Transfer Azure Website。

 

Azure 的Style Transfer

網站連結  https://styletransfers.azurewebsites.net/

 

這種拍一張照片就可以直接做轉換的感覺真的很棒!所以我們理論上也可以透過簡單的opencv程式來完成這件事情,再實作之前先去體驗看看Style Transfer 。

按下Create就能進來這個頁面,透過點擊Capture就可以拍照進行轉換也可以點擊Upload a picture上傳照片,總共有4種風格可以選擇:

感覺真的超級酷的!所以我們也來試著實作類似的功能。

 

在JetsonNano中進行風格轉換

 

1.首先要將權重放到Jetson Nano中

我新增了一個weights資料夾並且將pth放入其中,此外還在同一層級新增了jupyter book的程式:

2.重建生成器並導入權重值

這邊可能會有版本問題,像我就必須升級成Torch 1.6版本,而安裝PyTorch的方法我會放在文章結尾補述,回歸正題,還記得剛剛我儲存的時候只有儲存權重對吧,所以我們必須建一個跟當初訓練一模一樣的模型才能匯入哦!所以來複製一下之前寫的生成器吧!

[pastacode lang=”python” manual=”import%20torch%0Afrom%20torch%20import%20nn%0Afrom%20torchsummary%20import%20summary%0A%0A%0Adef%20conv_norm_relu(in_dim%2C%20out_dim%2C%20kernel_size%2C%20stride%20%3D%201%2C%20padding%3D0)%3A%0A%20%20%20%20%0A%20%20%20%20layer%20%3D%20nn.Sequential(nn.Conv2d(in_dim%2C%20out_dim%2C%20kernel_size%2C%20stride%2C%20padding)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.InstanceNorm2d(out_dim)%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.ReLU(True))%0A%20%20%20%20return%20layer%0A%0Adef%20dconv_norm_relu(in_dim%2C%20out_dim%2C%20kernel_size%2C%20stride%20%3D%201%2C%20padding%3D0%2C%20output_padding%3D0)%3A%0A%20%20%20%20%0A%20%20%20%20layer%20%3D%20nn.Sequential(nn.ConvTranspose2d(in_dim%2C%20out_dim%2C%20kernel_size%2C%20stride%2C%20padding%2C%20output_padding)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.InstanceNorm2d(out_dim)%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.ReLU(True))%0A%20%20%20%20return%20layer%0A%0Aclass%20ResidualBlock(nn.Module)%3A%0A%20%20%20%20%0A%20%20%20%20def%20__init__(self%2C%20dim%2C%20use_dropout)%3A%0A%20%20%20%20%20%20%20%20super(ResidualBlock%2C%20self).__init__()%0A%20%20%20%20%20%20%20%20res_block%20%3D%20%5Bnn.ReflectionPad2d(1)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20conv_norm_relu(dim%2C%20dim%2C%20kernel_size%3D3)%5D%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20if%20use_dropout%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20res_block%20%2B%3D%20%5Bnn.Dropout(0.5)%5D%0A%20%20%20%20%20%20%20%20res_block%20%2B%3D%20%5Bnn.ReflectionPad2d(1)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(dim%2C%20dim%2C%20kernel_size%3D3%2C%20padding%3D0)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.InstanceNorm2d(dim)%5D%0A%0A%20%20%20%20%20%20%20%20self.res_block%20%3D%20nn.Sequential(*res_block)%0A%0A%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20return%20x%20%2B%20self.res_block(x)%0A%0Aclass%20Generator(nn.Module)%3A%0A%20%20%20%20%0A%20%20%20%20def%20__init__(self%2C%20input_nc%3D3%2C%20output_nc%3D3%2C%20filters%3D64%2C%20use_dropout%3DTrue%2C%20n_blocks%3D6)%3A%0A%20%20%20%20%20%20%20%20super(Generator%2C%20self).__init__()%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%23%20%E5%90%91%E4%B8%8B%E6%8E%A1%E6%A8%A3%0A%20%20%20%20%20%20%20%20model%20%3D%20%5Bnn.ReflectionPad2d(3)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20conv_norm_relu(input_nc%20%20%20%2C%20filters%20*%201%2C%207)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20conv_norm_relu(filters%20*%201%2C%20filters%20*%202%2C%203%2C%202%2C%201)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20conv_norm_relu(filters%20*%202%2C%20filters%20*%204%2C%203%2C%202%2C%201)%5D%0A%0A%20%20%20%20%20%20%20%20%23%20%E9%A0%B8%E8%84%96%E5%B1%A4%0A%20%20%20%20%20%20%20%20for%20i%20in%20range(n_blocks)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20model%20%2B%3D%20%5BResidualBlock(filters%20*%204%2C%20use_dropout)%5D%0A%0A%20%20%20%20%20%20%20%20%23%20%E5%90%91%E4%B8%8A%E6%8E%A1%E6%A8%A3%0A%20%20%20%20%20%20%20%20model%20%2B%3D%20%5Bdconv_norm_relu(filters%20*%204%2C%20filters%20*%202%2C%203%2C%202%2C%201%2C%201)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20dconv_norm_relu(filters%20*%202%2C%20filters%20*%201%2C%203%2C%202%2C%201%2C%201)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.ReflectionPad2d(3)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(filters%2C%20output_nc%2C%207)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.Tanh()%5D%0A%0A%20%20%20%20%20%20%20%20self.model%20%3D%20nn.Sequential(*model)%20%20%20%20%23%20model%20%E6%98%AF%20list%20%E4%BD%86%E6%98%AF%20sequential%20%E9%9C%80%E8%A6%81%E5%B0%87%E5%85%B6%E9%80%8F%E9%81%8E%20%2C%20%E5%88%86%E5%89%B2%E5%87%BA%E4%BE%86%0A%0A%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20return%20self.model(x)” message=”” highlight=”” provider=”manual”/]

接下來要做實例化模型並導入權重:

[pastacode lang=”python” manual=”def%20init_model()%3A%0A%20%20%20%20%0A%20%20%20%20device%20%3D%20’cuda%3A0’%20if%20torch.cuda.is_available()%20else%20’cpu’%0A%20%20%20%20G_B2A%20%3D%20Generator().to(device)%0A%20%20%20%20G_B2A.load_state_dict(torch.load(os.path.join(%22weights%22%2C%20%22netG_B2A.pth%22)%2C%20map_location%3Ddevice%20))%0A%20%20%20%20G_B2A.eval()%0A%20%20%20%20%0A%20%20%20%20return%20G_B2A” message=”” highlight=”” provider=”manual”/]

3.在Colab中拍照

我先寫了一個副函式來進行模型的預測,丟進去的圖片記得也要做transform,將大小縮放到256、轉換成tensor以及正規化,這部分squeeze目的是要模擬成有batch_size的格式:

[pastacode lang=”python” manual=”def%20test(G%2C%20img)%3A%20%0A%20%20%20%20device%20%3D%20’cuda%3A0’%20if%20torch.cuda.is_available()%20else%20’cpu’%20%0A%20%20%20%20transform%20%3D%20transforms.Compose(%5Btransforms.Resize((256%2C256))%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20transforms.ToTensor()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20transforms.Normalize(mean%3D%5B0.5%2C%200.5%2C%200.5%5D%2C%20std%3D%5B0.5%2C%200.5%2C%200.5%5D)%5D)%0A%0A%20%20%20%20data%20%3D%20transform(img).to(device)%20%0A%20%20%20%20data%20%3D%20data.unsqueeze(0)%20%0A%20%20%20%20out%20%3D%20(0.5%20*%20(G(data).data%20%2B%201.0)).squeeze(0)%20%20%20%20return%20out” message=”” highlight=”” provider=”manual”/]

我們接著使用OpenCV來完成拍照,按下q離開,按下s進行儲存,那我們可以在按下s的時候進行風格轉換,存下兩種風格的圖片,這邊要注意的是PyTorch吃的是PIL的圖檔格式,所以還必須將OpenCV的nparray格式轉換成PIL.Image格式:

[pastacode lang=”python” manual=”if%20__name__%3D%3D’__main__’%3A%0A%20%20%20%20%0A%20%20%20%20G%20%3D%20init_model()%0A%20%20%20%20%0A%20%20%20%20trans_path%20%3D%20’test_transform.jpg’%0A%20%20%20%20org_path%20%3D%20’test_original.jpg’%0A%20%20%20%20%0A%20%20%20%20cap%20%3D%20cv2.VideoCapture(0)%0A%0A%20%20%20%20while(True)%3A%0A%0A%20%20%20%20%20%20%20%20ret%2C%20frame%20%3D%20cap.read()%0A%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20cv2.imshow(‘webcam’%2C%20frame)%0A%0A%20%20%20%20%20%20%20%20key%20%3D%20cv2.waitKey(1)%0A%0A%20%20%20%20%20%20%20%20if%20key%3D%3Dord(‘q’)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20cap.release()%0A%20%20%20%20%20%20%20%20%20%20%20%20cv2.destroyAllWindows()%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%0A%20%20%20%20%20%20%20%20elif%20key%3D%3Dord(‘s’)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20output%20%3D%20test(G%2C%20Image.fromarray(frame))%0A%20%20%20%20%20%20%20%20%20%20%20%20style_img%20%3D%20np.array(output.cpu()).transpose(%5B1%2C2%2C0%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20org_img%20%3D%20cv2.resize(frame%2C%20(256%2C%20256))%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20cv2.imwrite(trans_path%2C%20style_img*255)%0A%20%20%20%20%20%20%20%20%20%20%20%20cv2.imwrite(org_path%2C%20org_img)%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%20%20%20%20%0A%20%20%20%20cap.release()%0A%20%20%20%20cv2.destroyWindow(‘webcam’)%0A%20%20%20%20″ message=”” highlight=”” provider=”manual”/]

執行的畫面如下:

最後再將兩種風格照片合併顯示出來:

[pastacode lang=”python” manual=”%20%20%20%20res%20%3D%20np.concatenate((style_img%2C%20org_img%2F255)%2C%20axis%3D1)%0A%20%20%20%20cv2.imshow(‘res’%2Cres%20)%0A%0A%20%20%20%20cv2.waitKey(0)%0A%20%20%20%20cv2.destroyAllWindows()” message=”” highlight=”” provider=”manual”/]

在Jetson Nano中做即時影像轉換

概念跟拍照轉換雷同,這邊我們直接在取得到攝影機的圖像之後就做風格轉換,我額外寫了一個判斷,按下t可以進行風格轉換,並且用cv2.putText將現在風格的標籤顯示在左上角。

[pastacode lang=”python” manual=”if%20__name__%3D%3D’__main__’%3A%0A%20%20%20%20%0A%20%20%20%20G%20%3D%20init_model()%20%0A%20%20%20%20cap%20%3D%20cv2.VideoCapture(0)%20%0A%20%20%20%20change_style%20%3D%20False%0A%20%20%20%20save_img_name%20%3D%20’test.jpg’%0A%20%20%20%20cv2text%20%3D%20”%0A%0A%20%20%20%20while(True)%3A%20%0A%20%20%20%20%20%20%20%20ret%2C%20frame%20%3D%20cap.read()%0A%20%20%20%20%20%20%20%20%23%20Do%20Something%20Cool%20%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%20%0A%20%20%20%20%20%20%20%20if%20change_style%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20style_img%20%3D%20test(G%2C%20Image.fromarray(frame))%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20np.array(style_img.cpu()).transpose(%5B1%2C2%2C0%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20cv2text%20%3D%20’Style%20Transfer’%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20frame%0A%20%20%20%20%20%20%20%20%20%20%20%20cv2text%20%3D%20’Original’%0A%20%20%20%20%20%20%20%20%20%20%20%20%0A%20%20%20%20%20%20%20%20out%20%3D%20cv2.resize(out%2C%20(512%2C%20512))%0A%20%20%20%20%20%20%20%20out%20%3D%20cv2.putText(out%2C%20f’%7Bcv2text%7D’%2C%20(20%2C%2040)%2C%20cv2.FONT_HERSHEY_SIMPLEX%20%2C%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%201%2C%20(255%2C%20255%2C%20255)%2C%202%2C%20cv2.LINE_AA)%20%0A%0A%20%20%20%20%20%20%20%20%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%23%20%0A%0A%20%20%20%20%20%20%20%20cv2.imshow(‘webcam’%2C%20out)%20%0A%20%20%20%20%20%20%20%20key%20%3D%20cv2.waitKey(1)%20%0A%20%20%20%20%20%20%20%20if%20key%3D%3Dord(‘q’)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%20%20%20%20%20%20%20%20elif%20key%3D%3Dord(‘s’)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20change_style%3D%3DTrue%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20cv2.imwrite(save_img_name%2Cout*255)%0A%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20cv2.imwrite(save_img_name%2Cout)%20%0A%20%20%20%20%20%20%20%20elif%20key%3D%3Dord(‘t’)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20change_style%20%3D%20False%20if%20change_style%20else%20True%0A%20%20%20%20cap.release()%0A%20%20%20%20cv2.destroyAllWindows()” message=”” highlight=”” provider=”manual”/]

即時影像風格轉換成果

結語

這次GAN影像風格轉換的部分就告一段落了,利用Colab來訓練風格轉換的範例真的還是偏硬了一點,雖然我們只有訓練100回合但也跑了半天多一點了,但是!GAN就是個需要耐心的模型,不跑個三天兩夜他是不會給你多好的成效的。

至於在Inference的部分,Jetson Nano還是擔當起重要的角色,稍微有一些延遲不過還算是不錯的了,或許可以考慮透過ONNX轉換成TensorRT再去跑,應該又會加快許多了,下一次又會有什麼GAN的範例大家可以期待一下,或者留言跟我說。

 

補充 – Nano 安裝Torch 1.6的方法

首先,JetPack版本要升級到4.4哦!不然CUDA核心不同這部分官網就有升級教學所以就不多贅述了。

將PyTorch等相依套件更新至1.6版本:

[pastacode lang=”python” manual=”%24%20wget%20https%3A%2F%2Fnvidia.box.com%2Fshared%2Fstatic%2Fyr6sjswn25z7oankw8zy1roow9cy5ur1.whl%20-O%20torch-1.6.0rc2-cp36-cp36m-linux_aarch64.whl%0A%24%20sudo%20apt-get%20install%20python3-pip%20libopenblas-base%20libopenmpi-dev%20%0A%24%20pip3%20install%20Cython%0A%24%20pip3%20install%20torch-1.6.0rc2-cp36-cp36m-linux_aarch64.whl” message=”” highlight=”” provider=”manual”/]

將TorchVision更新至對應版本:

[pastacode lang=”python” manual=”%24%20sudo%20apt-get%20install%20libjpeg-dev%20zlib1g-dev%0A%24%20git%20clone%20–branch%20v0.7.0%20https%3A%2F%2Fgithub.com%2Fpytorch%2Fvision%20torchvision%0A%24%20cd%20torchvision%0A%24%20export%20BUILD_VERSION%3D0.7.0%20%20%23%20where%200.x.0%20is%20the%20torchvision%20version%20%20%0A%24%20sudo%20python3%20setup.py%20install%20%20%20%20%20%23%20use%20python3%20if%20installing%20for%20Python%203.6%0A%24%20cd%20..%2F%20%20%23%20attempting%20to%20load%20torchvision%20from%20build%20dir%20will%20result%20in%20import%20error” message=”” highlight=”” provider=”manual”/]

*本文由RS components 贊助發表,轉載自DesignSpark部落格原文連結 

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *