19 lines
545 B
Python
19 lines
545 B
Python
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import utils
|
|
from dataset import get_dataloader
|
|
from torch.autograd import Variable
|
|
from generator import Generator
|
|
|
|
G = Generator()
|
|
G.load_state_dict(torch.load('./gen_10epoch_32f_tanh'))
|
|
trainloader = get_dataloader()
|
|
it = trainloader.__iter__()
|
|
|
|
for i in range(10):
|
|
test_ok, test_damaged = it.__next__()
|
|
generated_image = G(Variable(test_damaged))
|
|
generated_image = generated_image.data
|
|
utils.plot_images(test_damaged, test_ok, generated_image) |