trwoga i drżenie
This commit is contained in:
parent
a8979d81b0
commit
c498c700c4
209
gan.py
209
gan.py
@ -8,56 +8,58 @@ import os
|
||||
from dataset import get_dataloader
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def add_conv_layer(conv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
||||
layer = []
|
||||
layer.append(nn.Conv2d(input_dim, output_dim, kernel))
|
||||
if batch_norm:
|
||||
layer.append(nn.BatchNorm2d(output_dim))
|
||||
if activation:
|
||||
layer.append(nn.LeakyReLU(0.2))
|
||||
if dropout:
|
||||
layer.append(nn.Dropout2d())
|
||||
conv.append(nn.Sequential(*layer))
|
||||
|
||||
def add_deconv_layer(deconv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
||||
layer = []
|
||||
if activation:
|
||||
layer.append(nn.LeakyReLU(0.2))
|
||||
layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel))
|
||||
if batch_norm:
|
||||
layer.append(nn.BatchNorm2d(output_dim))
|
||||
if dropout:
|
||||
layer.append(nn.Dropout2d())
|
||||
deconv.append(nn.Sequential(*layer))
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, input_width=178, input_height=218, input_dim=3, num_features=32, output_dim=3, lr=0.0002):
|
||||
super(Generator, self).__init__()
|
||||
self.input_width = 178
|
||||
self.input_height = 218
|
||||
self.input_dim = 3
|
||||
self.num_features = 32
|
||||
self.output_dim = 3
|
||||
self.lr = 0.0002
|
||||
self.n = 4
|
||||
self.input_width = input_width
|
||||
self.input_height = input_height
|
||||
self.input_dim = input_dim
|
||||
self.num_features = num_features
|
||||
self.output_dim = output_dim
|
||||
self.lr = lr # bo tak
|
||||
self.n = 4
|
||||
|
||||
self.conv = []
|
||||
self.add_conv_layer(self.input_dim, self.num_features, activation=False, batch_norm = False)
|
||||
self.add_conv_layer(self.num_features, self.num_features * 2)
|
||||
self.add_conv_layer(self.num_features * 2, self.num_features * 4)
|
||||
self.add_conv_layer(self.num_features * 4, self.num_features * 8, batch_norm = False)
|
||||
add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False)
|
||||
add_conv_layer(self.conv, self.num_features, self.num_features * 2)
|
||||
add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4)
|
||||
add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8, batch_norm = False)
|
||||
|
||||
self.deconv = []
|
||||
self.add_deconv_layer(self.num_features * 8, self.num_features * 4, dropout = True)
|
||||
self.add_deconv_layer(self.num_features * 4, self.num_features * 2)
|
||||
self.add_deconv_layer(self.num_features * 2, self.num_features)
|
||||
self.add_deconv_layer(self.num_features, self.output_dim, batch_norm = False)
|
||||
add_deconv_layer(self.conv, self.num_features * 8, self.num_features * 4, dropout = True)
|
||||
add_deconv_layer(self.conv, self.num_features * 4, self.num_features * 2)
|
||||
add_deconv_layer(self.conv, self.num_features * 2, self.num_features)
|
||||
add_deconv_layer(self.conv, self.num_features, self.output_dim, batch_norm = False)
|
||||
|
||||
|
||||
self.conv = nn.Sequential(*self.conv)
|
||||
self.deconv = nn.Sequential(*self.deconv)
|
||||
|
||||
utils.initialize_weights(self)
|
||||
|
||||
def add_conv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
||||
layer = []
|
||||
layer.append(nn.Conv2d(input_dim, output_dim, kernel))
|
||||
if batch_norm:
|
||||
layer.append(nn.BatchNorm2d(output_dim))
|
||||
if activation:
|
||||
layer.append(nn.LeakyReLU(0.2))
|
||||
if dropout:
|
||||
layer.append(nn.Dropout2d())
|
||||
self.conv.append(nn.Sequential(*layer))
|
||||
|
||||
def add_deconv_layer(self, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
|
||||
layer = []
|
||||
if activation:
|
||||
layer.append(nn.LeakyReLU(0.2))
|
||||
layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel))
|
||||
if batch_norm:
|
||||
layer.append(nn.BatchNorm2d(output_dim))
|
||||
if dropout:
|
||||
layer.append(nn.Dropout2d())
|
||||
self.deconv.append(nn.Sequential(*layer))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
@ -78,8 +80,42 @@ class Generator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, input_width=178, input_height=218, input_dim=1, num_features=32, output_dim=3, lr=0.0002):
|
||||
super(Discriminator, self).__init__()
|
||||
self.input_width = input_width
|
||||
self.input_height = input_height
|
||||
self.input_dim = input_dim
|
||||
self.num_features = num_features
|
||||
self.output_dim = output_dim
|
||||
self.lr = lr
|
||||
|
||||
self.conv = []
|
||||
add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False)
|
||||
add_conv_layer(self.conv, self.num_features, self.num_features * 2)
|
||||
add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4)
|
||||
add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8)
|
||||
add_conv_layer(self.conv, self.num_features * 8, self.output_dim, batch_norm = False)
|
||||
#self.conv.append(nn.Linear(self.num_features * 8, self.output_dim))
|
||||
|
||||
self.conv = nn.Sequential(*self.conv)
|
||||
|
||||
utils.initialize_weights(self)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
|
||||
x = nn.Sigmoid()(x) # bo musi być w <0;1>
|
||||
return x
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
EPOCHS = 3
|
||||
EPOCHS = 5
|
||||
CUDA = False
|
||||
NUM_FEATURES = 16
|
||||
LAMBDA = 100
|
||||
|
||||
output_path = "output"
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
@ -90,42 +126,109 @@ if __name__ == '__main__':
|
||||
if not os.path.exists(os.path.dirname(model_path)):
|
||||
os.makedirs(os.path.dirname(model_path))
|
||||
|
||||
G = Generator()
|
||||
print(G)
|
||||
torch.manual_seed(1)
|
||||
|
||||
BCE_loss = nn.BCELoss()
|
||||
L1_loss = nn.L1Loss()
|
||||
L2_loss = nn.KLDivLoss()
|
||||
G = Generator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=3, lr=0.0002)
|
||||
D = Discriminator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=1, lr=0.0002)
|
||||
print("Generator:\n", G)
|
||||
print("Discriminator:\n", D)
|
||||
|
||||
BCE_loss = nn.BCELoss() # nie działa, przyjmuje inputy tylko z zakresu <0; 1>
|
||||
L1_loss = nn.L1Loss() # działa, ale rozmyte
|
||||
L2_loss = nn.KLDivLoss() # działa, ale wolne
|
||||
MSE_loss = nn.MSELoss() # działa, ale jakieś dziwne przekolorwania (chociaż zdają się zanikać)
|
||||
|
||||
G_optimizer = optim.Adam(G.parameters(), G.lr)
|
||||
D_optimizer = optim.Adam(D.parameters(), D.lr)
|
||||
|
||||
if CUDA: # u mnie nie działa
|
||||
G.cuda()
|
||||
D.cuda()
|
||||
BCE_loss.cuda()
|
||||
L1_loss.cuda()
|
||||
|
||||
trainloader = get_dataloader()
|
||||
test_ok, test_damaged = trainloader.__iter__().__next__()
|
||||
#loss_mean, loss_std = [], []
|
||||
|
||||
def get_real_estimate(size):
|
||||
return Variable(torch.ones(size))
|
||||
|
||||
def get_fake_estimate(size):
|
||||
return Variable(torch.zeros(size))
|
||||
|
||||
# https://github.com/soumith/ganhacks
|
||||
# Label smoothing: duże rozmycie spowalnia zbieganie D_loss do zera.
|
||||
def get_smooth_real_estimate(size):
|
||||
return Variable(torch.ones(size)) * (0.7 + torch.rand(1)[0]*0.5)
|
||||
|
||||
def get_smooth_fake_estimate(size):
|
||||
return Variable(torch.ones(size)) * (torch.rand(1)[0]*0.3)
|
||||
|
||||
for epoch in range(1, EPOCHS+1, 1):
|
||||
losses = []
|
||||
for i, (ok_image, damaged_image) in enumerate(trainloader):
|
||||
ok_image, damaged_image = Variable(ok_image), Variable(damaged_image)
|
||||
print(ok_image.size())
|
||||
|
||||
# TODO: jakieś if'y wykrywające wyglebywanie się GAN-a i zezwalające
|
||||
# naukę tylko D albo G ?
|
||||
|
||||
# nauka dyskryminacji
|
||||
|
||||
D_optimizer.zero_grad()
|
||||
print(1)
|
||||
D_real_estimate = D(ok_image)
|
||||
real_estimate = get_smooth_real_estimate(D_real_estimate.size())
|
||||
print (real_estimate.size())
|
||||
D_real_loss = BCE_loss(D_real_estimate, real_estimate)
|
||||
#D_real_loss.backward()
|
||||
|
||||
print(2)
|
||||
generated_image = G(damaged_image).detach()
|
||||
D_fake_estimate = D(generated_image)
|
||||
fake_estimate = get_smooth_fake_estimate(D_fake_estimate.size())
|
||||
D_fake_loss = BCE_loss(D_fake_estimate, fake_estimate)
|
||||
#D_fake_loss.backward()
|
||||
print(3)
|
||||
D_loss = D_real_loss + D_fake_loss
|
||||
D_loss.backward()
|
||||
D_optimizer.step()
|
||||
print(4)
|
||||
|
||||
# nauka generacji
|
||||
G_optimizer.zero_grad()
|
||||
generated_image = G(damaged_image)
|
||||
loss = L1_loss(generated_image, ok_image)
|
||||
loss.backward()
|
||||
losses.append(loss.data[0])
|
||||
estimate = D(generated_image)
|
||||
G_fake_loss = BCE_loss(estimate, get_real_estimate(estimate.size()))
|
||||
G_L1_loss = LAMBDA * L1_loss(generated_image, ok_image)
|
||||
# ogólnie im większa lambda, tym zachowuje się bardziej jak DCNN i
|
||||
# bardziej zachowuje ogólną kolorystykę
|
||||
G_loss = G_fake_loss + G_L1_loss
|
||||
G_loss.backward()
|
||||
G_optimizer.step()
|
||||
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
print(5)
|
||||
#losses.append((D_loss.data[0], G_loss.data[0]))
|
||||
#loss = L1_loss(generated_image, ok_image)
|
||||
#loss = BCE_loss(generated_image, ok_image)
|
||||
#loss = used_loss(generated_image, ok_image)
|
||||
#loss.backward()
|
||||
#losses.append(loss.data[0])
|
||||
#G_optimizer.step()
|
||||
print('Epoch: {} [{}/{} ({:.0f}%)]\tD_loss: {:.6f}\tG_loss: {:.6f}'.format(
|
||||
epoch, (i + 1) * len(ok_image), len(trainloader.dataset),
|
||||
100. * i / len(trainloader), loss.data[0]))
|
||||
100. * i / len(trainloader), D_loss.data[0], G_loss.data[0]))
|
||||
#loss_mean.append(np.mean(losses))
|
||||
#loss_std.append(np.std(losses))
|
||||
|
||||
with open(loss_path.format(epoch), 'w') as csvfile:
|
||||
fieldnames = ['num_image', 'loss_l1']
|
||||
fieldnames = ['num_image', 'd_loss', 'g_loss']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for i, loss_l1 in enumerate(losses):
|
||||
writer.writerow({'num_image': i, 'loss_l1': loss_l1})
|
||||
for i, (d_loss, g_loss) in enumerate(losses):
|
||||
writer.writerow({'num_image': i, 'd_loss': d_loss, 'g_loss': g_loss})
|
||||
|
||||
if epoch == 1 or epoch % 5 == 0:
|
||||
if True or epoch == 1 or epoch % 5 == 0:
|
||||
torch.save(G.state_dict(), model_path.format(epoch))
|
||||
|
||||
generated_image = G(Variable(test_damaged))
|
||||
|
Loading…
Reference in New Issue
Block a user