piękno i dobro
This commit is contained in:
parent
96f98c38a8
commit
e38d3c31d5
34
dataset.py
Normal file
34
dataset.py
Normal file
@ -0,0 +1,34 @@
|
||||
|
||||
import os
|
||||
from torch.utils import data
|
||||
from torchvision import transforms, utils
|
||||
from PIL import Image
|
||||
|
||||
class dataset(data.Dataset):
|
||||
def __init__(self, root_path, ok_path = "ok_dataset", damaged_path = "damaged_dataset"):
|
||||
super(dataset, self).__init__()
|
||||
self.ok_images_path = os.path.join(root_path, ok_path)
|
||||
self.damaged_images_path = os.path.join(root_path, damaged_path)
|
||||
self.ok_images_paths = [x for x in os.listdir(self.ok_images_path)]
|
||||
self.damaged_images_paths = [x for x in os.listdir(self.damaged_images_path)]
|
||||
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
||||
|
||||
|
||||
def __getitem__(self, i):
|
||||
ok_image = Image.open(os.path.join(self.ok_images_path, self.ok_images_paths[i])).convert("RGB")
|
||||
ok_image = self.transform(ok_image)
|
||||
damaged_image = Image.open(os.path.join(self.damaged_images_path, self.damaged_images_paths[i])).convert("RGB")
|
||||
damaged_image = self.transform(damaged_image)
|
||||
|
||||
return ok_image, damaged_image
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ok_images_paths)
|
||||
|
||||
def get_dataloader():
|
||||
trainset = dataset(".")
|
||||
trainloader = data.DataLoader(trainset)
|
||||
return trainloader
|
||||
|
||||
|
||||
|
137
gan.py
Normal file
137
gan.py
Normal file
@ -0,0 +1,137 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import utils
|
||||
import csv
|
||||
import os
|
||||
from dataset import get_dataloader
|
||||
from torch.autograd import Variable
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super(Generator, self).__init__()
|
||||
self.input_width = 178
|
||||
self.input_height = 218
|
||||
self.input_dim = 3
|
||||
self.num_features = 32
|
||||
self.output_dim = 3
|
||||
self.lr = 0.0002
|
||||
self.n = 4
|
||||
|
||||
self.conv = []
|
||||
self.add_conv_layer(self.input_dim, self.num_features, activation=False, batch_norm = False)
|
||||
self.add_conv_layer(self.num_features, self.num_features * 2)
|
||||
self.add_conv_layer(self.num_features * 2, self.num_features * 4)
|
||||
self.add_conv_layer(self.num_features * 4, self.num_features * 8, batch_norm = False)
|
||||
|
||||
self.deconv = []
|
||||
self.add_deconv_layer(self.num_features * 8, self.num_features * 4, dropout = True)
|
||||
self.add_deconv_layer(self.num_features * 4, self.num_features * 2)
|
||||
self.add_deconv_layer(self.num_features * 2, self.num_features)
|
||||
self.add_deconv_layer(self.num_features, self.output_dim, batch_norm = False)
|
||||
|
||||
|
||||
self.conv = nn.Sequential(*self.conv)
|
||||
self.deconv = nn.Sequential(*self.deconv)
|
||||
|
||||
utils.initialize_weights(self)
|
||||
|
||||
def add_conv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
||||
layer = []
|
||||
layer.append(nn.Conv2d(input_dim, output_dim, kernel))
|
||||
if batch_norm:
|
||||
layer.append(nn.BatchNorm2d(output_dim))
|
||||
if activation:
|
||||
layer.append(nn.LeakyReLU(0.2))
|
||||
if dropout:
|
||||
layer.append(nn.Dropout2d())
|
||||
self.conv.append(nn.Sequential(*layer))
|
||||
|
||||
def add_deconv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
||||
layer = []
|
||||
if activation:
|
||||
layer.append(nn.LeakyReLU(0.2))
|
||||
layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel))
|
||||
if batch_norm:
|
||||
layer.append(nn.BatchNorm2d(output_dim))
|
||||
if dropout:
|
||||
layer.append(nn.Dropout2d())
|
||||
self.deconv.append(nn.Sequential(*layer))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.deconv(x)
|
||||
""" # To miał← być skip connections, ale nie działa.
|
||||
conv = [self.conv[0](x)]
|
||||
for i in range(1, len(self.conv)):
|
||||
print(i)
|
||||
conv.append(self.conv[i](conv[-1]))
|
||||
deconv = [self.deconv[0](conv[-1])]
|
||||
for i in range(1, len(self.deconv)-1):
|
||||
print(i)
|
||||
deconv.append(self.deconv[i](deconv[-1]))
|
||||
#deconv[-1] = torch.cat((deconv[-1], conv[-1-i]), 1)
|
||||
deconv.append(self.deconv[-1](deconv[-1]))
|
||||
"""
|
||||
x = nn.Tanh()(x) # może i Sigmoid, ale wtedy są dziwne przekolorowana
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
EPOCHS = 3
|
||||
output_path = "output"
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
loss_path = os.path.join(output_path, "loss", "loss_epoch_{}.csv")
|
||||
if not os.path.exists(os.path.dirname(loss_path)):
|
||||
os.makedirs(os.path.dirname(loss_path))
|
||||
model_path = os.path.join(output_path, "model", "gan_epoch_{}.pt")
|
||||
if not os.path.exists(os.path.dirname(model_path)):
|
||||
os.makedirs(os.path.dirname(model_path))
|
||||
|
||||
G = Generator()
|
||||
print(G)
|
||||
|
||||
BCE_loss = nn.BCELoss()
|
||||
L1_loss = nn.L1Loss()
|
||||
L2_loss = nn.KLDivLoss()
|
||||
|
||||
G_optimizer = optim.Adam(G.parameters(), G.lr)
|
||||
trainloader = get_dataloader()
|
||||
test_ok, test_damaged = trainloader.__iter__().__next__()
|
||||
#loss_mean, loss_std = [], []
|
||||
|
||||
for epoch in range(1, EPOCHS+1, 1):
|
||||
losses = []
|
||||
for i, (ok_image, damaged_image) in enumerate(trainloader):
|
||||
ok_image, damaged_image = Variable(ok_image), Variable(damaged_image)
|
||||
G_optimizer.zero_grad()
|
||||
generated_image = G(damaged_image)
|
||||
loss = L1_loss(generated_image, ok_image)
|
||||
loss.backward()
|
||||
losses.append(loss.data[0])
|
||||
G_optimizer.step()
|
||||
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, (i + 1) * len(ok_image), len(trainloader.dataset),
|
||||
100. * i / len(trainloader), loss.data[0]))
|
||||
#loss_mean.append(np.mean(losses))
|
||||
#loss_std.append(np.std(losses))
|
||||
|
||||
with open(loss_path.format(epoch), 'w') as csvfile:
|
||||
fieldnames = ['num_image', 'loss_l1']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for i, loss_l1 in enumerate(losses):
|
||||
writer.writerow({'num_image': i, 'loss_l1': loss_l1})
|
||||
|
||||
if epoch == 1 or epoch % 5 == 0:
|
||||
torch.save(G.state_dict(), model_path.format(epoch))
|
||||
|
||||
generated_image = G(Variable(test_damaged))
|
||||
generated_image = generated_image.data
|
||||
utils.plot_images(test_damaged, test_ok, generated_image)
|
||||
|
||||
|
||||
#utils.plot_loss(loss_mean, loss_std)
|
||||
|
33
prepare_dataset.py
Executable file
33
prepare_dataset.py
Executable file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from subprocess import call
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Generates dataset")
|
||||
parser.add_argument('dataset', help='Dataset directory path', type=str)
|
||||
parser.add_argument('-d', help='Decals count', type=int, default=8, dest="decals")
|
||||
parser.add_argument('-n', help='Noise level', type=float, default=.4, dest="noise")
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_path = args.dataset
|
||||
ok_path = "ok_" + dataset_path
|
||||
damaged_path = "damaged_" + dataset_path
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
raise RuntimeError("fuk")
|
||||
|
||||
if not os.path.exists(damaged_path):
|
||||
os.makedirs(damaged_path)
|
||||
if not os.path.exists(ok_path):
|
||||
os.makedirs(ok_path);
|
||||
|
||||
dataset_images_paths = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path)]
|
||||
ok_images_paths = [os.path.join(ok_path, f) for f in os.listdir(dataset_path)]
|
||||
damaged_images_paths = [os.path.join(damaged_path, f) for f in os.listdir(dataset_path)]
|
||||
|
||||
for dataset_image_path, ok_image_path, damaged_image_path in zip(dataset_images_paths, ok_images_paths, damaged_images_paths):
|
||||
call(['./uss-mariusz.py', dataset_image_path, ok_image_path, damaged_image_path, '-d', str(args.decals), '-n', str(args.noise)])
|
||||
|
19
test_plot.py
Normal file
19
test_plot.py
Normal file
@ -0,0 +1,19 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import utils
|
||||
from dataset import get_dataloader
|
||||
from torch.autograd import Variable
|
||||
from generator import Generator
|
||||
|
||||
G = Generator()
|
||||
G.load_state_dict(torch.load('./gen_10epoch_32f_tanh'))
|
||||
trainloader = get_dataloader()
|
||||
it = trainloader.__iter__()
|
||||
|
||||
for i in range(10):
|
||||
test_ok, test_damaged = it.__next__()
|
||||
generated_image = G(Variable(test_damaged))
|
||||
generated_image = generated_image.data
|
||||
utils.plot_images(test_damaged, test_ok, generated_image)
|
34
utils.py
Normal file
34
utils.py
Normal file
@ -0,0 +1,34 @@
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def initialize_weights(net):
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
m.weight.data.normal_(0, 0.02)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
m.weight.data.normal_(0, 0.02)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
m.weight.data.normal_(0, 0.02)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def plot_loss(loss_mean, loss_std):
|
||||
plt.errorbar(np.linspace(1, len(loss_mean), num=len(loss_mean)), loss_mean, loss_std, linestyle='None', marker='^', capsize=3)
|
||||
plt.ylim(ymin = 0.0, ymax = 1.0)
|
||||
plt.show()
|
||||
|
||||
def plot_images(input_image, target_image, generated_image):
|
||||
fig_size = (input_image.size(2) * 3 / 100, input_image.size(3)/100)
|
||||
|
||||
_, axes = plt.subplots(1, 3, figsize=fig_size)
|
||||
imgs = [input_image, generated_image, target_image]
|
||||
for ax, img in zip(axes.flatten(), imgs):
|
||||
ax.axis('off')
|
||||
# Scale to 0-255
|
||||
img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)
|
||||
ax.imshow(img, cmap=None, aspect='equal')
|
||||
plt.subplots_adjust(wspace=0, hspace=0)
|
||||
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user