From e38d3c31d5999769b9f15397d7cc9019c3a0cc39 Mon Sep 17 00:00:00 2001 From: simonrgk Date: Sat, 28 Apr 2018 23:30:20 +0200 Subject: [PATCH] =?UTF-8?q?pi=C4=99kno=20i=20dobro?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset.py | 34 +++++++++++ gan.py | 137 +++++++++++++++++++++++++++++++++++++++++++++ prepare_dataset.py | 33 +++++++++++ test_plot.py | 19 +++++++ utils.py | 34 +++++++++++ 5 files changed, 257 insertions(+) create mode 100644 dataset.py create mode 100644 gan.py create mode 100755 prepare_dataset.py create mode 100644 test_plot.py create mode 100644 utils.py diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..f7b545a --- /dev/null +++ b/dataset.py @@ -0,0 +1,34 @@ + +import os +from torch.utils import data +from torchvision import transforms, utils +from PIL import Image + +class dataset(data.Dataset): + def __init__(self, root_path, ok_path = "ok_dataset", damaged_path = "damaged_dataset"): + super(dataset, self).__init__() + self.ok_images_path = os.path.join(root_path, ok_path) + self.damaged_images_path = os.path.join(root_path, damaged_path) + self.ok_images_paths = [x for x in os.listdir(self.ok_images_path)] + self.damaged_images_paths = [x for x in os.listdir(self.damaged_images_path)] + self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + + def __getitem__(self, i): + ok_image = Image.open(os.path.join(self.ok_images_path, self.ok_images_paths[i])).convert("RGB") + ok_image = self.transform(ok_image) + damaged_image = Image.open(os.path.join(self.damaged_images_path, self.damaged_images_paths[i])).convert("RGB") + damaged_image = self.transform(damaged_image) + + return ok_image, damaged_image + + def __len__(self): + return len(self.ok_images_paths) + +def get_dataloader(): + trainset = dataset(".") + trainloader = data.DataLoader(trainset) + return trainloader + + + \ No newline at end of file diff --git a/gan.py b/gan.py new file mode 100644 index 0000000..346f628 --- /dev/null +++ b/gan.py @@ -0,0 +1,137 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.optim as optim +import utils +import csv +import os +from dataset import get_dataloader +from torch.autograd import Variable + +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.input_width = 178 + self.input_height = 218 + self.input_dim = 3 + self.num_features = 32 + self.output_dim = 3 + self.lr = 0.0002 + self.n = 4 + + self.conv = [] + self.add_conv_layer(self.input_dim, self.num_features, activation=False, batch_norm = False) + self.add_conv_layer(self.num_features, self.num_features * 2) + self.add_conv_layer(self.num_features * 2, self.num_features * 4) + self.add_conv_layer(self.num_features * 4, self.num_features * 8, batch_norm = False) + + self.deconv = [] + self.add_deconv_layer(self.num_features * 8, self.num_features * 4, dropout = True) + self.add_deconv_layer(self.num_features * 4, self.num_features * 2) + self.add_deconv_layer(self.num_features * 2, self.num_features) + self.add_deconv_layer(self.num_features, self.output_dim, batch_norm = False) + + + self.conv = nn.Sequential(*self.conv) + self.deconv = nn.Sequential(*self.deconv) + + utils.initialize_weights(self) + + def add_conv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): + layer = [] + layer.append(nn.Conv2d(input_dim, output_dim, kernel)) + if batch_norm: + layer.append(nn.BatchNorm2d(output_dim)) + if activation: + layer.append(nn.LeakyReLU(0.2)) + if dropout: + layer.append(nn.Dropout2d()) + self.conv.append(nn.Sequential(*layer)) + + def add_deconv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False): + layer = [] + if activation: + layer.append(nn.LeakyReLU(0.2)) + layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel)) + if batch_norm: + layer.append(nn.BatchNorm2d(output_dim)) + if dropout: + layer.append(nn.Dropout2d()) + self.deconv.append(nn.Sequential(*layer)) + + def forward(self, x): + x = self.conv(x) + x = self.deconv(x) + """ # To miał← być skip connections, ale nie działa. + conv = [self.conv[0](x)] + for i in range(1, len(self.conv)): + print(i) + conv.append(self.conv[i](conv[-1])) + deconv = [self.deconv[0](conv[-1])] + for i in range(1, len(self.deconv)-1): + print(i) + deconv.append(self.deconv[i](deconv[-1])) + #deconv[-1] = torch.cat((deconv[-1], conv[-1-i]), 1) + deconv.append(self.deconv[-1](deconv[-1])) + """ + x = nn.Tanh()(x) # może i Sigmoid, ale wtedy są dziwne przekolorowana + return x + + +if __name__ == '__main__': + EPOCHS = 3 + output_path = "output" + if not os.path.exists(output_path): + os.makedirs(output_path) + loss_path = os.path.join(output_path, "loss", "loss_epoch_{}.csv") + if not os.path.exists(os.path.dirname(loss_path)): + os.makedirs(os.path.dirname(loss_path)) + model_path = os.path.join(output_path, "model", "gan_epoch_{}.pt") + if not os.path.exists(os.path.dirname(model_path)): + os.makedirs(os.path.dirname(model_path)) + + G = Generator() + print(G) + + BCE_loss = nn.BCELoss() + L1_loss = nn.L1Loss() + L2_loss = nn.KLDivLoss() + + G_optimizer = optim.Adam(G.parameters(), G.lr) + trainloader = get_dataloader() + test_ok, test_damaged = trainloader.__iter__().__next__() + #loss_mean, loss_std = [], [] + + for epoch in range(1, EPOCHS+1, 1): + losses = [] + for i, (ok_image, damaged_image) in enumerate(trainloader): + ok_image, damaged_image = Variable(ok_image), Variable(damaged_image) + G_optimizer.zero_grad() + generated_image = G(damaged_image) + loss = L1_loss(generated_image, ok_image) + loss.backward() + losses.append(loss.data[0]) + G_optimizer.step() + print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, (i + 1) * len(ok_image), len(trainloader.dataset), + 100. * i / len(trainloader), loss.data[0])) + #loss_mean.append(np.mean(losses)) + #loss_std.append(np.std(losses)) + + with open(loss_path.format(epoch), 'w') as csvfile: + fieldnames = ['num_image', 'loss_l1'] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for i, loss_l1 in enumerate(losses): + writer.writerow({'num_image': i, 'loss_l1': loss_l1}) + + if epoch == 1 or epoch % 5 == 0: + torch.save(G.state_dict(), model_path.format(epoch)) + + generated_image = G(Variable(test_damaged)) + generated_image = generated_image.data + utils.plot_images(test_damaged, test_ok, generated_image) + + + #utils.plot_loss(loss_mean, loss_std) + diff --git a/prepare_dataset.py b/prepare_dataset.py new file mode 100755 index 0000000..fb965bf --- /dev/null +++ b/prepare_dataset.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +import argparse +import os +from subprocess import call + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Generates dataset") + parser.add_argument('dataset', help='Dataset directory path', type=str) + parser.add_argument('-d', help='Decals count', type=int, default=8, dest="decals") + parser.add_argument('-n', help='Noise level', type=float, default=.4, dest="noise") + args = parser.parse_args() + + dataset_path = args.dataset + ok_path = "ok_" + dataset_path + damaged_path = "damaged_" + dataset_path + + if not os.path.exists(dataset_path): + raise RuntimeError("fuk") + + if not os.path.exists(damaged_path): + os.makedirs(damaged_path) + if not os.path.exists(ok_path): + os.makedirs(ok_path); + + dataset_images_paths = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path)] + ok_images_paths = [os.path.join(ok_path, f) for f in os.listdir(dataset_path)] + damaged_images_paths = [os.path.join(damaged_path, f) for f in os.listdir(dataset_path)] + + for dataset_image_path, ok_image_path, damaged_image_path in zip(dataset_images_paths, ok_images_paths, damaged_images_paths): + call(['./uss-mariusz.py', dataset_image_path, ok_image_path, damaged_image_path, '-d', str(args.decals), '-n', str(args.noise)]) + \ No newline at end of file diff --git a/test_plot.py b/test_plot.py new file mode 100644 index 0000000..14de609 --- /dev/null +++ b/test_plot.py @@ -0,0 +1,19 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.optim as optim +import utils +from dataset import get_dataloader +from torch.autograd import Variable +from generator import Generator + +G = Generator() +G.load_state_dict(torch.load('./gen_10epoch_32f_tanh')) +trainloader = get_dataloader() +it = trainloader.__iter__() + +for i in range(10): + test_ok, test_damaged = it.__next__() + generated_image = G(Variable(test_damaged)) + generated_image = generated_image.data + utils.plot_images(test_damaged, test_ok, generated_image) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..aff25fc --- /dev/null +++ b/utils.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import numpy as np +import matplotlib.pyplot as plt + +def initialize_weights(net): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + m.weight.data.normal_(0, 0.02) + m.bias.data.zero_() + elif isinstance(m, nn.ConvTranspose2d): + m.weight.data.normal_(0, 0.02) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.02) + m.bias.data.zero_() + +def plot_loss(loss_mean, loss_std): + plt.errorbar(np.linspace(1, len(loss_mean), num=len(loss_mean)), loss_mean, loss_std, linestyle='None', marker='^', capsize=3) + plt.ylim(ymin = 0.0, ymax = 1.0) + plt.show() + +def plot_images(input_image, target_image, generated_image): + fig_size = (input_image.size(2) * 3 / 100, input_image.size(3)/100) + + _, axes = plt.subplots(1, 3, figsize=fig_size) + imgs = [input_image, generated_image, target_image] + for ax, img in zip(axes.flatten(), imgs): + ax.axis('off') + # Scale to 0-255 + img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8) + ax.imshow(img, cmap=None, aspect='equal') + plt.subplots_adjust(wspace=0, hspace=0) + + plt.show() \ No newline at end of file