ProjektSI/dataset.py
2018-05-03 20:10:15 +02:00

31 lines
1.2 KiB
Python

import os
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
class dataset(data.Dataset):
def __init__(self, root_path, ok_path="ok_dataset", damaged_path="damaged_dataset"):
super(dataset, self).__init__()
self.ok_images_path = os.path.join(root_path, ok_path)
self.damaged_images_path = os.path.join(root_path, damaged_path)
self.ok_images_paths = [x for x in os.listdir(self.ok_images_path)]
self.damaged_images_paths = [x for x in os.listdir(self.damaged_images_path)]
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def __getitem__(self, i):
ok_image = Image.open(os.path.join(self.ok_images_path, self.ok_images_paths[i])).convert("RGB")
ok_image = self.transform(ok_image)
damaged_image = Image.open(os.path.join(self.damaged_images_path, self.damaged_images_paths[i])).convert("RGB")
damaged_image = self.transform(damaged_image)
return ok_image, damaged_image
def __len__(self):
return len(self.ok_images_paths)
def get_dataloader():
trainset = dataset(".")
trainloader = data.DataLoader(trainset)
return trainloader