Appendix D: Creating a CNN to Predict Bitmojis#

import numpy as np
import torch
from PIL import Image
from torch import nn, optim
from torchvision import datasets, transforms, utils
from torchsummary import summary
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams.update({'font.size': 16, 'axes.labelweight': 'bold', 'axes.grid': False})
from utils.set_seed import *
set_seed(1)

1. Introduction#


This code-based appendix contains the code needed to develop and save the Bitmoji CNN’s used in Lecture 7.

2. CNN from Scratch#


TRAIN_DIR = "data/eva_bitmoji_rgb/train/"
VALID_DIR = "data/eva_bitmoji_rgb/valid/"
IMAGE_SIZE = 64
BATCH_SIZE = 64

# Transforms
data_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE), transforms.ToTensor()])
# Load data and create dataloaders
train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=data_transforms)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataset = datasets.ImageFolder(root=VALID_DIR, transform=data_transforms)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Plot samples
sample_batch = next(iter(trainloader))
plt.figure(figsize=(10, 8)); plt.axis("off"); plt.title("Sample Training Images")
plt.imshow(np.transpose(utils.make_grid(sample_batch[0], padding=1, normalize=True),(1,2,0)));
../_images/3cf2674a9002b26095b5d9adbef84862136c5a2bb80497273502b526c9defd8c.png
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') # SOLUTION
# BEGIN SOLUTION
class bitmoji_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 8, (5, 5)),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(8, 4, (3, 3)),
            nn.ReLU(),
            nn.MaxPool2d((3, 3)),
            nn.Dropout(0.2),
            nn.Flatten(),
            nn.Linear(324, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        
    def forward(self, x):
        out = self.main(x)
        return out

def trainer(model, criterion, optimizer, trainloader, validloader, epochs=5, verbose=True):
    """Simple training wrapper for PyTorch network."""
    
    train_loss, valid_loss, valid_accuracy = [], [], []
    for epoch in range(epochs):  # for each epoch
        train_batch_loss = 0
        valid_batch_loss = 0
        valid_batch_acc = 0
        
        # Training
        for X, y in trainloader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_hat = model(X).flatten()
            loss = criterion(y_hat, y.type(torch.float32))
            loss.backward()
            optimizer.step()
            train_batch_loss += loss.item()
        train_loss.append(train_batch_loss / len(trainloader))
        
        # Validation
        model.eval()
        
        with torch.no_grad():  # this stops pytorch doing computational graph stuff under-the-hood
            for X, y in validloader:
                X, y = X.to(device), y.to(device)
                y_hat = model(X).flatten()
                y_hat_labels = torch.sigmoid(y_hat) > 0.5
                loss = criterion(y_hat, y.type(torch.float32))
                valid_batch_loss += loss.item()
                valid_batch_acc += (y_hat_labels == y).type(torch.float32).mean().item()
        valid_loss.append(valid_batch_loss / len(validloader))
        valid_accuracy.append(valid_batch_acc / len(validloader))  # accuracy
        
        model.train()
        
        # Print progress
        if verbose:
            print(f"Epoch {epoch + 1}:",
                  f"Train Loss: {train_loss[-1]:.3f}.",
                  f"Valid Loss: {valid_loss[-1]:.3f}.",
                  f"Valid Accuracy: {valid_accuracy[-1]:.2f}.")
    
    results = {"train_loss": train_loss,
               "valid_loss": valid_loss,
               "valid_accuracy": valid_accuracy}
    return results    
    
model = bitmoji_CNN().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-3)
results = trainer(model, criterion, optimizer, trainloader, validloader, epochs=30)
# END SOLUTION
Epoch 1: Train Loss: 0.661. Valid Loss: 0.608. Valid Accuracy: 0.67.
Epoch 2: Train Loss: 0.612. Valid Loss: 0.578. Valid Accuracy: 0.68.
Epoch 3: Train Loss: 0.588. Valid Loss: 0.577. Valid Accuracy: 0.68.
Epoch 4: Train Loss: 0.571. Valid Loss: 0.574. Valid Accuracy: 0.70.
Epoch 5: Train Loss: 0.556. Valid Loss: 0.565. Valid Accuracy: 0.69.
Epoch 6: Train Loss: 0.560. Valid Loss: 0.559. Valid Accuracy: 0.70.
Epoch 7: Train Loss: 0.541. Valid Loss: 0.561. Valid Accuracy: 0.70.
Epoch 8: Train Loss: 0.509. Valid Loss: 0.547. Valid Accuracy: 0.69.
Epoch 9: Train Loss: 0.490. Valid Loss: 0.528. Valid Accuracy: 0.71.
Epoch 10: Train Loss: 0.475. Valid Loss: 0.511. Valid Accuracy: 0.71.
Epoch 11: Train Loss: 0.453. Valid Loss: 0.497. Valid Accuracy: 0.74.
Epoch 12: Train Loss: 0.444. Valid Loss: 0.501. Valid Accuracy: 0.74.
Epoch 13: Train Loss: 0.409. Valid Loss: 0.487. Valid Accuracy: 0.73.
Epoch 14: Train Loss: 0.386. Valid Loss: 0.493. Valid Accuracy: 0.75.
Epoch 15: Train Loss: 0.385. Valid Loss: 0.476. Valid Accuracy: 0.75.
Epoch 16: Train Loss: 0.358. Valid Loss: 0.490. Valid Accuracy: 0.77.
Epoch 17: Train Loss: 0.335. Valid Loss: 0.502. Valid Accuracy: 0.76.
Epoch 18: Train Loss: 0.314. Valid Loss: 0.539. Valid Accuracy: 0.76.
Epoch 19: Train Loss: 0.307. Valid Loss: 0.534. Valid Accuracy: 0.76.
Epoch 20: Train Loss: 0.282. Valid Loss: 0.508. Valid Accuracy: 0.78.
Epoch 21: Train Loss: 0.277. Valid Loss: 0.507. Valid Accuracy: 0.79.
Epoch 22: Train Loss: 0.256. Valid Loss: 0.498. Valid Accuracy: 0.79.
Epoch 23: Train Loss: 0.238. Valid Loss: 0.547. Valid Accuracy: 0.79.
Epoch 24: Train Loss: 0.233. Valid Loss: 0.506. Valid Accuracy: 0.79.
Epoch 25: Train Loss: 0.217. Valid Loss: 0.588. Valid Accuracy: 0.77.
Epoch 26: Train Loss: 0.224. Valid Loss: 0.515. Valid Accuracy: 0.82.
Epoch 27: Train Loss: 0.178. Valid Loss: 0.527. Valid Accuracy: 0.82.
Epoch 28: Train Loss: 0.165. Valid Loss: 0.550. Valid Accuracy: 0.82.
Epoch 29: Train Loss: 0.183. Valid Loss: 0.521. Valid Accuracy: 0.82.
Epoch 30: Train Loss: 0.177. Valid Loss: 0.529. Valid Accuracy: 0.83.
# Save model
PATH = "models/eva_cnn.pt"
torch.save(model.state_dict(), PATH)

