Basic CNN and MLP implementation using PyTorch
This is a basic implementation of CNN and MLP in the PyTorch framework. We haven't done anything fancy with the NN architecture. However, make sure to check the preprocessing area.
We have used the coarse lebels of Cifar-100 and also implemented a sub-routine to determine the mean and standard deviation per channel of the input CIFAR-100 data. This approach will be much better than guessing and searching Research Papers for the appropriate digits.
Other than this, we have provided the user with an option to run either MLP implementation or CNN implementation.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR100
from sklearn.metrics import confusion_matrix
class CIFAR100coarse(CIFAR100):
def __init__(self, root, train = True, transform = None, target_transform=None, download = False):
super(CIFAR100coarse, self).__init__(root, train, transform, target_transform, download)
coarse_labels = np.array([ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
self.targets = coarse_labels[self.targets]
self.classes = [['beaver', 'dolphin', 'otter', 'seal', 'whale'],
['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
['bottle', 'bowl', 'can', 'cup', 'plate'],
['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],
['clock', 'keyboard', 'lamp', 'telephone', 'television'],
['bed', 'chair', 'couch', 'table', 'wardrobe'],
['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
['bear', 'leopard', 'lion', 'tiger', 'wolf'],
['bridge', 'castle', 'house', 'road', 'skyscraper'],
['cloud', 'forest', 'mountain', 'plain', 'sea'],
['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
['crab', 'lobster', 'snail', 'spider', 'worm'],
['baby', 'boy', 'girl', 'man', 'woman'],
['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],
['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']]
def mean_std():
transform = transforms.Compose([transforms.ToTensor()])
data_set = CIFAR100coarse(root = './data', train = True, transform = transform, target_transform=None, download = True)
train_loader = torch.utils.data.DataLoader(data_set, batch_size = 4, shuffle=True)
images, labels = iter(train_loader).next()
numpy_images = images.numpy()
per_image_mean = np.mean(numpy_images, axis=(2,3))
per_image_std = np.std(numpy_images, axis=(2,3))
per_channel_mean = np.mean(per_image_mean, axis=0)
per_channel_std = np.mean(per_image_std, axis=0)
return per_channel_mean, per_channel_std
def cifar_preprocessor(batch_size, test_shuffle = False):
per_channel_mean, per_channel_std = mean_std()
transform =transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = per_channel_mean, std = per_channel_std)])
train_set = CIFAR100coarse(root = './data', train = True, transform = transform, target_transform=None, download = True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle = True)
test_set = CIFAR100coarse(root = './data', train = True, transform = transform, target_transform=None, download = True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = batch_size, shuffle = test_shuffle)
return train_loader, test_loader, train_set, test_set
class MLP_network(nn.Module):
def __init__(self):
'''
Check this link for nn.Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
'''
super(MLP_network, self).__init__()
self.fc1 = torch.nn.Linear(3072, 2000)
self.fc2 = torch.nn.Linear(2000, 1500)
self.fc3 = torch.nn.Linear(1500, 1000)
self.fc4 = torch.nn.Linear(1000, 800)
self.fc5 = torch.nn.Linear(800, 500)
self.fc6 = torch.nn.Linear(500, 200)
self.out = torch.nn.Linear(200, 20)
def forward(self, t):
#Layer1
t = F.relu(self.fc1(t))
#Layer2
t = F.relu(self.fc2(t))
#Layer3
t = F.relu(self.fc3(t))
#Layer4
t = F.relu(self.fc4(t))
#Layer5
t = F.relu(self.fc5(t))
#Layer6
t = F.relu(self.fc6(t))
#Layer7
t = self.out(t)
return t
class CNN_network(nn.Module):
def __init__(self):
super(CNN_network,self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
self.fc1 = nn.Linear(in_features=16*5*5, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.out = nn.Linear(in_features=84, out_features=20)
def forward(self, t):
#Layer 1
t = t
#Layer 2
t = self.conv1(t)
t = F.relu(t)
t = F.max_pool2d(t, kernel_size=2, stride=2)#output shape : (6,14,14)
#Layer 3
t = self.conv2(t)
t = F.relu(t)
t = F.max_pool2d(t, kernel_size=2, stride=2)#output shape : (16,5,5)
#Layer 4
t = t.reshape(-1, 16*5*5)
t = self.fc1(t)
t = F.relu(t)#output shape : (1,120)
#Layer 5
t = self.fc2(t)
t = F.relu(t)#output shape : (1, 84)
#Layer 6/ Output Layer
t = self.out(t)#output shape : (1, 20)
return t
def train_nn(model, train_loader, optimizer, i):
print('>>> Training Start >>>')
for epoch in range(30):
total_loss = 0
total_correct = 0
for batch in train_loader:
images, labels = batch
if i == 1:
images = images.reshape(-1, 32*32*3)
predictions = model(images)
loss = F.cross_entropy(predictions, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss = total_loss + loss.item()
total_correct = total_correct + predictions.argmax(dim=1).eq(labels).sum().item()
print('epoch:', epoch, "total_correct:", total_correct, "loss:", total_loss)
print('>>> Training Complete >>>')
@torch.no_grad()
def get_all_preds(model, loader, i):
all_preds = torch.tensor([])
for batch in loader:
images, labels = batch
if i == 1:
images = images.reshape(-1, 32*32*3)
preds = model(images)
all_preds = torch.cat((all_preds, preds) ,dim=0)
return all_preds
def calc_accuracy(test_preds, test_set, i):
actual_labels = torch.Tensor(test_set.targets)
preds_correct = test_preds.argmax(dim=1).eq(actual_labels).sum().item()
if i == 1:
print('Multi Layer Perceptrons')
else:
print('Convolutional Neural Networks')
print('total correct:', preds_correct)
print('accuracy:', preds_correct / len(test_set))
def main():
train_loader, test_loader, train_set, test_set = cifar_preprocessor(64)
print('Enter 1 for MLP, 2 for CNN')
i = int(input())
if i == 1:
print('Multiple Layer of Perceptrons')
model = MLP_network()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum = 0.9)
train_nn(model, train_loader, optimizer, i)
all_preds = get_all_preds(model, test_loader, i)
calc_accuracy(all_preds, test_set, i)
PATH = './cifar100_mlp.pth'
torch.save(model.state_dict(), PATH)
elif i == 2:
print('Convolutional Neural Network')
model = CNN_network()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum = 0.9)
train_nn(model, train_loader, optimizer, i)
all_preds = get_all_preds(model, test_loader, i)
calc_accuracy(all_preds, test_set, i)
PATH = './cifar100_cnn.pth'
torch.save(model.state_dict(), PATH)
else:
print('Wrong Choice...Try Again!!!')
main()