39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
import torch.nn as nn
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
def initialize_weights(net):
|
|
for m in net.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
m.weight.data.normal_(0, 0.02)
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.ConvTranspose2d):
|
|
m.weight.data.normal_(0, 0.02)
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.Linear):
|
|
m.weight.data.normal_(0, 0.02)
|
|
m.bias.data.zero_()
|
|
|
|
def plot_loss(loss_mean, loss_std):
|
|
plt.errorbar(np.linspace(1, len(loss_mean), num=len(loss_mean)), loss_mean, loss_std, linestyle='None', marker='^', capsize=3)
|
|
plt.ylim(ymin = 0.0, ymax = 1.0)
|
|
plt.show()
|
|
|
|
def plot_images(input_image, target_image, generated_image):
|
|
fig_size = (input_image.size(2) * 3 / 100, input_image.size(3)/100)
|
|
|
|
_, axes = plt.subplots(1, 3, figsize=fig_size)
|
|
imgs = [input_image, generated_image, target_image]
|
|
for ax, img in zip(axes.flatten(), imgs):
|
|
ax.axis('off')
|
|
# Scale to 0-255
|
|
img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)
|
|
ax.imshow(img, cmap=None, aspect='equal')
|
|
plt.subplots_adjust(wspace=0, hspace=0)
|
|
|
|
plt.show()
|
|
|
|
|
|
def normalize_image(img):
|
|
return (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min())).numpy().transpose(1, 2, 0).astype(np.uint8)
|