31 lines
1.2 KiB
Python
31 lines
1.2 KiB
Python
import os
|
|
from torch.utils import data
|
|
from torchvision import transforms, utils
|
|
from PIL import Image
|
|
|
|
class dataset(data.Dataset):
|
|
def __init__(self, root_path, ok_path="ok_dataset", damaged_path="damaged_dataset"):
|
|
super(dataset, self).__init__()
|
|
self.ok_images_path = os.path.join(root_path, ok_path)
|
|
self.damaged_images_path = os.path.join(root_path, damaged_path)
|
|
self.ok_images_paths = [x for x in os.listdir(self.ok_images_path)]
|
|
self.damaged_images_paths = [x for x in os.listdir(self.damaged_images_path)]
|
|
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
|
def __getitem__(self, i):
|
|
ok_image = Image.open(os.path.join(self.ok_images_path, self.ok_images_paths[i])).convert("RGB")
|
|
ok_image = self.transform(ok_image)
|
|
damaged_image = Image.open(os.path.join(self.damaged_images_path, self.damaged_images_paths[i])).convert("RGB")
|
|
damaged_image = self.transform(damaged_image)
|
|
|
|
return ok_image, damaged_image
|
|
|
|
def __len__(self):
|
|
return len(self.ok_images_paths)
|
|
|
|
|
|
def get_dataloader():
|
|
trainset = dataset(".")
|
|
trainloader = data.DataLoader(trainset)
|
|
return trainloader
|