138 lines
5.2 KiB
Python
138 lines
5.2 KiB
Python
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import utils
|
|
import csv
|
|
import os
|
|
from dataset import get_dataloader
|
|
from torch.autograd import Variable
|
|
|
|
class Generator(nn.Module):
|
|
def __init__(self):
|
|
super(Generator, self).__init__()
|
|
self.input_width = 178
|
|
self.input_height = 218
|
|
self.input_dim = 3
|
|
self.num_features = 32
|
|
self.output_dim = 3
|
|
self.lr = 0.0002
|
|
self.n = 4
|
|
|
|
self.conv = []
|
|
self.add_conv_layer(self.input_dim, self.num_features, activation=False, batch_norm = False)
|
|
self.add_conv_layer(self.num_features, self.num_features * 2)
|
|
self.add_conv_layer(self.num_features * 2, self.num_features * 4)
|
|
self.add_conv_layer(self.num_features * 4, self.num_features * 8, batch_norm = False)
|
|
|
|
self.deconv = []
|
|
self.add_deconv_layer(self.num_features * 8, self.num_features * 4, dropout = True)
|
|
self.add_deconv_layer(self.num_features * 4, self.num_features * 2)
|
|
self.add_deconv_layer(self.num_features * 2, self.num_features)
|
|
self.add_deconv_layer(self.num_features, self.output_dim, batch_norm = False)
|
|
|
|
|
|
self.conv = nn.Sequential(*self.conv)
|
|
self.deconv = nn.Sequential(*self.deconv)
|
|
|
|
utils.initialize_weights(self)
|
|
|
|
def add_conv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
|
layer = []
|
|
layer.append(nn.Conv2d(input_dim, output_dim, kernel))
|
|
if batch_norm:
|
|
layer.append(nn.BatchNorm2d(output_dim))
|
|
if activation:
|
|
layer.append(nn.LeakyReLU(0.2))
|
|
if dropout:
|
|
layer.append(nn.Dropout2d())
|
|
self.conv.append(nn.Sequential(*layer))
|
|
|
|
def add_deconv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
|
layer = []
|
|
if activation:
|
|
layer.append(nn.LeakyReLU(0.2))
|
|
layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel))
|
|
if batch_norm:
|
|
layer.append(nn.BatchNorm2d(output_dim))
|
|
if dropout:
|
|
layer.append(nn.Dropout2d())
|
|
self.deconv.append(nn.Sequential(*layer))
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.deconv(x)
|
|
""" # To miał← być skip connections, ale nie działa.
|
|
conv = [self.conv[0](x)]
|
|
for i in range(1, len(self.conv)):
|
|
print(i)
|
|
conv.append(self.conv[i](conv[-1]))
|
|
deconv = [self.deconv[0](conv[-1])]
|
|
for i in range(1, len(self.deconv)-1):
|
|
print(i)
|
|
deconv.append(self.deconv[i](deconv[-1]))
|
|
#deconv[-1] = torch.cat((deconv[-1], conv[-1-i]), 1)
|
|
deconv.append(self.deconv[-1](deconv[-1]))
|
|
"""
|
|
x = nn.Tanh()(x) # może i Sigmoid, ale wtedy są dziwne przekolorowana
|
|
return x
|
|
|
|
|
|
if __name__ == '__main__':
|
|
EPOCHS = 3
|
|
output_path = "output"
|
|
if not os.path.exists(output_path):
|
|
os.makedirs(output_path)
|
|
loss_path = os.path.join(output_path, "loss", "loss_epoch_{}.csv")
|
|
if not os.path.exists(os.path.dirname(loss_path)):
|
|
os.makedirs(os.path.dirname(loss_path))
|
|
model_path = os.path.join(output_path, "model", "gan_epoch_{}.pt")
|
|
if not os.path.exists(os.path.dirname(model_path)):
|
|
os.makedirs(os.path.dirname(model_path))
|
|
|
|
G = Generator()
|
|
print(G)
|
|
|
|
BCE_loss = nn.BCELoss()
|
|
L1_loss = nn.L1Loss()
|
|
L2_loss = nn.KLDivLoss()
|
|
|
|
G_optimizer = optim.Adam(G.parameters(), G.lr)
|
|
trainloader = get_dataloader()
|
|
test_ok, test_damaged = trainloader.__iter__().__next__()
|
|
#loss_mean, loss_std = [], []
|
|
|
|
for epoch in range(1, EPOCHS+1, 1):
|
|
losses = []
|
|
for i, (ok_image, damaged_image) in enumerate(trainloader):
|
|
ok_image, damaged_image = Variable(ok_image), Variable(damaged_image)
|
|
G_optimizer.zero_grad()
|
|
generated_image = G(damaged_image)
|
|
loss = L1_loss(generated_image, ok_image)
|
|
loss.backward()
|
|
losses.append(loss.data[0])
|
|
G_optimizer.step()
|
|
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
|
epoch, (i + 1) * len(ok_image), len(trainloader.dataset),
|
|
100. * i / len(trainloader), loss.data[0]))
|
|
#loss_mean.append(np.mean(losses))
|
|
#loss_std.append(np.std(losses))
|
|
|
|
with open(loss_path.format(epoch), 'w') as csvfile:
|
|
fieldnames = ['num_image', 'loss_l1']
|
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for i, loss_l1 in enumerate(losses):
|
|
writer.writerow({'num_image': i, 'loss_l1': loss_l1})
|
|
|
|
if epoch == 1 or epoch % 5 == 0:
|
|
torch.save(G.state_dict(), model_path.format(epoch))
|
|
|
|
generated_image = G(Variable(test_damaged))
|
|
generated_image = generated_image.data
|
|
utils.plot_images(test_damaged, test_ok, generated_image)
|
|
|
|
|
|
#utils.plot_loss(loss_mean, loss_std)
|
|
|