import torch from torchvision import transforms import torchvision.transforms.functional as F import utils from torch.autograd import Variable from gan import Generator from PIL import Image import argparse NUM_FEATURES = 16 G = Generator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=3, lr=0.0002) if __name__ == '__main__': parser = argparse.ArgumentParser(description="Shows output") parser.add_argument('image', help='Image to transform', type=str) parser.add_argument('network', help='Network to use', type=str) parser.add_argument('ok', help='Comparison result', nargs='?', type=str) parser.add_argument('-o', help='Output file', nargs='?', type=str, dest='output') args = parser.parse_args() G.load_state_dict(torch.load(args.network)) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) reverse_transform = transforms.Compose([transforms.ToTensor(), transforms.ToPILImage()]) test_damaged = transform(Image.open(args.image).convert("RGB")).unsqueeze_(0) if args.ok: test_ok = transform(Image.open(args.ok).convert("RGB")).unsqueeze_(0) else: test_ok = test_damaged generated_image = G(Variable(test_damaged)) generated_image = generated_image.data if args.output: generated_image = utils.normalize_image(generated_image) im = F.to_pil_image(generated_image) im.convert('RGB').save(args.output) else: utils.plot_images(test_damaged, test_ok, generated_image)