import torch.nn as nn import numpy as np import matplotlib.pyplot as plt def initialize_weights(net): for m in net.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.ConvTranspose2d): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() def plot_loss(loss_mean, loss_std): plt.errorbar(np.linspace(1, len(loss_mean), num=len(loss_mean)), loss_mean, loss_std, linestyle='None', marker='^', capsize=3) plt.ylim(ymin = 0.0, ymax = 1.0) plt.show() def plot_images(input_image, target_image, generated_image): fig_size = (input_image.size(2) * 3 / 100, input_image.size(3)/100) _, axes = plt.subplots(1, 3, figsize=fig_size) imgs = [input_image, generated_image, target_image] for ax, img in zip(axes.flatten(), imgs): ax.axis('off') # Scale to 0-255 img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8) ax.imshow(img, cmap=None, aspect='equal') plt.subplots_adjust(wspace=0, hspace=0) plt.show() def normalize_image(img): return (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)