import os from torch.utils import data from torchvision import transforms, utils from PIL import Image class dataset(data.Dataset): def __init__(self, root_path, ok_path="ok_dataset", damaged_path="damaged_dataset"): super(dataset, self).__init__() self.ok_images_path = os.path.join(root_path, ok_path) self.damaged_images_path = os.path.join(root_path, damaged_path) self.ok_images_paths = [x for x in os.listdir(self.ok_images_path)] self.damaged_images_paths = [x for x in os.listdir(self.damaged_images_path)] self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) def __getitem__(self, i): ok_image = Image.open(os.path.join(self.ok_images_path, self.ok_images_paths[i])).convert("RGB") ok_image = self.transform(ok_image) damaged_image = Image.open(os.path.join(self.damaged_images_path, self.damaged_images_paths[i])).convert("RGB") damaged_image = self.transform(damaged_image) return ok_image, damaged_image def __len__(self): return len(self.ok_images_paths) def get_dataloader(): trainset = dataset(".") trainloader = data.DataLoader(trainset) return trainloader