본문 바로가기
Boostcamp AI Tech/[week 1-5] LEVEL 1

[Week 3 - Day 3 ] Pytorch

by newnu 2021. 8. 20.
반응형

1. 강의 내용

모델 불러오기

model.save()

  • 학습의 결과를 저장하기 위한 함수
  • 모델 형태와 파라미터 저장
  • 모델 학습 중간 과정의 저장을 통해 최선의 결과 모델 선택
  • 만들어진 모델을 외부 연구자와 공유하여 학습 재연성 향상

checkpoints

  • 학습의 중간 결과를 저장하여 최선의 결과를 선택
  • epoch, loss,metric을 함꼐 저장하여 확인

Transfer learning

  • 다른 데이터셋으로 만든 모델을 현재 데이터에 적용
  • 일반적으로 대용햘 데이터셋으로 만들어진 모델의 성능이 좋음
  • backbone architecture가 잘 학습된 모델에서 일부분만 변경하여 학습을 수행
  • Freezing : pretrained model 활용시 모델의 일부분을 frozen 시킴

Monitoring tools for Pytorch

Tensorboard

  • Tensorflow의 프로젝트로 만들어진 시각화 도구
  • 학습 그래프, metric, 학습 결과의 시각화 지원
  • Pytorch도 연결 가능 -> DL 시각화 핵심 도구
  • scalar : metric 등 상수 값의 연속을 표시
  • graph : 모델의 computation graph 표시
  • histogram : weight 등 값의 분포를 표현
  • Image : 예측값과 실제 값을 비교 표시
  • mesh : 3d 형태의 데이터를 표현하는 도구

Weight & biases

  • 협업, code versioning, 실험 결과 기록 등 제공

 

2. 과제 2 - Custom Dataset , Custom DataLoader

 

Dataset

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self,):
    	pass
    def __len__(self):
    	pass
    def __getitem__(self,idx):
    	pass

__init__ : 데이터의 위치나 파일명 초기화, 데이터 불러오기, transforms Compose

__len__ : 데이터셋의 최대 요소 수 반환

__getitem__ :  데이터셋의 idx번째 데이터 반환, 데이터 전처리, 증강

 

DataLoader

모델 학습을 위해 데이터를 미니 배치 단위로 제공

DataLoader(dataset,batch_size = 1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None)

sampler : 불균형 데이터의 경우 클래스의 비율에 맞게 데이터 제공 필요 -> index 컨트롤

                  SequentialSampler : 항상 같은 순서

                  RandomSampler

                  WeightRandomSampler :  가중치에 따른 확률

                  BatchSampler : batch 단위로 sampling 가능

num_workers : 데이터를 불러올 때 사용하는 서브 프로세스 개수

                         무작정 수를 높이면 CPU와 GPU 사이에 많은 데이터 교류 -> 병목 발생

 

collate_fn : sample list를 배치 단위로 바꾸기 위해 사용 , padding 등 데이터의 사이즈를 맞추기 위해

https://www.coastalcreative.com/wp-content/uploads/2019/10/collated-not-collated-543x600.jpg

drop_last : batch 단위로 데이터를 불러올 때 마지막 batch의 길이가 다른 경우 사용하지 않음

time_out : DataLoader가 data를 불러오는데 제한 시간

worker_init_fn : 어떤 worker를 불러올 것인가 리스트로 전달

 

 

torchvision에서 제공하는 transform 함수

transforms.Resize(size) :  size에 맞춰 확대/축소

transforms.RandomCrop(size) : 사이즈 만큼 랜덤으로 crop

transforms.RandomRotation(degrees) : 주어진 각도만큼 회전

transforms.CenterCrop(size) : size만큼 중심 crop

transforms.RandomHorizontalCrop(0.5) : 수평으로 뒤집기

transforms.RandomVerticalCrop(0.5) : 수직으로 뒤집기

 

transforms.ToTensor() : 이미지 데이터를 tensor로 변환

 

transforms.Compose([transforms.Resize(size)),transforms.CenterCrop(150)])(im) : 여러 transforms 묶어서 처리

 

transformation에 의해서 input이 변하면 ground truth 값이 변하는 경우가 있다.(ex. bounding box)

-> imgaug 라이브러리로 해결

 

torchvison의 Dataset

MNIST 손으로 쓴 숫자들로 이루어진 대형 데이터베이스

torchvision.datasets.MNIST(download_path,train=True,transform=transforms.ToTensor(),download=True)

 

CIFAR-10 클래스당 6000개의 이미지를 포함하여 10개 클래스의 60000 컬러 이미지

 

torchtext의 Dataset

AG_NEWS : 100만 개가 넘는 뉴스 기사 모음

 

반응형