import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import utils
import csv
import os
from dataset import get_dataloader
from torch.autograd import Variable


def add_conv_layer(conv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
    layer = []
    layer.append(nn.Conv2d(input_dim, output_dim, kernel))
    if batch_norm:
        layer.append(nn.BatchNorm2d(output_dim))
    if activation:
        layer.append(nn.LeakyReLU(0.2))
    if dropout:
        layer.append(nn.Dropout2d())
    conv.append(nn.Sequential(*layer))
            
def add_deconv_layer(deconv, input_dim, output_dim, kernel=4, activation=True, batch_norm=True, dropout=False):
    layer = []
    if activation:
        layer.append(nn.LeakyReLU(0.2))
    layer.append(nn.ConvTranspose2d(input_dim, output_dim, kernel))
    if batch_norm:
        layer.append(nn.BatchNorm2d(output_dim))
    if dropout:
        layer.append(nn.Dropout2d())
    deconv.append(nn.Sequential(*layer))
    

class Generator(nn.Module):
    def __init__(self, input_width=178, input_height=218, input_dim=3, num_features=32, output_dim=3, lr=0.0002):
        super(Generator, self).__init__()
        self.input_width = input_width
        self.input_height = input_height
        self.input_dim = input_dim
        self.num_features = num_features
        self.output_dim = output_dim
        self.lr = lr # bo tak
        self.n = 4 
        
        self.conv = []
        add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False)
        add_conv_layer(self.conv, self.num_features, self.num_features * 2)
        add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4)
        add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8, batch_norm = False)
        
        self.deconv = []
        add_deconv_layer(self.conv, self.num_features * 8, self.num_features * 4, dropout = True)
        add_deconv_layer(self.conv, self.num_features * 4, self.num_features * 2)
        add_deconv_layer(self.conv, self.num_features * 2, self.num_features)
        add_deconv_layer(self.conv, self.num_features, self.output_dim, batch_norm = False)
        
        
        self.conv = nn.Sequential(*self.conv)
        self.deconv = nn.Sequential(*self.deconv)
        
        utils.initialize_weights(self)
            
    def forward(self, x):
        x = self.conv(x)
        x = self.deconv(x)
        """ # To miał← być skip connections, ale nie działa.
        conv = [self.conv[0](x)]
        for i in range(1, len(self.conv)):
            print(i)
            conv.append(self.conv[i](conv[-1]))
        deconv = [self.deconv[0](conv[-1])]
        for i in range(1, len(self.deconv)-1):
            print(i)
            deconv.append(self.deconv[i](deconv[-1]))
            #deconv[-1] = torch.cat((deconv[-1], conv[-1-i]), 1)
        deconv.append(self.deconv[-1](deconv[-1]))
        """
        x = nn.Tanh()(x) # może i Sigmoid, ale wtedy są dziwne przekolorowana
        return x


class Discriminator(nn.Module):
    def __init__(self, input_width=178, input_height=218, input_dim=1, num_features=32, output_dim=3, lr=0.0002):
        super(Discriminator, self).__init__()
        self.input_width = input_width
        self.input_height = input_height
        self.input_dim = input_dim
        self.num_features = num_features
        self.output_dim = output_dim
        self.lr = lr
        
        self.conv = []
        add_conv_layer(self.conv, self.input_dim, self.num_features, activation=False, batch_norm = False)
        add_conv_layer(self.conv, self.num_features, self.num_features * 2)
        add_conv_layer(self.conv, self.num_features * 2, self.num_features * 4)
        add_conv_layer(self.conv, self.num_features * 4, self.num_features * 8)
        add_conv_layer(self.conv, self.num_features * 8, self.output_dim, batch_norm = False)
        #self.conv.append(nn.Linear(self.num_features * 8, self.output_dim))
        
        self.conv = nn.Sequential(*self.conv)
        
        utils.initialize_weights(self)
            
    def forward(self, x):
        x = self.conv(x)
        
        x = nn.Sigmoid()(x) # bo musi być w <0;1>
        return x
    
    

