Transfer Learning on Cats-Dogs Classification - Feature Extraction¶

Features are extracted from a MobileNet-V2 model pre-trained on ImageNet data, then passed through a new classification head to classify cats vs. dogs.¶

Adapted from https://www.tensorflow.org/tutorials/images/transfer_learning¶

CIML Summer Institute¶

UC San Diego¶

Setup¶

In [ ]:
# --- Set logging level ---

import warnings
warnings.filterwarnings("ignore")
In [ ]:
import os
import random
import shutil
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# TODO: Import PyTorch Lightning as 'pl'
==> YOUR CODE HERE

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    f1_score,
    precision_score,
    recall_score
)
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics.functional import accuracy
from torchvision import datasets, transforms
from torchvision.io import read_image
from torchvision.utils import make_grid
import torchvision

plt.rcParams["figure.facecolor"] = "white"
In [ ]:
# Set global random seed for reproducibility

def set_seed(seed=1234):
    os.environ["PYTHONHASHSEED"] = str(0)  # disable hash randomization
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    pl.seed_everything(seed, workers=False)

# TODO:  Call set_seed() to set all seeds
==> YOUR CODE HERE
In [ ]:
!jupyter --version
!python --version
In [ ]:
print(
    f"PyTorch version: {torch.__version__}\nPyTorch Lightning version: {pl.__version__}"
)
In [ ]:
# TODO:  Use linux shell command nvidia-smi to see GPU device.  

==> YOUR CODE HERE
In [ ]:
from os.path import expanduser
HOME = expanduser("~")

DATA_DIR = HOME + "/data/catsVsDogs"
CHECKPOINT_DIR = "models/feature_extraction"
IMG_DIM = 224
BATCH_SIZE = 16
ROTATION_DEGREES = 72
LEARNING_RATE = 0.0001
NUM_CPUS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Define Preprocessing and Data Augmentation¶

In [ ]:
class Preprocess(nn.Module):
    def __init__(self, rescale: bool = False) -> None:
        super().__init__()
        self.rescale = rescale

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.rescale:
            self.transform = transforms.Compose(
                [
                    transforms.Resize(
                        size=(IMG_DIM, IMG_DIM),
                        interpolation=transforms.InterpolationMode.BILINEAR,
                    ),
                    # Rescale to [-1, 1] range:
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    # TODO: Add a transform here that resizes the image with bilinear interpolation
                    # HINT: look at the transforms used above in the rescale block
                    ==> YOUR CODE HERE
  
                    transforms.PILToTensor(),
                ]
            )
        return self.transform(x)


class DataAugmentation(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.transforms = nn.Sequential(
            transforms.RandomAffine(degrees=0, shear=0.2),  # Shear
            transforms.RandomResizedCrop(
                size=IMG_DIM,
                scale=(0.8, 1.2),
                interpolation=transforms.InterpolationMode.NEAREST,
            ),  # Zoom
            
            # TODO: Add a transform here that randomly flips images horizontally 
            # HINT: use torchvision.transforms.RandomHorizontalFlip()
            ==> YOUR CODE HERE

        )

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.transforms(x)

Define Data Module¶

In [ ]:
class CatsDogsData(pl.LightningDataModule):
    """Cats and dogs dataset class."""

    def __init__(
        self,
        augment: bool = False,
        rescale: bool = False,
        batch_size: int = BATCH_SIZE,
        data_dir: str = DATA_DIR,
    ):
        super().__init__()
        self.augment = augment
        self.rescale = rescale
        self.batch_size = batch_size
        self.preprocess = Preprocess(self.rescale)
        self.transform = DataAugmentation()

    def prepare_data(self):
        """Load data and apply transforms."""
        if self.augment:
            train_orig = datasets.ImageFolder(
                root=os.path.join(DATA_DIR, "train"), transform=self.preprocess
            )
            train_aug = datasets.ImageFolder(
                root=os.path.join(DATA_DIR, "train"),
                transform=transforms.Compose([self.preprocess, self.transform]),
            )
            self.train_data = ConcatDataset([train_orig, train_aug])
        else:
            self.train_data = datasets.ImageFolder(
                root=os.path.join(DATA_DIR, "train"), transform=self.preprocess
            )
        self.val_data = datasets.ImageFolder(
            root=os.path.join(DATA_DIR, "val"), transform=self.preprocess
        )
        self.test_data = datasets.ImageFolder(
            root=os.path.join(DATA_DIR, "test"), transform=self.preprocess
        )

    def show_batch(self, win_size=(10, 10)):
        def _to_viz(data):
            return make_grid(data).permute(2, 1, 0)

        imgs, labels = (next(iter(self.train_dataloader())))
        plt.figure(figsize=win_size)
        plt.imshow(_to_viz(imgs))

    def train_dataloader(self):
        """Train DataLoader."""
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=NUM_CPUS)

    def val_dataloader(self):
        """Validation DataLoader."""
        # TODO:  Add a data loader here for the validation dataset. No shuffling is needed.
        ==> YOUR CODE HERE

    def test_dataloader(self):
        """Test DataLoader."""
        return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, num_workers=NUM_CPUS)

