본문 바로가기

Paper Review/Generative Model

[Code Review] StyleGAN : A Style-Based Generator Architecture for Generative Adversarial Networks (2)

 

앞서 리뷰했던 StyleGAN 논문에 대한 Pytorch 구현체를 통해 어떻게 구현되어있는지를 세세히 살펴보도록 하겠습니다 리뷰하게 된 코드는 star도 많고 다른 repo들에서도 많이 재사용되는 구현체인 https://github.com/rosinality/style-based-gan-pytorch입니다. 라인별 코드를 더 자세히 이해하시고 싶다면 코드 블록에 주석을 달아놓은 영문 comment를 참고해주세요.

전체적인 코드 리뷰 방식은 논문의 아이디어가 구현체 상에서 어떤 방식으로 구현되어 있는지를 확인하고 추가적으로 설명이 필요한 부분을 보충하는 방식입니다. 또한 논문의 아키텍처 혹은 수식에 대응되는 부분은 기호 A, B, C 등을 통해 주석 앞에 [A] 기호를 붙였으니 잘 대응시키면서 확인하시면 도움이 될 것 같습니다.

그럼 시작하겠습니다. :)

 

GitHub - rosinality/style-based-gan-pytorch: Implementation A Style-Based Generator Architecture for Generative Adversarial Netw

Implementation A Style-Based Generator Architecture for Generative Adversarial Networks in PyTorch - GitHub - rosinality/style-based-gan-pytorch: Implementation A Style-Based Generator Architecture...

github.com

 

 

 

StyleGAN Overview

 
1
2
3
4
5
│...
├── dataset.py      # multi-resolution dataset
├── generate.py     # style-mixing
├── model.py        # style-based generator
└── train.py        # progressive growing & mixing regularization
cs

전체적인 stylegan architecture의 overview는 다음과 같고, 이를 구현체 상에서 살펴볼 때 주의 깊게 봐야할 파일들과 살펴봐야 할 내용들은 다음과 같습니다. 그럼 차근차근 살펴보도록 하겠습니다.

 

dataset.py

