import torch import numpy as np import torch.nn as nn import torch.optim as optim import utils from dataset import get_dataloader from torch.autograd import Variable from generator import Generator G = Generator() G.load_state_dict(torch.load('./gen_10epoch_32f_tanh')) trainloader = get_dataloader() it = trainloader.__iter__() for i in range(10): test_ok, test_damaged = it.__next__() generated_image = G(Variable(test_damaged)) generated_image = generated_image.data utils.plot_images(test_damaged, test_ok, generated_image)