nn.Module 클래스
import torch.nn as nn
torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel,self).__init__()
self.conv1 = nn.Conv2d(1,20,5)
self.conv2 = nn.Conv2d(20,20,5)
def forward(self,x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
forward : 이 모델이 호출 되었을 때 실행 되는 함수
모든 nn.Module 은 forward() 함수를 가진다
state_dict() : 모듈이 가지고 있는 계산에 쓰일 Parameter
각 모델 파라미터들은 data, grad, requires_grad 변수 등을 가지고 있다
Pretrained Model
미리 학습된 좋은 성능이 검증되어 있는 모델 사용
- torchvision.models
import torchvision.models as models
resnet18 = model.resnet18(pretrained=True)
- timm (pyTorch IMage Models)
PyTorch Image Models (timm) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.
import timm
m = timm.create_model('mobilenetv3_large_100', pretrained=True)
Transfer Learning
CNN base 모델 구조
Input + CNN Backbone + Classifier -> Output
fc = fully connected layer == classifier
학습 데이터 충분한 경우
- high similarity : feature extraction
- Input -> CNN Backbone(freeze) -> Classifier -> Output
- low similarity : fine tuning
- Input -> CNN Backbone -> Classifier -> Output
학습 데이터가 충분하지 않은 경우
-high similarity
- Input -> CNN Backbone(freeze) -> Classifier -> Output
'Boostcamp AI Tech > [week 1-5] LEVEL 1' 카테고리의 다른 글
[Week 4 - Day 5] Pstage 이미지 분류 - Ensemble (0) | 2021.08.28 |
---|---|
[Week 4 - Day 4] Pstage 이미지 분류 - Training & Infernece (0) | 2021.08.27 |
[Week 4 - Day 2] Pstage 이미지 분류 - Data Generation (0) | 2021.08.24 |
[Week 4 - Day 1] Pstage 이미지 분류 - EDA (0) | 2021.08.24 |
[Week 3 - Day 4 ] Pytorch (0) | 2021.08.20 |