[ic]dataset.py[/ic]에서는 Progressive Growing 방식의 학습을 위해 다양한 resolution에 해당하는 데이터들을 상황에 맞게 load할 수 있어야 합니다. 이를 위해서 [ic]MultiResolutionDataset[/ic]을 정의하고 있습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class MultiResolutionDataset(Dataset):
    # we can set resolution
    def __init__(self, path, transform, resolution=8):
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )
 
        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)
 
        with self.env.begin(write=Falseas txn:
            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
 
        self.resolution = resolution
        self.transform = transform
 
    def __len__(self):
        return self.length
 
    def __getitem__(self, index):
        # load specific resolution dataset corresponding to lmdb key(resolution)
        with self.env.begin(write=Falseas txn:
            key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
            img_bytes = txn.get(key)
 
        # For decoding to image, we need to decode bytes to image and apply transformation
        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        img = self.transform(img)
 
        return img
cs

LMDB ( Lightning Memory-Mapped Database )

이 때 현 구현체에서는 미리 lmdb 포맷으로 데이터를 저장해 두고 있다가, data를 불러올 때 image로 decode해서 사용합니다. 이때 LMDB가 무엇인지 간단히 살펴보도록 하겠습니다. 

Lightning Memory-Mapped Database (LMDB) is a software library that provides an embedded transactional database in the form of a key-value store. LMDB is written in C with API bindings for several programming languages. LMDB stores arbitrary key/data pairs as byte arrays, has a range-based search capability, supports multiple data items for a single key and has a special mode for appending records (MDB_APPEND) without checking for consistency.[1] LMDB is not a relational database, it is strictly a key-value store like Berkeley DB and dbm.

LMDB는 이름에서도 볼 수 있듯이, 메모리를 적게 사용하기 위한 방식으로 Key-data database 형태로 key-data 쌍을 byte배열로 저장하고 있습니다. 따라서 다음과 같이 특정 resolution에 맞는 데이터에 접근하고자 할 때 resolution을 key로 저장해두었기에 특정 resolution에 해당하는 key에 접근해 byte로 저장된 데이터들을 image로 변환해서 가져올 수 있습니다.

 

 

model.py

결국 stylegan 논문의 핵심 아이디어는 혁신적인 new generator architecture에 있습니다. 따라서  network 구조를 잘 살펴봐야 하고 이를 위해서 stylegan의 architecture가 구현되어 있는 [ic]model.py[/ic]를 살펴보도록 하겠습니다. 논문에서 제시된 style-based generator는 StyledGenerator로 구현되어 있습니다. 이를 찬찬히 따라가면서 이해해봅시다.

Style-based generator Overview

style-based generator의 전체적인 구조를 살펴보면 크게 다음과 같이 synthesis network g, mapping network, style applying & noise injection으로 확인할 수 있습니다. 논문의 핵심 아이디어인 다음과 같은 세 가지 구조를 살펴봅시다.

이때 코드 구현체에서는 각 구조 A, B, C가 추상화되어 있으므로 이를 통해 큰 흐름을 먼저 살펴봅시다. 해당 클래스를 살펴보면 Synthesis Network g가 돌아가기에 앞서 필요한 style vector, noise 등을 준비하고 g network에 태우는 부분입니다. 

아래 클래스에서는 이러한 전체적인 흐름이 어떻게 구현되어 있는지를 확인하고, mixing regularization을 위해서 다양한 latent vector를 한 번에 입력으로 받을 수 있기에 다음과 같이 input list를 만들고 각 style을 가져온다는 점과 noise가 명시적으로 사용하기 위해서 미리 정의해두고 사용한다는 점, mapping network가 정의된 부분 등을 살펴볼 수 있습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Style-based Generator
class StyledGenerator(nn.Module):
    def __init__(self, code_dim=512, n_mlp=8):
        super().__init__()
        
        # [B] Definition of synthesis network g
        self.generator = Generator(code_dim)
        
        # [A.1] PixelNorm is inital normalization about latent vector
        layers = [PixelNorm()]
        
        # [A.2] Definition of Mapping Network, n_mlp same to mappint network depth
        for i in range(n_mlp):
            layers.append(EqualLinear(code_dim, code_dim))
            layers.append(nn.LeakyReLU(0.2))
        self.style = nn.Sequential(*layers)
        
    def forward(
        self,
        input,
        noise=None,
        step=0,
        alpha=-1,
        mean_style=None,
        style_weight=0,
        mixing_range=(-1-1),
    ):
        styles = []
    
           
        if type(input) not in (list, tuple):
            input = [input]
            
        # style-based generator can input multiple latent vector for style mixing
        for i in input:
        
        # [A]
        # latent vector z is mapped to intermediate latent vector w by mapping network
        # self.style mean mapping network
        # here styles mean w list (not style vector which is applied affine transformation to w)
            styles.append(self.style(i))
        batch = input[0].shape[0]
        
        # [C] Explicit Noise definition for noise injection
        # Prepare Noise Map for each step(resolution) shape
        if noise is None:
            noise = []
            for i in range(step + 1):
                size = 4 * 2 ** i
                noise.append(torch.randn(batch, 1, size, size, device=input[0].device))
        
        if mean_style is not None:
            styles_norm = []
            for style in styles:
                styles_norm.append(mean_style + style_weight * (style - mean_style))
            styles = styles_norm
            
        # style vector is applied to synthesis network g by AdaIN
        # this content is explained below
        return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)
    def mean_style(self, input):
        style = self.style(input).mean(0, keepdim=True)
        return style
cs

따라서 mapping network를 통해 W를 setting 했고, noise injection을 위해 explicit noise를 미리 세팅했습니다. 또한, g network또한 구성해놨으니 이제 Synthesis Network g가 돌아가기에 앞서 필요한 것들을 모두 세팅했습니다.

 

Synthesis Network g 

synthesis network g의 전체적인 흐름과 적용되는 method들은 다음과 같습니다. synthesis network는 styled conv block들로 이루어져 있으며 각 block은 progressive growing 방식으로 점점 resolution을 늘려갑니다. 그렇다 보니 특정 resolution에서의 학습 후 이를 평가하기 위한 toRGB layer가 필요하기에 다음과 같이 PGGAN의 아이디어를 그대로 가져와서 적용됩니다. 또한, 이러한 방식이기에 특정 level을 기준으로 다른 latent vector의 style을 반영할 수 있는 mixing regularizatoin이 다음과 같이 가능합니다.

이러한 g의 구현은 pggan의 아이디어를 그대로 가져와 Progressive Growing한 방식의 네트워크가 구현되게 되어있으며 이때 각 module하나하나의 단위를 stylegan의 styled conv block을 통해 stylegan의 아이디어와 pggan의 아이디어를 합쳤습니다. 또한, 학습 과정에서의 Mixing Regualrization과 테스트 과정의 Style Mixing을 위해서 crossover point를 기준으로 latent vector들의 다양한 style을 반영하도록 작성되었습니다.

특히  설정된 [ic]step[/ic]만큼의 progrssion layer만 가져와지고( 모두 있는 리스트에서 for문으로 가져오면서) 현재의 resolution step에 도달했을 때 [ic]torgb[/ic]를 거친 후 [ic]break[/ic]됨에 따라서, 처음에는 4x4를 학습하고 다음에는 8x8을 학습하는 progressive growing이 적용됩니다. 또한, 4x4를 학습한 다음에 8x8을 학습할 때는 딱히 앞부분을 freeze하지는 않습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class Generator(nn.Module):
    def __init__(self, code_dim, fused=True):
        super().__init__()
 
    # [B.1] Setting Styled Conv Block at each resolution 
        self.progression = nn.ModuleList(
            [
                StyledConvBlock(51251231, initial=True),  # 4
                StyledConvBlock(51251231, upsample=True),  # 8
                StyledConvBlock(51251231, upsample=True),  # 16
                StyledConvBlock(51251231, upsample=True),  # 32
                StyledConvBlock(51225631, upsample=True),  # 64
                StyledConvBlock(25612831, upsample=True, fused=fused),  # 128
                StyledConvBlock(1286431, upsample=True, fused=fused),  # 256
                StyledConvBlock(643231, upsample=True, fused=fused),  # 512
                StyledConvBlock(321631, upsample=True, fused=fused),  # 1024
            ]
        )
        
    # [B.2] Setting ToRGB layer at each resolution
        self.to_rgb = nn.ModuleList(
            [
                EqualConv2d(51231),
                EqualConv2d(51231),
                EqualConv2d(51231),
                EqualConv2d(51231),
                EqualConv2d(25631),
                EqualConv2d(12831),
                EqualConv2d(6431),
                EqualConv2d(3231),
                EqualConv2d(1631),
            ]
        )
 
 
    def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1-1)):
    
        # First noise injection ( noise is list of all step explicit noise )
        out = noise[0]
 
    # [B.4] Mixing Regularization(else) or not(if)
        # When Mixing Regularization is applied, inject index mean crossover point
        # len(style) mean the number of using Intermediate Latent vector ( W1,W2, ... )
        if len(style) < 2:
            inject_index = [len(self.progression) + 1]
 
        else:
            inject_index = sorted(random.sample(list(range(step)), len(style) - 1))
 
        
        # loop all resolution blocks
        # [B.4] First, crossover mean which crossover point index in inject index list
        # so when now step i reach crossover point, next style is applying
        # you can adjust mixing range for style mixing with explicit style mixing level range
        crossover = 0
 
        for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
            
            if mixing_range == (-1-1):
                if crossover < len(inject_index) and i > inject_index[crossover]:
                    crossover = min(crossover + 1len(style))
 
                style_step = style[crossover]
 
            else:
                if mixing_range[0<= i <= mixing_range[1]:
                    style_step = style[1]
 
                else:
                    style_step = style[0]
 
            if i > 0 and step > 0:
                out_prev = out
          
            # [B.1] [B.2] Here is same with progressive gan idea
            # loop all resolution blocks
            # Second, pass styled conv block and to rgb (for measure loss in now resolution )        
            out = conv(out, style_step, noise[i])
 
    # bring proper network untill now resolution
            if i == step:
                out = to_rgb(out)
 
                if i > 0 and 0 <= alpha < 1:
                    skip_rgb = self.to_rgb[i - 1](out_prev)
                    skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
                    out = (1 - alpha) * skip_rgb + alpha * out
 
                break
 
        return out
cs

 

Styled Conv Block (1) - overview

synthesis network g는 세부적인 styled conv block들로 구성되어 있습니다. figure의 각 하나의 붉은색(회색) 블록인 styled conv block이 어떤 모듈인지 살펴봅시다.

이 각 styled conv block은 구조와 같이 initail block인지 아닌지에 따라 [ic]constant input[/ic]으로 시작하는지 [ic]upsample[/ic]로 시작하는지가 정해지고 그 이후 [ic]conv[/ic] -> [ic]noise injectoin[/ic] -> [ic]activation function[/ic] -> [ic]AdaIN[/ic] -> [ic]conv[/ic] -> [ic]noise injectoin[/ic] -> [ic]activation function[/ic] -> [ic]AdaIN[/ic] 방식으로 구현되어 있습니다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class StyledConvBlock(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size=3,
        padding=1,
        style_dim=512,
        initial=False,
        upsample=False,
        fused=False,
    ):
        super().__init__()
 
    # initial styled conv block starts with constant input
        if initial:
            self.conv1 = ConstantInput(in_channel)
 
    # every styled conv block which is not initial starts with upsample
        # pytorch implementation upsample is coded with conv transpose
        # fused and not-fused, Blur, EqualConv and EqualLinear are explained below article
        else:
            if upsample:
                if fused:
                    self.conv1 = nn.Sequential(
                        FusedUpsample(
                            in_channel, out_channel, kernel_size, padding=padding
                        ),
                        Blur(out_channel),
                    )
 
                else:
                    self.conv1 = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='nearest'),
                        EqualConv2d(
                            in_channel, out_channel, kernel_size, padding=padding
                        ),
                        Blur(out_channel),
                    )
 
            else:
                self.conv1 = EqualConv2d(
                    in_channel, out_channel, kernel_size, padding=padding
                )
 
        self.noise1 = equal_lr(NoiseInjection(out_channel))
        self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu1 = nn.LeakyReLU(0.2)
 
        self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
        self.noise2 = equal_lr(NoiseInjection(out_channel))
        self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
        self.lrelu2 = nn.LeakyReLU(0.2)
 
    def forward(self, input, style, noise):
        # Every styled conv block is forwarded by below sequence
        # look code with corresponding styled conv block architecture
        out = self.conv1(input)
        out = self.noise1(out, noise)
        out = self.lrelu1(out)
        out = self.adain1(out, style)
 
        out = self.conv2(out)
        out = self.noise2(out, noise)
        out = self.lrelu2(out)
        out = self.adain2(out, style)
 
        return out