Define Model¶

In [ ]:
class MobileNetV2Model(pl.LightningModule):
    """MobileNetV2 model class."""

    def __init__(self):
        super().__init__()

        self.model = torch.hub.load(
            "pytorch/vision:v0.10.0", "mobilenet_v2", weights=torchvision.models.MobileNet_V2_Weights.DEFAULT, progress=False
        )

        self.accuracy = torchmetrics.Accuracy(task="binary").to(DEVICE)

        # Freeze the pretrained model weights
        for param in self.model.parameters():

            # TODO:  Freeze all weights in the pre-trained model by setting requires_grad 
            #        for each parameter to False
            ==> YOUR CODE HERE

        # Top model
        self.model.pooling = nn.AdaptiveAvgPool2d(output_size=1)
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.2), nn.Linear(1280, 1), nn.Sigmoid()
        )


    def forward(self, x):
        return self.model(x.float())

    def configure_optimizers(self):
        return torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    def training_step(self, batch: int, batch_idx: int):
        """Training step. Perform data augmentation if specified."""
        X, y = batch
        logits = self(X.float()).view(-1)
        train_loss = F.binary_cross_entropy(logits, y.float())
        pred = logits >= 0.5
        train_acc = self.accuracy(pred, y)

        self.log("train_loss", train_loss, on_epoch = True, on_step=False, prog_bar=True)
        self.log("train_acc", train_acc, on_epoch = True, on_step=False, prog_bar=True)
        return train_loss

    def validation_step(self, batch: int, batch_idx: int):
        """Validation step. Perform data augmentation if specified."""
        X, y = batch
        logits = self(X.float()).view(-1)
        valid_loss = F.binary_cross_entropy(logits, y.float())
        pred = logits >= 0.5
        valid_acc = self.accuracy(pred, y)

        self.log("val_loss", valid_loss, prog_bar=True)
        self.log("val_acc", valid_acc, prog_bar=True)
        return valid_loss

    def test_step(self, batch: int, batch_idx: int):
        """Test step. Perform data augmentation if specified."""
        X, y = batch
        logits = self(X.float()).view(-1)
        test_loss = F.binary_cross_entropy(logits, y.float())
        pred = logits >= 0.5
        test_acc = self.accuracy(pred, y)

        return test_loss

Train Model¶

In [ ]:
# Prepare data and helpers

data = CatsDogsData(augment=True, rescale=True)

# TODO: Prepare the dataset by calling the appropriate method from the CatsDogsData class
# HINT: Look at the CatsDogsDataclass above.
==> YOUR CODE HERE

def get_predict(model, data_loader):
    """Get predictions from model and DataLoader."""
    true_values = []
    predicted_values = []
    for imgs, labels in data_loader:
        imgs = imgs.to(DEVICE)
        outputs = model(imgs)
        true_values.extend(labels)
        predicted_values.extend((outputs >= 0.5).view(-1).cpu().numpy())

    return true_values, predicted_values
In [ ]:
# Define number of epochs
num_epochs = 5

data = CatsDogsData(augment=True, rescale=True)
data.prepare_data()

# Define checkpoint callback
checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath=CHECKPOINT_DIR,
    filename=str(num_epochs) + "_{epoch:02d}-{step}",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
    save_top_k=1,
)

# Define trainer
trainer = pl.Trainer(
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    max_epochs=num_epochs,
    callbacks=[checkpoint, TQDMProgressBar(refresh_rate=50)]
)

# Define and fit model
model = MobileNetV2Model()

# TODO: Train the model using the model and data
# HINT: Use trainer.fit(...)
==> YOUR CODE HERE
In [ ]:
# Saving model weights to path
model_path = "models/feature_extraction/best_model.ckpt"
trainer.save_checkpoint(model_path, weights_only=True)

Evaluate Model¶

In [ ]:
# Freeze model parameters for inference
model = MobileNetV2Model.load_from_checkpoint(checkpoint_path=model_path)
model = model.to(DEVICE)
model.freeze()

# Make predictions
y_train, pred_train = get_predict(model, data.train_dataloader())
y_val, pred_val = get_predict(model, data.val_dataloader())
y_test, pred_test = get_predict(model, data.test_dataloader())

# Print classification reports
print()
print(f"Train:\n {classification_report(y_train, pred_train, digits=4)}")
print(f"Val:\n {classification_report(y_val, pred_val, digits=4)}")

# TODO: Get performance numbers for test dataset using classification_report().
==> YOUR CODE HERE

Perform Inference¶

In [ ]:
# helper method to prepare image for model inference
def image_loader(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ])
    image = Image.open(image_path).convert("RGB")
    plt.figure(figsize=(5, 5))
    plt.imshow(image)
    plt.axis("off")
    plt.show()

    image = transform(image).unsqueeze(0).to(DEVICE)
    return image
In [ ]:
!ls ~/data/catsVsDogs/test/cats/cat.1070.jpg
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1070.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1233.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1080.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1132.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
# TODO:  Set image_path for cat image 1311
==> YOUR CODE HERE

img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]: