import torch import numpy as np import torch.nn as nn import torch.optim as optim import utils import csv import os from dataset import get_dataloader from torch.autograd import Variable def add_conv_layer(conv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): layer = [] layer.append(nn.Conv2d(input_dim, output_dim, kernel)) if batch_norm: layer.append(nn.BatchNorm2d(output_dim)) if activation: layer.append(nn.LeakyReLU(0.2)) if dropout: layer.append(nn.Dropout2d()) conv.append(nn.Sequential(*layer)) def add_deconv_layer(deconv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): layer = [] if activation: layer.append(nn.LeakyReLU(0.2)) layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel)) if batch_norm: layer.append(nn.BatchNorm2d(output_dim)) if dropout: layer.append(nn.Dropout2d()) deconv.append(nn.Sequential(*layer)) class Generator(nn.Module): def __init__(self, input_width=178, input_height=218, input_dim=3, num_features=32, output_dim=3, lr=0.0002): super(Generator, self).__init__() self.input_width = input_width self.input_height = input_height self.input_dim = input_dim self.num_features = num_features self.output_dim = output_dim self.lr = lr # bo tak self.n = 4 self.conv = [] add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False) add_conv_layer(self.conv, self.num_features, self.num_features * 2) add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4) add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8, batch_norm = False) self.deconv = [] add_deconv_layer(self.conv, self.num_features * 8, self.num_features * 4, dropout = True) add_deconv_layer(self.conv, self.num_features * 4, self.num_features * 2) add_deconv_layer(self.conv, self.num_features * 2, self.num_features) add_deconv_layer(self.conv, self.num_features, self.output_dim, batch_norm = False) self.conv = nn.Sequential(*self.conv) self.deconv = nn.Sequential(*self.deconv) utils.initialize_weights(self) def forward(self, x): x = self.conv(x) x = self.deconv(x) """ # To miał← być skip connections, ale nie działa. conv = [self.conv[0](x)] for i in range(1, len(self.conv)): print(i) conv.append(self.conv[i](conv[-1])) deconv = [self.deconv[0](conv[-1])] for i in range(1, len(self.deconv)-1): print(i) deconv.append(self.deconv[i](deconv[-1])) #deconv[-1] = torch.cat((deconv[-1], conv[-1-i]), 1) deconv.append(self.deconv[-1](deconv[-1])) """ x = nn.Tanh()(x) # może i Sigmoid, ale wtedy są dziwne przekolorowana return x class Discriminator(nn.Module): def __init__(self, input_width=178, input_height=218, input_dim=1, num_features=32, output_dim=3, lr=0.0002): super(Discriminator, self).__init__() self.input_width = input_width self.input_height = input_height self.input_dim = input_dim self.num_features = num_features self.output_dim = output_dim self.lr = lr self.conv = [] add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False) add_conv_layer(self.conv, self.num_features, self.num_features * 2) add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4) add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8) add_conv_layer(self.conv, self.num_features * 8, self.output_dim, batch_norm = False) #self.conv.append(nn.Linear(self.num_features * 8, self.output_dim)) self.conv = nn.Sequential(*self.conv) utils.initialize_weights(self) def forward(self, x): x = self.conv(x) x = nn.Sigmoid()(x) # bo musi być w <0;1> return x if __name__ == '__main__': EPOCHS = 5 CUDA = False NUM_FEATURES = 16 LAMBDA = 100 output_path = "output" if not os.path.exists(output_path): os.makedirs(output_path) loss_path = os.path.join(output_path, "loss", "loss_epoch_{}.csv") if not os.path.exists(os.path.dirname(loss_path)): os.makedirs(os.path.dirname(loss_path)) model_path = os.path.join(output_path, "model", "gan_epoch_{}.pt") if not os.path.exists(os.path.dirname(model_path)): os.makedirs(os.path.dirname(model_path)) torch.manual_seed(1) G = Generator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=3, lr=0.0002) D = Discriminator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=1, lr=0.0002) print("Generator:\n", G) print("Discriminator:\n", D) BCE_loss = nn.BCELoss() # nie działa, przyjmuje inputy tylko z zakresu <0; 1> L1_loss = nn.L1Loss() # działa, ale rozmyte L2_loss = nn.KLDivLoss() # działa, ale wolne MSE_loss = nn.MSELoss() # działa, ale jakieś dziwne przekolorwania (chociaż zdają się zanikać) G_optimizer = optim.Adam(G.parameters(), G.lr) D_optimizer = optim.Adam(D.parameters(), D.lr) if CUDA: # u mnie nie działa G.cuda() D.cuda() BCE_loss.cuda() L1_loss.cuda() trainloader = get_dataloader() test_ok, test_damaged = trainloader.__iter__().__next__() #loss_mean, loss_std = [], [] def get_real_estimate(size): return Variable(torch.ones(size)) def get_fake_estimate(size): return Variable(torch.zeros(size)) # https://github.com/soumith/ganhacks # Label smoothing: duże rozmycie spowalnia zbieganie D_loss do zera. def get_smooth_real_estimate(size): return Variable(torch.ones(size)) * (0.7 + torch.rand(1)[0]*0.5) def get_smooth_fake_estimate(size): return Variable(torch.ones(size)) * (torch.rand(1)[0]*0.3) for epoch in range(1, EPOCHS+1, 1): losses = [] for i, (ok_image, damaged_image) in enumerate(trainloader): ok_image, damaged_image = Variable(ok_image), Variable(damaged_image) print(ok_image.size()) # TODO: jakieś if'y wykrywające wyglebywanie się GAN-a i zezwalające # naukę tylko D albo G ? # nauka dyskryminacji D_optimizer.zero_grad() print(1) D_real_estimate = D(ok_image) real_estimate = get_smooth_real_estimate(D_real_estimate.size()) print (real_estimate.size()) D_real_loss = BCE_loss(D_real_estimate, real_estimate) #D_real_loss.backward() print(2) generated_image = G(damaged_image).detach() D_fake_estimate = D(generated_image) fake_estimate = get_smooth_fake_estimate(D_fake_estimate.size()) D_fake_loss = BCE_loss(D_fake_estimate, fake_estimate) #D_fake_loss.backward() print(3) D_loss = D_real_loss + D_fake_loss D_loss.backward() D_optimizer.step() print(4) # nauka generacji G_optimizer.zero_grad() generated_image = G(damaged_image) estimate = D(generated_image) G_fake_loss = BCE_loss(estimate, get_real_estimate(estimate.size())) G_L1_loss = LAMBDA * L1_loss(generated_image, ok_image) # ogólnie im większa lambda, tym zachowuje się bardziej jak DCNN i # bardziej zachowuje ogólną kolorystykę G_loss = G_fake_loss + G_L1_loss G_loss.backward() G_optimizer.step() print(5) #losses.append((D_loss.data[0], G_loss.data[0])) #loss = L1_loss(generated_image, ok_image) #loss = BCE_loss(generated_image, ok_image) #loss = used_loss(generated_image, ok_image) #loss.backward() #losses.append(loss.data[0]) #G_optimizer.step() print('Epoch: {} [{}/{} ({:.0f}%)]\tD_loss: {:.6f}\tG_loss: {:.6f}'.format( epoch, (i + 1) * len(ok_image), len(trainloader.dataset), 100. * i / len(trainloader), D_loss.data[0], G_loss.data[0])) #loss_mean.append(np.mean(losses)) #loss_std.append(np.std(losses)) with open(loss_path.format(epoch), 'w') as csvfile: fieldnames = ['num_image', 'd_loss', 'g_loss'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for i, (d_loss, g_loss) in enumerate(losses): writer.writerow({'num_image': i, 'd_loss': d_loss, 'g_loss': g_loss}) if True or epoch == 1 or epoch % 5 == 0: torch.save(G.state_dict(), model_path.format(epoch)) generated_image = G(Variable(test_damaged)) generated_image = generated_image.data utils.plot_images(test_damaged, test_ok, generated_image) #utils.plot_loss(loss_mean, loss_std)