cs

다만, 이때 각 layer가 어떻게 구현되어있는지와, 구현체에 등장하는 1.fused and not-fused가 뭐가 다른지 2.Blur를 쓰는 이유가 무엇인지  3.EqualConv and EqualLinear와 같은 layer의 equal이 의미하는 바는 무엇인지가 모호하게 느껴질 수 있습니다. 지금부터 차근차근 살펴보도록 하겠습니다. 

Styled Conv Block (2) - detail

우선 간단히 styled conv block내의 Constant Input , Noise Injection, AdaIN 레이어가 어떻게 구성되어 있는지 살펴봅시다.

 

Constant Input : [ic]nn.Parameter[/ic]를 통해서 learnable한 constant 값을 가집니다.(learnable하다는 부분에 초점을 맞춰 살펴봅시다.)

Noise Injection : [ic]nn.Parameter[/ic]를 통해서 explicit한 noise가 injection될 때 어느정도 learning이 있다는 점을 고려하고 이러한 noise map이 weight를 가지고 injection될 때 feature map에 add된다는 것을 확인합시다.

AdaIN  : Adaptiva Instance Noramlization의 구현이 잘 반영되고 있는지와 함께 intermediate latent vector w를 learned affine transformation을 거쳐서 style vector로 만드는 부분에서 linear layer를 통해 기존 vector크기의 2배만큼 키운 후 반은 scaling factor로 반은 bias factor로 이용한다는 것을 확인할 수 있습니다. ( 특히 위 figure의 AdaIN의 과정을 코드의 각 부분 부분에 잘 대응시켜서 이해합시다.)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
 # [B.3.1]
