ProjektSI/test_plot.py
2018-04-28 23:30:20 +02:00

19 lines
545 B
Python

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import utils
from dataset import get_dataloader
from torch.autograd import Variable
from generator import Generator
G = Generator()
G.load_state_dict(torch.load('./gen_10epoch_32f_tanh'))
trainloader = get_dataloader()
it = trainloader.__iter__()
for i in range(10):
test_ok, test_damaged = it.__next__()
generated_image = G(Variable(test_damaged))
generated_image = generated_image.data
utils.plot_images(test_damaged, test_ok, generated_image)