Page 33 -
P. 33

코드 13-3 네트워크(신경망) 생성
                 class Encoder(nn.Module):    인코더 네트워크 생성
                     def __init__(self, encoded_space_dim,fc2_input_dim):
                         super().__init__()

                         self.encoder_cnn = nn.Sequential(
                             nn.Conv2d(1, 8, 3, stride=2, padding=1),
                             nn.ReLU(True),
                             nn.Conv2d(8, 16, 3, stride=2, padding=1),
                             nn.BatchNorm2d(16),
                             nn.ReLU(True),
                             nn.Conv2d(16, 32, 3, stride=2, padding=0),
                             nn.ReLU(True)
                         )    이미지 데이터셋 처리를 위해 합성곱 신경망 이용

                         self.flatten = nn.Flatten(start_dim=1)    완전연결층
                         self.encoder_lin = nn.Sequential(
                             nn.Linear(3 * 3 * 32, 128),
                             nn.ReLU(True),
                             nn.Linear(128, encoded_space_dim)
                         )    출력 계층

                     def forward(self, x):
                         x = self.encoder_cnn(x)
                         x = self.flatten(x)
                         x = self.encoder_lin(x)
                         return x

                 class Decoder(nn.Module):    디코더 네트워크 생성
                     def __init__(self, encoded_space_dim, fc2_input_dim):
                         super().__init__()
                         self.decoder_lin = nn.Sequential(
                             nn.Linear(encoded_space_dim, 128),
                             nn.ReLU(True),
                             nn.Linear(128, 3 * 3 * 32),
                             nn.ReLU(True)
                         )    인코더의 출력을 디코더의 입력으로 사용

                         self.unflatten = nn.Unflatten(dim=1,
                                                       unflattened_size=(32, 3, 3))   인코더의 완전연결층에 대응
                         self.decoder_conv = nn.Sequential(
                             nn.ConvTranspose2d(32, 16, 3,
                                                stride=2, output_padding=0),

         688
   28   29   30   31   32   33   34   35