MNIST Autoencoder exercise - reconstruction of noisy digits¶
Goal: Introduction Auto encoder/decoder
Exercise:
Run the notebook, observe the reconstruction of noisy images
Try changing the amount of noise
Try adding a skip connection and see how it affects the reconstruction (look at the fwd method of the Decoder subnetwork)
In [ ]:
# ----------- IMPORT STATEMENTS ---------------
import argparse
import torch
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms.v2 as transforms
import os
import sys
import numpy as np
import time
#---------------------------------------------
print('import done')
In [ ]:
# -----------------------------------
#Parameters for training
# -----------------------------------
num_worker2use = 0 #for parallel reading/prefetching of (bigger) data
batch_size = 128 #1024
max_numtrain = 1024 #4096 #for this exercise, train on limited num of input, to save time
max_numtest = batch_size # and test on limited num of input
epochs = 10
lrate = 0.01
numfilt = 16
num_xtra_noise_steps=15 #<<<<<<<<<<------------ (15 to 25 is large-ish, 0 to 5 is smallish)
torch.manual_seed(776)
In [ ]:
#We'll load data into arrays directly
X_train=np.load("./X_train1k.npy")
Y_train=np.load("./Y_train1k.npy")
X_test =np.load("./X_test.npy")[0:max_numtrain,] #take 1k out of the 10k images
Y_test=np.load("./Y_test.npy")[0:max_numtest,]
#Scale 0 to 1 - or should we not scale
X_train = X_train/255.0
X_test = X_test/255.0
X_train = X_train[:,np.newaxis,:, :]
X_test = X_test[:,np.newaxis,:, :]
print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)
In [ ]:
#print(np.max(X_train))
In [ ]:
#Add some noise to make it harder
def addnoise(X):
#X=X + np.round(np.random.uniform(-1,1,size=X.shape),1) using round adds full pixel or nothing
X=X + np.random.uniform(-0.2,0.2,size=X.shape)
X[np.where(X>1)]=1
X[np.where(X<0)]=0
return X
X_train_wnoise = addnoise(X_train)
X_test_wnoise = addnoise(X_test)
for i in range(num_xtra_noise_steps):
print('adding more noise, step',i)
X_train_wnoise = addnoise(X_train_wnoise)
X_test_wnoise = addnoise(X_test_wnoise)
print('noise added')
print(np.max(X_train_wnoise))
print(np.max(X_test_wnoise))
In [ ]:
#Set up arrays as 'tensor datasets'
from torch.utils.data import TensorDataset, DataLoader
X_train_wnoise_tensor = torch.from_numpy(X_train_wnoise).float() # Use .float() for float data
X_train_tensor = torch.from_numpy(X_train).float()
Y_train_tensor = torch.from_numpy(Y_train).long()
X_test_wnoise_tensor = torch.from_numpy(X_test_wnoise).float()
X_test_tensor = torch.from_numpy(X_test).float()
Y_test_tensor = torch.from_numpy(Y_test).long() # Use .long() for integer labels
# Combine input and target tensors into a TensorDataset object
my_train_dataset = TensorDataset(X_train_wnoise_tensor, X_train_tensor)
my_test_dataset = TensorDataset(X_test_wnoise_tensor, X_test_tensor)
print('train,test tensor datasets set up')
In [ ]:
In [ ]:
# -------------------------------------------
#prepare images for network as they are loaded
# -------------------------------------------
train_loader =torch.utils.data.DataLoader(my_train_dataset,
batch_size =batch_size, sampler =None,
num_workers=num_worker2use, pin_memory=False, drop_last=True)
test_loader = torch.utils.data.DataLoader(my_test_dataset,
batch_size =batch_size, sampler =None,
num_workers=num_worker2use, pin_memory=False, drop_last=True)
In [ ]:
#Sample of how to access data
with torch.no_grad():
for batch_idx, (data, target) in enumerate(train_loader):
#output=data
break
print(data.shape)
print(target.shape)
#print(torch.max(data))
In [ ]:
# -------------------------------------------------------------
# Define network class object and its
# initialization and forward function
# (other functions are inherited from torch.nn)
# -------------------------------------------------------------
class MyEncoder(torch.nn.Module):
def __init__(self):
super(MyEncoder, self).__init__()
#convolution layer then max pool to downsize,
self.conv1 = torch.nn.Conv2d(1, numfilt, 3, 1,padding='same')
self.max_pool_1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
#repeat the block but double the filters
self.conv2 = torch.nn.Conv2d(numfilt, numfilt*2, 3, 1,padding='same')
self.max_pool_2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
#print('MYINFO enc fwd, x shp:',x.shape)
x1 = self.conv1(x)
x1 = F.relu(x1)
x1 = self.max_pool_1(x1)
#print('MYINFO enc fwd, after max1, x shape:',x1.shape)
x2 = self.conv2(x1)
x2 = F.relu(x2)
x2 = self.max_pool_2(x2)
return x1,x2 #or x1,x2 to use skip connections
# ---------------------------------------------------------------------
# decoder
# ---------------------------------------------------------------------
class MyDecoder(torch.nn.Module):
def __init__(self):
super(MyDecoder, self).__init__()
#convolution layer then max pool to downsize,
self.conv1 = torch.nn.Conv2d(numfilt*2, numfilt, 3, 1,padding='same')
#if no skip connection use in channels = numfilt for Conv2
self.conv2 = torch.nn.Conv2d(numfilt, numfilt, 3, 1,padding='same')
# <<<<<<<<<-------------- uncomment this, comment out the above
#for skip connection going into conv2 use mumfilt*2
#self.conv2 = torch.nn.Conv2d(numfilt*2, numfilt, 3, 1,padding='same')
#last conv is 1 filter, and will use sigmoidal activation bc this is the output layer
self.conv3 = torch.nn.Conv2d(numfilt, 1, 3, 1,padding='same')
def forward(self, encx1,x): #or use x1,x2 inputs
x1 = self.conv1(x)
x1 = F.relu(x1)
x1 =torch.nn.functional.interpolate(x1,size=(14,14),mode='nearest')
#print('MYINFO dec fwd, after inter1, x shape:',x1.shape, 'encx1shp',encx1.shape)
skip_concat_1 = torch.cat((x1,encx1), dim=1)
#print('MYINFO, dec fwd, after concat1',skip_concat_1.shape)
#<<<<<<---------- choose if x2 should use x1 alone, or x1 concat with skip conntn
x2 = self.conv2(x1)
#x2 = self.conv2(skip_concat_1)
x2 = F.relu(x2)
x2 =torch.nn.functional.interpolate(x2,size=(28,28),mode='nearest')
#print('MYINFO dec fwd, after inter2, x shape:',x2.shape)
x3 = self.conv3(x2)
x3 = F.sigmoid(x3)
return x3
class MyAENet(torch.nn.Module):
def __init__(self):
super(MyAENet, self).__init__()
self.encoder = MyEncoder()
self.decoder = MyDecoder()
#self.bottleneck = torch.nn.Conv2d(numfilt, numfilt, 3, 1)
def forward(self, x):
encx1,x2 = self.encoder(x)
output = self.decoder(encx1,x2) #output is between 0 and 1
#print('MYINFO fwd, after max, x shape:',x.shape)
return output
print('Net class defined ')
In [ ]:
# --------------------------------------------------------
# Define training function
# --------------------------------------------------------
def train(model, device, train_loader, optimizer, epoch):
''' This is called for each epoch.
Arguments: the model, the device to run on, data loader, optimizer, and current epoch
'''
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx*batch_size>= max_numtrain:
break
else:
if batch_idx==0:
print('INFO train, ep:',epoch,' batidx:',batch_idx, ' batch size:',target.shape[0])
data, target = data.to(device), target.to(device)
optimizer.zero_grad() #reset optimizer state
output = model(data) #get predictions
loss = torch.nn.functional.binary_cross_entropy(output,target)
loss.backward() #backprop loss
optimizer.step() #update weights
# -------------------------------------------------------------
# Define test function
# -------------------------------------------------------------
def test(model, device, test_loader):
''' This is called for after training each epoch
Arguments: the model, the device to run on, test data loader
'''
model.eval()
#accumulate loss, accuracy info
total_loss = 0
total_correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
if batch_idx*batch_size>= max_numtest:
break
else:
data, target = data.to(device), target.to(device)
output = model(data)
total_loss += torch.nn.functional.binary_cross_entropy(output,target)
total +=data.shape[0]
test_loss = total_loss/total
print('INFO test loss:',f'{test_loss:.4}','tot:',total)
return test_loss
print('Train,test, support functions defined ')
In [ ]:
# -------------------------------------------------
# Get device
# (note, this is set up for 1 GPU device
# if this were to run on a full GPU node with >1 gpu device, you would
# want to get rank, world size info and set device id
# as in: torch.cuda.set_device(local_rank)
# and then also run distributed initialization )
# -------------------------------------------------
use_cuda = torch.cuda.is_available()
if use_cuda:
num_gpu = torch.cuda.device_count()
print('INFO, cuda, num gpu:',num_gpu)
device = torch.cuda.current_device()
print('environ visdevs:',os.environ["CUDA_VISIBLE_DEVICES"])
else:
num_gpu = 0
print('INFO, cuda not available')
device = torch.device("cpu")
print('INFO, device is:', device)
In [ ]:
In [ ]:
# -------------------------------------------
# Set up model and do training loop
# -------------------------------------------
mymodel = MyAENet().to(device)
#summary(mymodel,input_size=(1, 1, 28, 28))
In [ ]:
#Do training loop
# =---------------------------------------
optimizer = torch.optim.Adam(mymodel.parameters(), lr=lrate)
train_results = []
test_results = []
for epoch in range(epochs):
print('INFO about to train epoch:',epoch)
start_time=time.time()
train(mymodel, device, train_loader, optimizer, epoch)
print('INFO training time:',str.format('{0:.5f}', time.time()-start_time))
print('INFO about to test epoch:',epoch)
test(mymodel,device,train_loader)
test(mymodel,device,test_loader)
print('INFO done');
In [ ]:
#=====================
In [ ]:
#To view sample images
import matplotlib.pyplot as plt #These provide matlab type of plotting functions
import matplotlib.image as mpimg
def display_one_row(disp_images, offset, shape=(28, 28)):
'''Display sample outputs in one row.'''
for idx, test_image in enumerate(disp_images):
plt.subplot(3, 10, offset + idx + 1)
plt.xticks([])
plt.yticks([])
test_image = np.reshape(test_image, shape)
plt.imshow(test_image, cmap='gray')
In [ ]:
#get sample outputs and show original,noisy,reconstructed images
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
output=mymodel(data.to(device)).detach().cpu()
break
num2do=10
print(' disply noisy images --------------------')
display_one_row(data[0:num2do,], 0, shape=(28,28,))
print(' disply target images --------------------')
display_one_row(target[0:num2do,], 10, shape=(28,28,))
print(' disply output reconstruction ---------------')
display_one_row(output[0:num2do,], 20, shape=(28,28,))
In [ ]:
# Set it to 0 and so that if you rerun all cells it
# won't clear out the images
#get sample outputs and show original,noisy,reconstructed images
if 1:
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
output=mymodel(data.to(device)).detach().cpu()
break
num2do=10
print(' disply noisy images --------------------')
display_one_row(data[0:num2do,], 0, shape=(28,28,))
print(' disply target images --------------------')
display_one_row(target[0:num2do,], 10, shape=(28,28,))
print(' disply output reconstruction ---------------')
display_one_row(output[0:num2do,], 20, shape=(28,28,))
In [ ]: