ProjektSI/utils.py
2018-05-20 09:23:22 +02:00

39 lines
1.4 KiB
Python

import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
def plot_loss(loss_mean, loss_std):
plt.errorbar(np.linspace(1, len(loss_mean), num=len(loss_mean)), loss_mean, loss_std, linestyle='None', marker='^', capsize=3)
plt.ylim(ymin = 0.0, ymax = 1.0)
plt.show()
def plot_images(input_image, target_image, generated_image):
fig_size = (input_image.size(2) * 3 / 100, input_image.size(3)/100)
_, axes = plt.subplots(1, 3, figsize=fig_size)
imgs = [input_image, generated_image, target_image]
for ax, img in zip(axes.flatten(), imgs):
ax.axis('off')
# Scale to 0-255
img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)
ax.imshow(img, cmap=None, aspect='equal')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
def normalize_image(img):
return (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)