이번에는 ResNet에 대해 알아보고 만들어 보겠습니다.
ResNet은 앞서 살펴봤던 VGG보다 Layer가 많이 깊습니다.
ResNet은 잔차학습이라고도 하는데 아래의 그림을 보면 왜 이렇게 불리는지 짐작할 수 있습니다.
ResNet의 특징으로는 layer에 들어가기 전의 값들을 추후 layer들을 통과한 후의 값들과 더해줍니다.
잔차학습을 진행함으로써 Gradient Descent를 예방할 수 있습니다.
개인적인 생각으로는 layer가 반복될수록 큰 특징들보다는 작고 자세한 특징들을 잡아내는데
전, 후의 값들을 더해줌으로써 큰 특징들도 유지하며 작고 자세한 특징들도 어느정도 유지해주는것 같습니다.
이 부분은 저의 개인적인 생각이므로 댓글로 알려주시면 감사하겠습니다.
ResNet에는 BasicBlock과 bottleneck으로 구성되어 있다고 볼 수 있습니다.
Code
colab의 기본 설정을 해줍니다.
model_urls는 pre-trained된 가중치 값들을 다운받는 코드입니다.
일단 ResNet에서 가장 많이 사용될 3*3, 1*1 Conv Layer를 만들어 줍니다.
다음은 위에서 봤던 BasicBlock입니다.
여기에서 조심해야 할것은 downsample입니다.
stride가 2라고 한다면, out += identity 에서 out과 identity의 형태가 달라지므로 downsample을 통해 형태를 맞춰줍니다.
※ 연산 없이 downsample(x) 이렇게 끝내줬는데, 자동으로 1*1 Conv를 통해 downsample을 해준다고 한다.
https://stackoverflow.com/questions/55688645/how-downsample-work-in-resnet-in-pytorch-code
Bottleneck입니다.
여기까지는 위에서 봤던 그림을 그대로 구현한것이라 크게 어려운점은 없습니다.
expansion = 4로 정의해준 이유는 3x3x64 → 1x1x256 으로 바뀌기 때문에 맞춰주기 위해 4로 정의 (64 x 4 = 256 )
이제는 진짜 ResNet을 구현해 볼것인데 어떻게 보면 복잡하고 어떻게 보면 간단합니다.
밑에 zero_init_residual: 이부분은 사용하면 성능이 0.2% ~ 0.3% 정도 오른다고 하는데 논문을 읽어봐야 할것 같다.
여기까지가 ResNet class 완료입니다.
여기서 하나하나 대입을 해가며 맞게 작성이 되었는지 확인을 해보는것이 공부하는데 도움이 많이 될것 같습니다.
[8]을 보겠습니다.
ResNet Class를 선언할때,
def __init__(self, block, layers, num_classes = 1000, zero_init_residual=False)
이곳에 ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 를 그대로 대입해주면 됩니다.
- block = BasicBlock
- layers = [2, 2, 2, 2]
- **kwargs: keyword argument (키워드 = 값)의 줄임말로써 Default 값이 들어간다고 보면 됩니다.
이런식으로 차근차근 대입해보면 ResNet이 완성되는것을 볼 수 있습니다.
또, resnet 18, 50, 152 이렇게 이름이 붙는데
뒤의 숫자들은 Layer의 개수라고 보시면 됩니다.
resnet18을 예로 들면
x = self.conv1(x) ~ 1개
BasicBlock 안에 2개의 Layer 즉, 2 * (2 + 2 + 2 + 2)
마지막 self.fc(x) Layer
총 1 + 2 * (2 + 2 + 2 + 2) +1 = 18
출처:
https://www.youtube.com/watch?v=KbNbWTnlYXs&list=PLIMkM4tgfjnLSOjrEJN31gZATbcj_MpUm&index=38
https://www.youtube.com/watch?v=Qb_bYWcQXqY&list=PLQ28Nx3M4JrhkqBVIXg-i5_CVVoS1UzAv&index=25
'Deep Learning > Pytorch' 카테고리의 다른 글
22_Pytorch_RNN_Practice (0) | 2021.10.07 |
---|---|
21_Pytorch_RNN (0) | 2021.10.06 |
19_Pytorch_VGG16 (0) | 2021.10.05 |
18_Pytorch_CNN_MNIST (0) | 2021.09.30 |
17_Pytorch_CNN (Convolutional_Neural_Network) (0) | 2021.09.30 |