2022년 1월 24일(월)부터 28일(금)까지 네이버 부스트캠프(boostcamp) AI Tech 강의를 들으면서 개인적으로 중요하다고 생각되거나 짚고 넘어가야 할 핵심 내용들만 간단하게 메모한 내용입니다. 틀리거나 설명이 부족한 내용이 있을 수 있으며, 이는 학습을 진행하면서 꾸준히 내용을 수정하거나 추가해 나갈 예정입니다.
Dataset & DataLoader

[출처] BITAmin 연합동아리 PyTorch 실습 세션에서 발표용으로 직접 제작한 자료
Model에 데이터를 학습시키기 전에 우선 훈련용, 검증용 데이터에 관한 Dataset과 DataLoader를 지정해줘야 한다.
Dataset
데이터를 모델에 feeding할 때 사용하는 API이다.
또한 모델에 입력으로 주어지는 데이터의 형태와 방식을 정의하는 클래스이다.
이 API의 역할을 세 가지로 정리하면 collecting & cleaning & preprocessing라고 할 수 있다.
- 데이터 입력의 형태를 정의
- 데이터 입력 방식의 표준화
- 데이터의 종류(Image, Audio, Text, ...etc)에 따라 다른 입력을 정의
- __init__()
- 초기 데이터를 어떻게 불러오고 생성하는지를 지정한다.
- __len__()
- __getitem__()
- 하나의 데이터를 불러올 때 어떻게 반환을 할 것인지를 정의한다.
- index 값을 주었을 때 반환되는 데이터의 형태를 설정한다.
Dataset 클래스 생성 시 유의할 점
- 데이터 형태에 따라 각 함수를 다르게 정의한다.
- 모든 것을 데이터 생성 시점에 처리할 필요는 없다.
- 예: image의 Tensor로의 변화는 학습에 필요한 시점(주로 Transform 함수를 통해 정의)에 변환한다.
- 데이터 셋에 대한 표준화된 처리방법을 제공할 필요가 있다.
DataLoader
데이터를 모델의 입력으로 모델에 넣을 때 Batch로 생성해주는 등 어떻게 만들어서 넣어줄지를 결정하는 클래스이다.
DataLoader의 역할을 정리하면 구체적으로 다음과 같다.
- Data의 Batch를 생성해주는 클래스이다.
- 학습 직전(GPU에 feed하기 전) 데이터의 변환을 책임진다.
- Tensor로 변환하고 Bacth 처리하는 것이 중요 업무이다.
- transform에 정의한 전처리, tensor로의 변환 등이 적용된다.
쉽게 말하면 데이터를 모델에 입력으로 넣을 때 batch 단위로 가공하여 입력으로 넣어주는 클래스라고 볼 수 있다.
즉, 데이터를 batch 단위로 잘게 썰어서 모델에게 학습시키는 것이다.
한 번에 모든 데이터를 학습하는 것보다는 batch 단위로 학습을 진행하는 것이 리소스를 절약할 수 있어서 학습 시간을 줄일 수 있어서이다.
그래서 전체 데이터를 입력으로 한 번에 넣지 않고 batch 단위로 잘게 썰어 주는 것이다.
음식도 한 번에 모두 먹으면 체할 수 있듯이 말이다. 🤮
마치 Model에세 데이터를 먹이로 주는 숟가락 🥄역할을 한다고 이해하면 쉽다.
이 DataLoader
은 개발자를 위한 다양한 옵션들을 제공해주고 있다.
여기서 shuffle
은 데이터를 배치 단위로 만들어서 무작위로 섞을지 아닐지를 결정하는 옵션이다.
False
로 지정하면 데이터를 무작위로 섞지 않고, 순차적으로 배치 단위로 만들어서 모델에 먹이게 된다.
DataLoader의 옵션에 대한 자세한 내용은 링크를 참고하면 된다.
collate_fn이란?
우리가 글을 읽고 쓸 때도 알 수 있듯이 어떤 데이터의 길이가 항상 일정하지는 않다는 걸 알 수 있다.
그런데 모델에게 학습 데이터를 입력하려면 일정한 길이로 먹여줘야 된다.
그래서 collate_fn
파라미터에 원하는 함수를 지정해서 batch로 묶일 데이터를 일정한 길이로 묶어줄 수 있도록 해야 한다.
주로 데이터의 길이가 가변적일 때 collate_fn
파라미터를 사용한다고 한다.
데이터의 길이가 가변적일 때 이를 해결하기 위한 방법에는 두 가지가 있다.
첫 번째 방법은 패딩(padding)을 주어서 부족한 길이를 일정한 값(예: 0)으로 채워서 batch를 일정한 길이로 맞춰주는 방법이다.
두 번째 방법은 batch의 길이를 맞추지는 않지만 offset을 이용해서 데이터가 어떻게 끊기는지 그 위치를 지정해 주는 방법이다.
Offset을 사용하는 경우 나중에 모델에서 EmbeddingBag
를 사용하여 해결할 수 있다.
다음은 offset을 사용하여 collate_fn
의 파라미터로 넘길 함수를 정의한 예시이다.