From c498c700c404724fafca9f753ca69cdfe10ad70b Mon Sep 17 00:00:00 2001 From: simonrgk Date: Fri, 4 May 2018 21:59:56 +0200 Subject: [PATCH] =?UTF-8?q?trwoga=20i=20dr=C5=BCenie?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gan.py | 209 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 156 insertions(+), 53 deletions(-) diff --git a/gan.py b/gan.py index 346f628..312d9df 100644 --- a/gan.py +++ b/gan.py @@ -8,56 +8,58 @@ import os from dataset import get_dataloader from torch.autograd import Variable + +def add_conv_layer(conv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): + layer = [] + layer.append(nn.Conv2d(input_dim, output_dim, kernel)) + if batch_norm: + layer.append(nn.BatchNorm2d(output_dim)) + if activation: + layer.append(nn.LeakyReLU(0.2)) + if dropout: + layer.append(nn.Dropout2d()) + conv.append(nn.Sequential(*layer)) + +def add_deconv_layer(deconv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): + layer = [] + if activation: + layer.append(nn.LeakyReLU(0.2)) + layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel)) + if batch_norm: + layer.append(nn.BatchNorm2d(output_dim)) + if dropout: + layer.append(nn.Dropout2d()) + deconv.append(nn.Sequential(*layer)) + + class Generator(nn.Module): - def __init__(self): + def __init__(self, input_width=178, input_height=218, input_dim=3, num_features=32, output_dim=3, lr=0.0002): super(Generator, self).__init__() - self.input_width = 178 - self.input_height = 218 - self.input_dim = 3 - self.num_features = 32 - self.output_dim = 3 - self.lr = 0.0002 - self.n = 4 + self.input_width = input_width + self.input_height = input_height + self.input_dim = input_dim + self.num_features = num_features + self.output_dim = output_dim + self.lr = lr # bo tak + self.n = 4 self.conv = [] - self.add_conv_layer(self.input_dim, self.num_features, activation=False, batch_norm = False) - self.add_conv_layer(self.num_features, self.num_features * 2) - self.add_conv_layer(self.num_features * 2, self.num_features * 4) - self.add_conv_layer(self.num_features * 4, self.num_features * 8, batch_norm = False) + add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False) + add_conv_layer(self.conv, self.num_features, self.num_features * 2) + add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4) + add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8, batch_norm = False) self.deconv = [] - self.add_deconv_layer(self.num_features * 8, self.num_features * 4, dropout = True) - self.add_deconv_layer(self.num_features * 4, self.num_features * 2) - self.add_deconv_layer(self.num_features * 2, self.num_features) - self.add_deconv_layer(self.num_features, self.output_dim, batch_norm = False) + add_deconv_layer(self.conv, self.num_features * 8, self.num_features * 4, dropout = True) + add_deconv_layer(self.conv, self.num_features * 4, self.num_features * 2) + add_deconv_layer(self.conv, self.num_features * 2, self.num_features) + add_deconv_layer(self.conv, self.num_features, self.output_dim, batch_norm = False) self.conv = nn.Sequential(*self.conv) self.deconv = nn.Sequential(*self.deconv) utils.initialize_weights(self) - - def add_conv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): - layer = [] - layer.append(nn.Conv2d(input_dim, output_dim, kernel)) - if batch_norm: - layer.append(nn.BatchNorm2d(output_dim)) - if activation: - layer.append(nn.LeakyReLU(0.2)) - if dropout: - layer.append(nn.Dropout2d()) - self.conv.append(nn.Sequential(*layer)) - - def add_deconv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): - layer = [] - if activation: - layer.append(nn.LeakyReLU(0.2)) - layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel)) - if batch_norm: - layer.append(nn.BatchNorm2d(output_dim)) - if dropout: - layer.append(nn.Dropout2d()) - self.deconv.append(nn.Sequential(*layer)) def forward(self, x): x = self.conv(x) @@ -78,8 +80,42 @@ class Generator(nn.Module): return x +class Discriminator(nn.Module): + def __init__(self, input_width=178, input_height=218, input_dim=1, num_features=32, output_dim=3, lr=0.0002): + super(Discriminator, self).__init__() + self.input_width = input_width + self.input_height = input_height + self.input_dim = input_dim + self.num_features = num_features + self.output_dim = output_dim + self.lr = lr + + self.conv = [] + add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False) + add_conv_layer(self.conv, self.num_features, self.num_features * 2) + add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4) + add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8) + add_conv_layer(self.conv, self.num_features * 8, self.output_dim, batch_norm = False) + #self.conv.append(nn.Linear(self.num_features * 8, self.output_dim)) + + self.conv = nn.Sequential(*self.conv) + + utils.initialize_weights(self) + + def forward(self, x): + x = self.conv(x) + + x = nn.Sigmoid()(x) # bo musi być w <0;1> + return x + + + if __name__ == '__main__': - EPOCHS = 3 + EPOCHS = 5 + CUDA = False + NUM_FEATURES = 16 + LAMBDA = 100 + output_path = "output" if not os.path.exists(output_path): os.makedirs(output_path) @@ -90,42 +126,109 @@ if __name__ == '__main__': if not os.path.exists(os.path.dirname(model_path)): os.makedirs(os.path.dirname(model_path)) - G = Generator() - print(G) + torch.manual_seed(1) - BCE_loss = nn.BCELoss() - L1_loss = nn.L1Loss() - L2_loss = nn.KLDivLoss() + G = Generator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=3, lr=0.0002) + D = Discriminator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=1, lr=0.0002) + print("Generator:\n", G) + print("Discriminator:\n", D) + + BCE_loss = nn.BCELoss() # nie działa, przyjmuje inputy tylko z zakresu <0; 1> + L1_loss = nn.L1Loss() # działa, ale rozmyte + L2_loss = nn.KLDivLoss() # działa, ale wolne + MSE_loss = nn.MSELoss() # działa, ale jakieś dziwne przekolorwania (chociaż zdają się zanikać) G_optimizer = optim.Adam(G.parameters(), G.lr) + D_optimizer = optim.Adam(D.parameters(), D.lr) + + if CUDA: # u mnie nie działa + G.cuda() + D.cuda() + BCE_loss.cuda() + L1_loss.cuda() + trainloader = get_dataloader() test_ok, test_damaged = trainloader.__iter__().__next__() #loss_mean, loss_std = [], [] + + def get_real_estimate(size): + return Variable(torch.ones(size)) + + def get_fake_estimate(size): + return Variable(torch.zeros(size)) + + # https://github.com/soumith/ganhacks + # Label smoothing: duże rozmycie spowalnia zbieganie D_loss do zera. + def get_smooth_real_estimate(size): + return Variable(torch.ones(size)) * (0.7 + torch.rand(1)[0]*0.5) + + def get_smooth_fake_estimate(size): + return Variable(torch.ones(size)) * (torch.rand(1)[0]*0.3) for epoch in range(1, EPOCHS+1, 1): losses = [] for i, (ok_image, damaged_image) in enumerate(trainloader): ok_image, damaged_image = Variable(ok_image), Variable(damaged_image) + print(ok_image.size()) + + # TODO: jakieś if'y wykrywające wyglebywanie się GAN-a i zezwalające + # naukę tylko D albo G ? + + # nauka dyskryminacji + + D_optimizer.zero_grad() + print(1) + D_real_estimate = D(ok_image) + real_estimate = get_smooth_real_estimate(D_real_estimate.size()) + print (real_estimate.size()) + D_real_loss = BCE_loss(D_real_estimate, real_estimate) + #D_real_loss.backward() + + print(2) + generated_image = G(damaged_image).detach() + D_fake_estimate = D(generated_image) + fake_estimate = get_smooth_fake_estimate(D_fake_estimate.size()) + D_fake_loss = BCE_loss(D_fake_estimate, fake_estimate) + #D_fake_loss.backward() + print(3) + D_loss = D_real_loss + D_fake_loss + D_loss.backward() + D_optimizer.step() + print(4) + + # nauka generacji G_optimizer.zero_grad() generated_image = G(damaged_image) - loss = L1_loss(generated_image, ok_image) - loss.backward() - losses.append(loss.data[0]) + estimate = D(generated_image) + G_fake_loss = BCE_loss(estimate, get_real_estimate(estimate.size())) + G_L1_loss = LAMBDA * L1_loss(generated_image, ok_image) + # ogólnie im większa lambda, tym zachowuje się bardziej jak DCNN i + # bardziej zachowuje ogólną kolorystykę + G_loss = G_fake_loss + G_L1_loss + G_loss.backward() G_optimizer.step() - print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + print(5) + #losses.append((D_loss.data[0], G_loss.data[0])) + #loss = L1_loss(generated_image, ok_image) + #loss = BCE_loss(generated_image, ok_image) + #loss = used_loss(generated_image, ok_image) + #loss.backward() + #losses.append(loss.data[0]) + #G_optimizer.step() + print('Epoch: {} [{}/{} ({:.0f}%)]\tD_loss: {:.6f}\tG_loss: {:.6f}'.format( epoch, (i + 1) * len(ok_image), len(trainloader.dataset), - 100. * i / len(trainloader), loss.data[0])) + 100. * i / len(trainloader), D_loss.data[0], G_loss.data[0])) #loss_mean.append(np.mean(losses)) #loss_std.append(np.std(losses)) with open(loss_path.format(epoch), 'w') as csvfile: - fieldnames = ['num_image', 'loss_l1'] + fieldnames = ['num_image', 'd_loss', 'g_loss'] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() - for i, loss_l1 in enumerate(losses): - writer.writerow({'num_image': i, 'loss_l1': loss_l1}) + for i, (d_loss, g_loss) in enumerate(losses): + writer.writerow({'num_image': i, 'd_loss': d_loss, 'g_loss': g_loss}) - if epoch == 1 or epoch % 5 == 0: + if True or epoch == 1 or epoch % 5 == 0: torch.save(G.state_dict(), model_path.format(epoch)) generated_image = G(Variable(test_damaged))