2. CNN from Scratch with Data Augmentation#


Is this eva or not eva?

image = Image.open('img/test-examples/eva-picnic.png')
image
../_images/bc0eb9745dde252216671e299c3928d034bd6834ebed3bb50531e0c453123c17.png

Let’s check whether our CNN is able to figure it out.

image_tensor = transforms.functional.to_tensor(image.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor.to(device))) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: eva

Great! But what happens if I flip my image like this:

image_rotated = image.rotate(180)
image_rotated
../_images/9ff9376efd45d5347ab416d4beb18acd2ee2501a0fc76ecb7ac177fa6a90b15c.png

You can still tell that it’s eva, but can our CNN?

image_tensor_flipped = transforms.functional.to_tensor(image_rotated.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model(image_tensor_flipped.to(device))) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: not_eva

Looks like our CNN is not very robust to rotational changes in our input image. We could try and fix that using some data augmentation, let’s do that now:

# Transforms
data_transforms = transforms.Compose([transforms.Resize(IMAGE_SIZE),
                                      transforms.RandomVerticalFlip(p=0.5),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.ToTensor()])

# Load data and re-create training loader
train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=data_transforms)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Plot samples
sample_batch = next(iter(trainloader))
plt.figure(figsize=(10, 8)); plt.axis("off"); plt.title("Sample Training Images")
plt.imshow(np.transpose(utils.make_grid(sample_batch[0], padding=1, normalize=True),(1,2,0)));
../_images/27a026a6c1d62cbea398039c6b7eb4a4fc00f3ba22d6c70b86c6704635c96d87.png

Okay, let’s train again with our new augmented dataset:

