42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
import torch
|
|
from torchvision import transforms
|
|
import torchvision.transforms.functional as F
|
|
import utils
|
|
from torch.autograd import Variable
|
|
from gan import Generator
|
|
from PIL import Image
|
|
import argparse
|
|
|
|
NUM_FEATURES = 16
|
|
|
|
G = Generator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=3, lr=0.0002)
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Shows output")
|
|
parser.add_argument('image', help='Image to transform', type=str)
|
|
parser.add_argument('network', help='Network to use', type=str)
|
|
parser.add_argument('ok', help='Comparison result', nargs='?', type=str)
|
|
parser.add_argument('-o', help='Output file', nargs='?', type=str, dest='output')
|
|
|
|
args = parser.parse_args()
|
|
|
|
G.load_state_dict(torch.load(args.network))
|
|
|
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
reverse_transform = transforms.Compose([transforms.ToTensor(), transforms.ToPILImage()])
|
|
|
|
test_damaged = transform(Image.open(args.image).convert("RGB")).unsqueeze_(0)
|
|
if args.ok:
|
|
test_ok = transform(Image.open(args.ok).convert("RGB")).unsqueeze_(0)
|
|
else:
|
|
test_ok = test_damaged
|
|
|
|
generated_image = G(Variable(test_damaged))
|
|
generated_image = generated_image.data
|
|
|
|
if args.output:
|
|
generated_image = utils.normalize_image(generated_image)
|
|
im = F.to_pil_image(generated_image)
|
|
im.convert('RGB').save(args.output)
|
|
else:
|
|
utils.plot_images(test_damaged, test_ok, generated_image)
|