class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()
 
        # learnable constant input!
        self.input = nn.Parameter(torch.randn(1, channel, size, size))
 
    def forward(self, input):
        batch = input.shape[0]
        out = self.input.repeat(batch, 111)
 
        return out
        
 # [B.3.2]       
class NoiseInjection(nn.Module):
    def __init__(self, channel):
        super().__init__()
        
        # learnable
        self.weight = nn.Parameter(torch.zeros(1, channel, 11))
 
    def forward(self, image, noise):
        # noise map is added to featrue map
        return image + self.weight * noise
        
 # [B.3.3]
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel, style_dim):
        super().__init__()
 
        self.norm = nn.InstanceNorm2d(in_channel)
        
        # learned affine transformation
        self.style = EqualLinear(style_dim, in_channel * 2)
 
        self.style.linear.bias.data[:in_channel] = 1
        self.style.linear.bias.data[in_channel:] = 0
 
    def forward(self, input, style):
    
        # learned affine transformation
        style = self.style(style).unsqueeze(2).unsqueeze(3)
        
        # scaling factor and bias factor
        gamma, beta = style.chunk(21)
 
        out = self.norm(input)
        out = gamma * out + beta
 
        return out
cs

 

Styled Conv Block (3) - Additional Details

앞서 언급했듯 styled conv block의 구현체에 등장하지만 논문에는 명시적으로 등장하지 않고 추후 stylegan2, stylegan2-ada등에서도 계속 사용되는 아이디어인  1.fused and not-fused가 뭐가 다른지 2.Blur를 쓰는 이유가 무엇인지  3.EqualConv and EqualLinear와 같은 layer의 equal이 의미하는 바는 무엇인지 와 같은 추가적으로 디테일한 내용들을 살펴보도록 하겠습니다. 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
if upsample:
 
   # fused upsample
   if fused:
       self.conv1 = nn.Sequential(
                      FusedUpsample( in_channel, out_channel, kernel_size, padding=padding),
                      Blur(out_channel),
                    )
                    
   # not-fused upsample
   else:
       self.conv1 = nn.Sequential(
                       nn.Upsample(scale_factor=2, mode='nearest'),
                       
                       # equal?
                       EqualConv2d(in_channel, out_channel, kernel_size, padding=padding),
                       
                       # why uew blur?
                       Blur(out_channel),
                    )