# Define and train model
model_aug = bitmoji_CNN()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model_aug.parameters())
results = trainer(model_aug.to(device), criterion, optimizer, trainloader, validloader, epochs=40)
Epoch 1: Train Loss: 0.680. Valid Loss: 0.642. Valid Accuracy: 0.60.
Epoch 2: Train Loss: 0.645. Valid Loss: 0.618. Valid Accuracy: 0.64.
Epoch 3: Train Loss: 0.630. Valid Loss: 0.597. Valid Accuracy: 0.67.
Epoch 4: Train Loss: 0.616. Valid Loss: 0.582. Valid Accuracy: 0.71.
Epoch 5: Train Loss: 0.610. Valid Loss: 0.578. Valid Accuracy: 0.70.
Epoch 6: Train Loss: 0.606. Valid Loss: 0.574. Valid Accuracy: 0.70.
Epoch 7: Train Loss: 0.586. Valid Loss: 0.568. Valid Accuracy: 0.70.
Epoch 8: Train Loss: 0.597. Valid Loss: 0.591. Valid Accuracy: 0.66.
Epoch 9: Train Loss: 0.597. Valid Loss: 0.558. Valid Accuracy: 0.72.
Epoch 10: Train Loss: 0.587. Valid Loss: 0.560. Valid Accuracy: 0.71.
Epoch 11: Train Loss: 0.576. Valid Loss: 0.579. Valid Accuracy: 0.69.
Epoch 12: Train Loss: 0.580. Valid Loss: 0.561. Valid Accuracy: 0.69.
Epoch 13: Train Loss: 0.576. Valid Loss: 0.559. Valid Accuracy: 0.70.
Epoch 14: Train Loss: 0.569. Valid Loss: 0.534. Valid Accuracy: 0.73.
Epoch 15: Train Loss: 0.545. Valid Loss: 0.523. Valid Accuracy: 0.75.
Epoch 16: Train Loss: 0.532. Valid Loss: 0.490. Valid Accuracy: 0.77.
Epoch 17: Train Loss: 0.506. Valid Loss: 0.468. Valid Accuracy: 0.79.
Epoch 18: Train Loss: 0.473. Valid Loss: 0.443. Valid Accuracy: 0.80.
Epoch 19: Train Loss: 0.469. Valid Loss: 0.469. Valid Accuracy: 0.78.
Epoch 20: Train Loss: 0.443. Valid Loss: 0.417. Valid Accuracy: 0.82.
Epoch 21: Train Loss: 0.408. Valid Loss: 0.397. Valid Accuracy: 0.82.
Epoch 22: Train Loss: 0.395. Valid Loss: 0.381. Valid Accuracy: 0.84.
Epoch 23: Train Loss: 0.380. Valid Loss: 0.385. Valid Accuracy: 0.84.
Epoch 24: Train Loss: 0.363. Valid Loss: 0.380. Valid Accuracy: 0.85.
Epoch 25: Train Loss: 0.361. Valid Loss: 0.360. Valid Accuracy: 0.86.
Epoch 26: Train Loss: 0.357. Valid Loss: 0.355. Valid Accuracy: 0.85.
Epoch 27: Train Loss: 0.342. Valid Loss: 0.354. Valid Accuracy: 0.85.
Epoch 28: Train Loss: 0.323. Valid Loss: 0.354. Valid Accuracy: 0.86.
Epoch 29: Train Loss: 0.316. Valid Loss: 0.346. Valid Accuracy: 0.86.
Epoch 30: Train Loss: 0.318. Valid Loss: 0.363. Valid Accuracy: 0.85.
Epoch 31: Train Loss: 0.320. Valid Loss: 0.349. Valid Accuracy: 0.86.
Epoch 32: Train Loss: 0.303. Valid Loss: 0.326. Valid Accuracy: 0.87.
Epoch 33: Train Loss: 0.300. Valid Loss: 0.331. Valid Accuracy: 0.87.
Epoch 34: Train Loss: 0.292. Valid Loss: 0.321. Valid Accuracy: 0.87.
Epoch 35: Train Loss: 0.295. Valid Loss: 0.343. Valid Accuracy: 0.87.
Epoch 36: Train Loss: 0.292. Valid Loss: 0.340. Valid Accuracy: 0.87.
Epoch 37: Train Loss: 0.291. Valid Loss: 0.352. Valid Accuracy: 0.86.
Epoch 38: Train Loss: 0.292. Valid Loss: 0.316. Valid Accuracy: 0.87.
Epoch 39: Train Loss: 0.273. Valid Loss: 0.302. Valid Accuracy: 0.88.
Epoch 40: Train Loss: 0.272. Valid Loss: 0.291. Valid Accuracy: 0.88.
# Save model
PATH = "models/eva_cnn_augmented.pt"
torch.save(model_aug.state_dict(), PATH)

Let’s try predict this one again:

image_rotated
../_images/9ff9376efd45d5347ab416d4beb18acd2ee2501a0fc76ecb7ac177fa6a90b15c.png
image_tensor_flipped = transforms.functional.to_tensor(image_rotated.resize((IMAGE_SIZE, IMAGE_SIZE))).unsqueeze(0)
prediction = int(torch.sigmoid(model_aug(image_tensor_flipped.to(device))) > 0.5)
print(f"Prediction: {train_dataset.classes[prediction]}")
Prediction: eva

Got it now!