본문 바로가기

딥러닝/GAN

[DCGAN] pytorch DCGAN 튜토리얼 설명

출처 : https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

DCGAN Tutorial — PyTorch Tutorials 1.4.0 documentation

Note Click here to download the full example code DCGAN Tutorial Author: Nathan Inkawhich Introduction This tutorial will give an introduction to DCGANs through an example. We will train a generative adversarial network (GAN) to generate new celebrities af

pytorch.org

위 pytorch 공식 홈페이지의 튜토리얼을 참고하였으며

 

https://github.com/pytorch/tutorials/blob/master/beginner_source/dcgan_faces_tutorial.py

 

설명에 나와있는 위 git 코드의 582-652 라인에 있는 학습 과정이 이해하기 힘들어 정리해보았습니다.

 

공부를 위한 정리이며, 틀린 것이 있으면 댓글 바랍니다 :)

시작!

 


1. Discriminator 업데이트하기


 

 

discriminator 네트워크는 크게 아래 두 개의 loss로부터 gradient를 계산하여 업데이트된다.

1) real 이미지를 real로 판단하면서 얻은 loss

2) fake 이미지를 fake로 판단하면서 얻은 loss

 

 

먼저 1) real 이미지를 real로 판단하는 과정을 살펴보자.

 

 

 

 

real_cpu는 dataloader로부터 받은 실제 이미지이고,

label에는 real에 해당하는 label을 넣었다.

실제 이미지를 netD에 넣어서 real/fake를 판단한 후, 정답과 비교하여 loss(errD_real)가 구해졌다.

 

이 loss로부터 gradient를 계산하는 과정이 errD_real.backward()에 해당한다.

 

 

 

이제 2) fake 이미지를 판단하는 과정을 보자.

 

 

 

 

일단 noise(z)를 만들어준 뒤, netG에 넣어 fake 이미지를 만들어준다.

그리고 netD가 fake 이미지는 fake로 판단해야 되기 때문에 정답 label에는 fake_label을 넣어준다.

 

이제 fake 이미지를 netD에 넣어 real/fake를 판단하는데, 이 과정에서 위의 real 이미지와 달리 detach()가 붙어있다.

.detach()는 원래 requires_grad = False로 만드는 역할을 한다. 여기서 fake 이미지는 netG로부터 만들어지기 때문에 fake에 .detach()를 붙이지 않을 경우, netG의 weight들의 grad까지 계산하게 된다. 하여 netD만 업데이트하고 netG에까지 backpropagation을 할 수 없도록 하기 위해 .detach()를 붙여야 한다.

자세한 내용은 여기를 참고하자.

 

마찬가지로 netD의 판단 결과와 정답을 비교하여 loss(errD_fake)를 구하고, 이 loss로부터 gradient를 계산해준다.

 

 

3) 최종적으로 netD를 업데이트해준다.

 

 

 

 

최종적으로 계산된 loss(errD)는 실제 이미지로부터 구한 loss와 fake 이미지로부터 구한 loss의 합이며,

 

1), 2)에서 계산한 gradient를 활용하여 netD를 업데이트해준다(optimizer.step()).

 

 

 

 

 

2. Generator 업데이트하기


 

 

위에서 noise를 netG에 넣어 만든 fake 이미지를 netD에 넣어 real/fake를 구분한다. 이때 fake는 .detach()를 붙이지 않아 netG의 weight의 gradient까지 계산되도록 한다. (혹시나 netD의 gradient도 이 과정에서 계산되어 업데이트되면 어쩌나 싶을 수 있지만,  이후 코드에서 optimizerD.step()은 없으니 안심해도 좋다)

 

netG는 fake 이미지를 real 이미지처럼 만들어야 하므로 정답 label을 real_label로 준다.

 

netD의 판단 결과와 label을 비교하여 얻어진 loss(errG)로 gradient를 구하고 optimizerG.step()을 통해 generator 네트워크를 업데이트해준다.

 

 

 

 

 


이렇게 netD와 netG의 업데이트를 num_epochs 만큼 반복하면서 서로 능력을 키워나가게 됩니다.

 

 

끝이 어설프지만 끝!입니다 :)