Maxima's Lab

[Python, Pytorch] DataLoader (데이터 불러오기) 본문

Python/Pytorch

[Python, Pytorch] DataLoader (데이터 불러오기)

Minima 2022. 8. 7. 23:58
728x90
SMALL

안녕하세요, 오늘은 Pytorch 내 DataLoader (데이터 셋 불러오기)에 대해서 알아보겠습니다.

 

torchvision.datasets 내 데이터 셋 예시들 중 CIFAR10를 활용해서 진행해보도록 하겠습니다.

 

 

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

transform = transforms.Compose(
    [transforms.ToTensor()])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)


print(train_dataset.data.shape)
print(np.array(train_dataset.targets).shape)
print(np.unique(train_dataset.targets))

 

(50000, 32, 32, 3) 
(50000,)
[0 1 2 3 4 5 6 7 8 9]

 

torchvision.datasets.CIFAR10을 통해 데이터 셋을 불러올 수있으며, train=True 옵션을 통해 train_dataset을 불러오게 됩니다.

train=False를 사용하여 추가적으로, test_dataset을 불러오면 됩니다. 추가적으로, 불러온 데이터 셋에 대해서 image와 label을 불러오기 위해서는 train_dataset.data, train_dataset.targets를 사용하시면 됩니다. 

 

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


print(test_dataset.data.shape)
print(np.array(test_dataset.targets).shape)
(10000, 32, 32, 3)
(10000,)

 

위의 코드는 test_dataset을 불러온 결과입니다. train_dataset은 총 50000개, test_dataset은 총 10000개 그리고 unique은 label 은 0 ~ 9 (총 10개)로 구성되어있는 것을 확인할 수 있습니다.

 

다음은 각 train_dataset과 test_dataaset에 대해서 torch.utils.data.Dataloder 함수를 사용해보도록 하겠습니다.

 

batch_size = 8
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                          shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset.data), 
                                         shuffle=False)

 

train_dataset은 batch_size = 8을 이용하여 trainloader를 만들었으며, test_dataset은 batch_size = 10000을 이용하여 만들었습니다. train/testloader에 대해 image와 labels를 접근하는 방법은 다음과 같습니다.

 

for _, data in enumerate(trainloader, 0):
  train_images, train_labels = data

for _, data in enumerate(testloader, 0):
  test_images, test_labels = data

 

지금까지, Pytorch 내 torch.utils.data.DataLoader() 함수를 활용하여, DataLoader (데이터 불러오기) 하는 방법에 대해서 알아보았습니다.

728x90
LIST
Comments