cs

 

1. FusedUpsample  vs Non-Fused Upsample

앞으로 살펴볼 내용들은 현재 리뷰 중인 구현체뿐 아니라 stylegan에 대한 official tensorflow implementation에도 구현되어 있습니다. 

내용은 간단합니다. upsample을 할 때 그냥 [ic]nn.Upsample[/ic]을 함으로써 생기는 artifact 등이나 노이즈 등을 최소화하기 위해서 learnable한 conv를 한 번 더 통과시키게 됩니다. 즉, 일반적으로 upsample되는 과정은 [ic]rule based upsample[/ic] -> [ic]conv[/ic] 로 이루어지는데 이는 non-fused upsample입니다. fused upsample은 이러한 과정을 합쳐서 [ic]conv-transpose[/ic]하나로 구현하는 방식입니다. 그렇다면 굳이 이 두 가지 방식 중 하나를 채택해서 사용하는 이유는 무엇일까요?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class FusedUpsample(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding=0):
        super().__init__()
 
        weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size)
        bias = torch.zeros(out_channel)
    
        # weight gain -> now we use He initialization
        fan_in = in_channel * kernel_size * kernel_size
        self.multiplier = sqrt(2 / fan_in)
 
        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias)
 
        self.pad = padding
 
    def forward(self, input):
        weight = F.pad(self.weight * self.multiplier, [1111])
        weight = (
            weight[:, :, 1:, 1:]
            + weight[:, :, :-11:]
            + weight[:, :, 1:, :-1]
            + weight[:, :, :-1, :-1]
        ) / 4
 
        # fused
        out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)
 
        return out
cs

