import os
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image

class dataset(data.Dataset):
    def __init__(self, root_path, ok_path="ok_dataset", damaged_path="damaged_dataset"):
        super(dataset, self).__init__()
        self.ok_images_path = os.path.join(root_path, ok_path)
        self.damaged_images_path = os.path.join(root_path, damaged_path)
        self.ok_images_paths = [x for x in os.listdir(self.ok_images_path)]
        self.damaged_images_paths = [x for x in os.listdir(self.damaged_images_path)]
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    def __getitem__(self, i):
        ok_image = Image.open(os.path.join(self.ok_images_path, self.ok_images_paths[i])).convert("RGB")
        ok_image = self.transform(ok_image)
        damaged_image = Image.open(os.path.join(self.damaged_images_path, self.damaged_images_paths[i])).convert("RGB")
        damaged_image = self.transform(damaged_image)
        
        return ok_image, damaged_image
    
    def __len__(self):
        return len(self.ok_images_paths)


def get_dataloader():
    trainset = dataset(".")
    trainloader = data.DataLoader(trainset)
    return trainloader