ProjektSI/test_plot.py
2018-05-20 09:23:22 +02:00

42 lines
1.5 KiB
Python

import torch
from torchvision import transforms
import torchvision.transforms.functional as F
import utils
from torch.autograd import Variable
from gan import Generator
from PIL import Image
import argparse
NUM_FEATURES = 16
G = Generator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=3, lr=0.0002)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Shows output")
parser.add_argument('image', help='Image to transform', type=str)
parser.add_argument('network', help='Network to use', type=str)
parser.add_argument('ok', help='Comparison result', nargs='?', type=str)
parser.add_argument('-o', help='Output file', nargs='?', type=str, dest='output')
args = parser.parse_args()
G.load_state_dict(torch.load(args.network))
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
reverse_transform = transforms.Compose([transforms.ToTensor(), transforms.ToPILImage()])
test_damaged = transform(Image.open(args.image).convert("RGB")).unsqueeze_(0)
if args.ok:
test_ok = transform(Image.open(args.ok).convert("RGB")).unsqueeze_(0)
else:
test_ok = test_damaged
generated_image = G(Variable(test_damaged))
generated_image = generated_image.data
if args.output:
generated_image = utils.normalize_image(generated_image)
im = F.to_pil_image(generated_image)
im.convert('RGB').save(args.output)
else:
utils.plot_images(test_damaged, test_ok, generated_image)