이유 또한 간단합니다. 앞으로 고해상도의 이미지를 처리하게 될 때 non-fused upsample을 사용하게 될 경우 각 operation을 개별적으로 수행하기에 fused upsample인 [ic]conv-transpose[/ic]하나를 수행하는 것에 비해서 느리고 memory를 많이 먹게 됩니다.( 성능은 미세하게 non-fused가 좋을 수 있다고 합니다. ) 즉, less memory와 fast를 위해서 fused upsample을 사용하는 것입니다. 따라서 앞서 살펴봤던 synthesis network의 구현에서도 고해상도 이미지에서는 fused upsample을 사용하고 저해상도에서는 non-fused를 사용하고 있는 것을 확인할 수 있습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Generator(nn.Module):
    def __init__(self, code_dim, fused=True):
        super().__init__()
 
        self.progression = nn.ModuleList(
            [
                # low resolution -> non-fused 
                StyledConvBlock(51251231, initial=True),  # 4
                StyledConvBlock(51251231, upsample=True),  # 8
                StyledConvBlock(51251231, upsample=True),  # 16
                StyledConvBlock(51251231, upsample=True),  # 32
                StyledConvBlock(51225631, upsample=True),  # 64
                
                # high resolution -> fused
                StyledConvBlock(25612831, upsample=True, fused=fused),  # 128
                StyledConvBlock(1286431, upsample=True, fused=fused),  # 256
                StyledConvBlock(643231, upsample=True, fused=fused),  # 512
                StyledConvBlock(321631, upsample=True, fused=fused),  # 1024
            ]
        )
 
...
cs

 

* gain

추가적으로 위의 주석에서 weight에 어떤 값을 multiply해주는 부분이 존재하는데 왜 저런 값을 곱해주는지 간단히 살펴보도로 하겠습니다. 해당 이슈를 참고하면 좋은데, non-linear function을 사용할 때 더 안정적인 학습을 위해 일종의 규칙처럼 weight에 특정 값을 multiply해주게 됩니다. 이때 곱해지는 값을 gain이라고 합니다. https://pytorch.org/docs/stable/nn.init.html 여기서 다양한 non-linear function에 맞는 gain값을 확인할 수 있다.

그렇다면 우리가 사용하는 He Initalization의 경우에는 어떤 값이 곱해져야 할까요? 논문을 직접 확인해본 것은 아니지만, 파이토치 Docs를 확인하면 다음과 같이 적혀있습니다.  (He et al., 2015)의 방법은  per-layer normalization을 수행하는데 이는 각 레이어 별로 다음과 같이 미리 정의한 gain값에 [ic]fan_mode (fan_in) = in_channel * kernel_size * kernel_size[/ic] 에 루트 씌운 값으로 나눈 것입니다. gain의 경우 앞서 살펴본 바처럼 안정적인 학습을 위해 weight에 곱해지는 값인데 이를 per layer 즉 레이어별로 normalizatoin을 한 후에 곱해주는 방식을 적용하기 위해서 다음과 같이 per layer normalization constant c를 구해서 나누고 gain을 곱해서 weight에 연산을 해줌으로써 더 안정적인 학습을 하고자 한다는 것입니다. 이를 통해 앞선 식에서 weight에 multiply되는 값이 어떤 의미인지 확인할 수 있을 것 입니다.

2. Blur

blur가 적용되는 이유에 대한 설명을 위해서는 해당 리뷰 논문을 참고해야 합니다. 신호 처리에 대한 이야기도 들어가고 쉽지는 않은 이야기다 보니 reference에 있는 저자의 talk이나 pr12에서 리뷰해주신 영상 혹은 제 블로그의 리뷰글을 참고해주시기 바랍니다. 내용이 많다 보니 따로 적게 되었습니다. 만약 깊게 알고 싶지 않으시다면 blur를 통해서 aliasing을 많이 줄일 수 있다 정도로 넘어가셔도 괜찮을 것 같습니다. 

 

3. Eqaul LR

본 구현체를 살펴보다 보면 Linear layer와 Conv layer가 전부 Eqaul Linear, Equal Conv로 적혀있습니다. 이는 pggan에서 나왔던 equal iearning rate의 아이디어를 적용시켰다는 뜻으로 기존의 conv에 이와 같은 equal lr을 wrapper처럼 씌워줌으로써 정의됩니다. 

eqaul lr이 뭔지에 대해서 간단히 언급하자면 Weight initialization의 아이디어로 특정 방식을 통해서 weight를 조정해줄 경우 동일한 변화폭으로 학습이 되게 (weight 변화폭이 동일한) 할 수 있고 이는 학습에 있어서 더 나은 결과를 가져온다는 것입니다. 세부 과정은 다음과 같습니다.

  • 모든 bias 파라미터는 0으로 설정합니다.
  • 모든 weight 값은 normal distribution을 따르도록 한 후, (He et al., 2015)의 방법을 도입해 per-layer normalization을 수행합니다.
    • 각 레이어별로 per-layer normalization constant를 구한 다음, 해당 값으로 weight를 나눕니다.