if __name__ == '__main__':
    EPOCHS = 5
    CUDA = False
    NUM_FEATURES = 16
    LAMBDA = 100
    
    output_path = "output"
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    loss_path = os.path.join(output_path, "loss", "loss_epoch_{}.csv")
    if not os.path.exists(os.path.dirname(loss_path)):
        os.makedirs(os.path.dirname(loss_path))
    model_path = os.path.join(output_path, "model", "gan_epoch_{}.pt")
    if not os.path.exists(os.path.dirname(model_path)):
        os.makedirs(os.path.dirname(model_path))
    
    torch.manual_seed(1)
    
    G = Generator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=3, lr=0.0002)
    D = Discriminator(input_width=178, input_height=218, input_dim=3, num_features=NUM_FEATURES, output_dim=1, lr=0.0002)
    print("Generator:\n", G)
    print("Discriminator:\n", D)
    
    BCE_loss = nn.BCELoss() # nie działa, przyjmuje inputy tylko z zakresu <0; 1>
    L1_loss = nn.L1Loss() # działa, ale rozmyte
    L2_loss = nn.KLDivLoss() # działa, ale wolne
    MSE_loss = nn.MSELoss() # działa, ale jakieś dziwne przekolorwania (chociaż zdają się zanikać)
    
    G_optimizer = optim.Adam(G.parameters(), G.lr)
    D_optimizer = optim.Adam(D.parameters(), D.lr)
    
    if CUDA: # u mnie nie działa
        G.cuda()
        D.cuda()
        BCE_loss.cuda()
        L1_loss.cuda()
    
    trainloader = get_dataloader()
    test_ok, test_damaged = trainloader.__iter__().__next__()
    #loss_mean, loss_std = [], []

    def get_real_estimate(size):
        return Variable(torch.ones(size))
    
    def get_fake_estimate(size):
        return Variable(torch.zeros(size))
    
    # https://github.com/soumith/ganhacks
    # Label smoothing: duże rozmycie spowalnia zbieganie D_loss do zera.
    def get_smooth_real_estimate(size):
        return Variable(torch.ones(size)) * (0.7 + torch.rand(1)[0]*0.5)
    
    def get_smooth_fake_estimate(size):
        return Variable(torch.ones(size)) * (torch.rand(1)[0]*0.3)
    
    for epoch in range(1, EPOCHS+1, 1):
        losses = []
        for i, (ok_image, damaged_image) in enumerate(trainloader):
            ok_image, damaged_image = Variable(ok_image), Variable(damaged_image)
            print(ok_image.size())
            
            # TODO: jakieś if'y wykrywające wyglebywanie się GAN-a i zezwalające
            #       naukę tylko D albo G ?
            
            # nauka dyskryminacji
            
            D_optimizer.zero_grad()
            print(1)
            D_real_estimate = D(ok_image)
            real_estimate = get_smooth_real_estimate(D_real_estimate.size())
            print (real_estimate.size())
            D_real_loss = BCE_loss(D_real_estimate, real_estimate)
            #D_real_loss.backward()
            
            print(2)
            generated_image = G(damaged_image).detach()
            D_fake_estimate = D(generated_image)
            fake_estimate = get_smooth_fake_estimate(D_fake_estimate.size())
            D_fake_loss = BCE_loss(D_fake_estimate, fake_estimate)
            #D_fake_loss.backward()
            print(3)
            D_loss = D_real_loss + D_fake_loss
            D_loss.backward()
            D_optimizer.step()
            print(4)
            
            # nauka generacji
            G_optimizer.zero_grad()
            generated_image = G(damaged_image)
            estimate = D(generated_image)
            G_fake_loss = BCE_loss(estimate, get_real_estimate(estimate.size()))
            G_L1_loss = LAMBDA * L1_loss(generated_image, ok_image)
            # ogólnie im większa lambda, tym zachowuje się bardziej jak DCNN i
            # bardziej zachowuje ogólną kolorystykę
            G_loss = G_fake_loss + G_L1_loss
            G_loss.backward()
            G_optimizer.step()
            print(5)
            #losses.append((D_loss.data[0], G_loss.data[0]))
            #loss = L1_loss(generated_image, ok_image)
            #loss = BCE_loss(generated_image, ok_image)
            #loss = used_loss(generated_image, ok_image)
            #loss.backward()
            #losses.append(loss.data[0])
            #G_optimizer.step()
            print('Epoch: {} [{}/{} ({:.0f}%)]\tD_loss: {:.6f}\tG_loss: {:.6f}'.format(
                    epoch, (i + 1) * len(ok_image), len(trainloader.dataset),
                           100. * i / len(trainloader), D_loss.data[0], G_loss.data[0]))
        #loss_mean.append(np.mean(losses))
        #loss_std.append(np.std(losses))
        
        with open(loss_path.format(epoch), 'w') as csvfile:
            fieldnames = ['num_image', 'd_loss', 'g_loss']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for i, (d_loss, g_loss) in enumerate(losses):
                writer.writerow({'num_image': i, 'd_loss': d_loss, 'g_loss': g_loss})
                
            if True or epoch == 1 or epoch % 5 == 0:
                torch.save(G.state_dict(), model_path.format(epoch))
                
                generated_image = G(Variable(test_damaged))
                generated_image = generated_image.data     
                utils.plot_images(test_damaged, test_ok, generated_image)
        
    
    #utils.plot_loss(loss_mean, loss_std)