이런 방식을 통해 모든 weight의 변화폭 (dynamic range)가 동일하게 될 수 있기에 equal learing rate 동일한 학습률이라고 표현합니다. 따라서 equal lr이 각 layer에 적용되는 방식을 구현체를 통해서 살펴보시면 됩니다. ( 더 자세한 내용을 얻고 싶으시다면 pggan 논문을 참고해주세요. )

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# per layer normalization & re-naming weight information
class EqualLR:
    def __init__(self, name):
        self.name = name
 
    # multiply std to weight ( per layer normalizatoin )
    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        fan_in = weight.data.size(1* weight.data[0][0].numel()
 
        return weight * sqrt(2 / fan_in)
 
    @staticmethod
    def apply(module, name):
        fn = EqualLR(name)
 
        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', nn.Parameter(weight.data))
        module.register_forward_pre_hook(fn)
 
        return fn
 
    def __call__(self, module, input):
        weight = self.compute_weight(module)
        setattr(module, self.name, weight)
 
# per layer normalization wrapper
def equal_lr(module, name='weight'):
    EqualLR.apply(module, name)
 
    return module
 
# layer which is applied equal lr
class EqualConv2d(nn.Module):
    def __init__(self*args, **kwargs):
        super().__init__()
 
        conv = nn.Conv2d(*args, **kwargs)
        
        # weight -> normal distribution
        conv.weight.data.normal_()
        
        # bias set 0
        conv.bias.data.zero_()
        
        # per layer normalization
        self.conv = equal_lr(conv)
 
    def forward(self, input):
        return self.conv(input)
 
# layer which is applied equal lr
class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
 
        linear = nn.Linear(in_dim, out_dim)
        
        # weight -> normal distribution
        linear.weight.data.normal_()
        
        # bias set 0
        linear.bias.data.zero_()
 
        # per layer normalization
        self.linear = equal_lr(linear)
 
    def forward(self, input):
        return self.linear(input)
cs

여기까지 network의 구조를 살펴봤습니다. discriminator의 경우 pggan의 discriminator를 그대로 사용하고 앞서 살펴본 내용들을 discriminator로 바꾼 듯한 모습이기에 따로 적지 않겠습니다.

 

train.py

stylegan의 큰 네트워크 아이디어는 generator에 있기에 학습 과정 자체는 기존의 gan의 학습방식과 동일하고 progressive growing을 고려한다면 pggan의 학습방식과 유사하다고 할 수 있습니다. 물론 loss의 경우 안정적이라고 알려진 WGAN-GP loss나 disentanglement를 measure하는 방식인 PPL과 유사한 식을 가진 R1-loss 등을 사용하고 있습니다.

[ic]train.py[/ic]에서는  progressive growing training이 어떻게 구현되었는지를 살펴보겠습니다. 즉, 어떻게 각 resolution의 data가 load되면서 학습이 되는지를 간단히 살펴보고, 간단히 어떻게 mixing regularization을 하고 있는지를 구현체를 통해 살펴보도록 하겠습니다.

- progressive growing

다음과 같이 각 resolution별로 학습을 하며 샘플 개수([ic]used_sample[/ic])가 일정 수([ic]args.phase*2[/ic])를 넘어가면, [ic]step +=1[/ic] 이 되어 dataloader에서도 다음 resolution의 데이터를 load하는 방식으로 progressive growing한 training이 적용됩니다. 즉, 매 resolution별로 충분히 학습했다고 생각할 값을 미리 정해두고 충분히 많은 이미지 샘플을 학습했다면 다음 resolution의 data를 로드하면서 또 충분한 양의 이미지를 학습하고 이렇게 계속 resolution을 늘려가며 고해상도를 학습하는 것입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    pbar = tqdm(range(3_000_000))
    
    ...
    
    for i in pbar:
        discriminator.zero_grad()
 
        alpha = min(11 / args.phase * (used_sample + 1))
 
        if (resolution == args.init_size and args.ckpt is Noneor final_progress:
            alpha = 1
 
    # set used sample which how many images learn when specific resolution ( args.phase*2 )
        # if sufficient images are learned, we step to next step
        if used_sample > args.phase * 2:
            used_sample = 0
            step += 1
 
            if step > max_step:
                step = max_step
                final_progress = True
                ckpt_step = step + 1
 
            else:
                alpha = 0
                ckpt_step = step
 
            resolution = 4 * 2 ** step
 
            loader = sample_data(
                dataset, args.batch.get(resolution, args.batch_default), resolution
            )
            data_loader = iter(loader)
cs

-mixing regularization

StyleGAN에서 mixing regularization을 이용해 학습하기를 원한다면 다음과 같이 기존의 2배의 latent vector를 만들고 둘을 list로 넘김으로써 학습과정에서 두 개의 latent vector를 받아 mixing regularization이 되도록 학습이 되는 것입니다. 즉, 이런 식으로 다양한 latent vector들의 다양한 level에서의 조합들을 학습하게 함으로써 인접한 두 level의 style간의 correlation이 생기는 것을 막고, regularization을 해줌으로써 local한 왜곡을 개선하고 실제 FID나 diversity관점에서도 성능 향상을 가져옵니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# mixing regularization O
         if args.mixing and random.random() < 0.9:
            gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
                4, b_size, code_size, device='cuda'
            ).chunk(40)
            gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
            gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]
            
# mixing regularization X
        else:
            gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk(
                20
            )
            gen_in1 = gen_in1.squeeze(0)
            gen_in2 = gen_in2.squeeze(0)
cs

generate.py

위와 같이 학습이 되었다면 이제는 이렇게 학습된 모델을 이용해서 여러 이미지들을 생성할 수 있어야 합니다. [ic]generate.py[/ic]에서는 이러한 생성과정을 살펴보도록 하겠습니다. 그냥 생성하는 것은 기존의 gan방식들과 크게 다르지 않기에 stylegan의 주요 아이디어 중 하나인 style mixing을 중점적으로 살펴보도록 하겠습니다.

style mixing은 학습 과정에서 style을 추출해 이를 조합하는 방식으로 학습했기에 다양한 latent vector들의 style을 조합해서 다양한 이미지의 style들이 mixing된 이미지를 생성해낼 수 있는 방법입니다. 구현체에서는 [ic]n_source[/ic]개의 [ic]source code[/ic]와 [ic]n_target[/ic]개의 [ic]target_code[/ic]를 사용해서 n_source * n_target 개의 style mixing 조합을 만들게 됩니다.( step은 정해줄 수 있습니다.) 그러면 논문에서 확인할 수 있었던 다음과 같은 결과들이 나오게 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@torch.no_grad()
def style_mixing(generator, step, mean_style, n_source, n_target, device):
 
    # the number of n_source source_codes
    source_code = torch.randn(n_source, 512).to(device)
    
    # the number of n_target target_codes
    target_code = torch.randn(n_target, 512).to(device)
    
    shape = 4 * 2 ** step
    alpha = 1
 
    images = [torch.ones(13, shape, shape).to(device) * -1]
 
    source_image = generator(
        source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
    )
    target_image = generator(
        target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
    )
  
    images.append(source_image)
 
    # loop for each source, target combination images    
    for i in range(n_target):
        image = generator(
            [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
            step=step,
            alpha=alpha,
            mean_style=mean_style,
            style_weight=0.7,
            mixing_range=(01),
        )
        images.append(target_image[i].unsqueeze(0))
        images.append(image)
 
    images = torch.cat(images, 0)
    
    return images
cs

 

Fin.

그럼 지금까지 stylegan pytorch구현체에 대한 리뷰였습니다. 해당 구현체는 동일 코드 저자분께서 stylegan2-pytorch구현체도 만들어주시면서 전체적인 구조가 거의 유지되기에 잘 봐 두면 다른 코드들을 읽음에도 도움이 될 것이라고 생각합니다. 또한, 해당 저자분의 stylegan2 구현체는 stylegan2-ada pytorch official implemenation에 있는 ofiicail한 implementation에( 조금 high level 구현체입니다. ) 비해서 직관적이기에 많이 재사용되고 있습니다. 도움이 되기를 바랍니다. 감사합니다. :)