Transfer Learning on Cats-Dogs Classification - Fine Tune¶

Fine-tune pre-trained CNN's top layers and classification layers to classify cats vs. dogs.¶

Adapted from https://www.tensorflow.org/tutorials/images/transfer_learning¶

CIML Summer Institute¶

UC San Diego¶

Setup¶

In [ ]:
import os
import random
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torchmetrics
from pytorch_lightning import callbacks as pl_callbacks
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from sklearn.metrics import classification_report
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from torchmetrics.functional import accuracy
from torchvision import datasets, models, transforms
# from torchsummary import summary
import torchvision
In [ ]:
# Set global random seed for reproducibility

def set_seed(seed=1234):
    os.environ["PYTHONHASHSEED"] = str(0)  # disable hash randomization
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    pl.seed_everything(seed, workers=False)

set_seed()
In [ ]:
!jupyter --version
print (pl.__version__)
print (torch.__version__)
!python --version

!nvidia-smi
In [ ]:
from os.path import expanduser
HOME = expanduser("~")

# TODO: Set DATA_DIR with the path to data in home directory, under 'data/catsVsDogs'
==> YOUR CODE HERE

CHECKPOINT_DIR = "models/finetune"
NUM_CPUS = 4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
In [ ]:
IMAGE_DIM = 224
MEAN = (0.5, 0.5, 0.5)
STD = (0.5, 0.5, 0.5)
BATCH_SIZE = 16
LEARNING_RATE = 1e-5

Define Transforms¶

Same transforms as in feature extraction, just in a different format

In [ ]:
transform = {
    "train": transforms.Compose(
        [
            transforms.Resize(
                size=(IMAGE_DIM, IMAGE_DIM),
                interpolation=transforms.InterpolationMode.BILINEAR,
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=MEAN, std=STD),
            transforms.RandomAffine(degrees=0, shear=0.2),  # Shear
            transforms.RandomResizedCrop(
                size=IMAGE_DIM,
                scale=(0.8, 1.2),
                interpolation=transforms.InterpolationMode.NEAREST,
            ),  # Zoom

            # TODO: add a transform here that randomly flips images horizontally 
            # HINT: use torchvision.transforms.RandomHorizontalFlip()
            ==> YOUR CODE HERE
        ]
    ),
    "val": transforms.Compose(
        [
            # TODO: add a transform here that resizes the image with interpolation
            # HINT: look at the transforms used above in the rescale block
            ==> YOUR CODE HERE
            
            transforms.ToTensor(),
            transforms.Normalize(mean=MEAN, std=STD),
        ]
    ),
}

Define Data Module¶

In [ ]:
class CatsDogsData(pl.LightningDataModule):
    def __init__(self, data_dir=DATA_DIR, batch_size=BATCH_SIZE):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform
        
    def prepare_data(self):
        self.train_data = datasets.ImageFolder(
            root=os.path.join(self.data_dir, "train"), transform=self.transform["train"]
        )
        self.val_data = datasets.ImageFolder(
            root=os.path.join(self.data_dir, "val"), transform=self.transform["val"]
        )
        self.test_data = datasets.ImageFolder(
            root=os.path.join(self.data_dir, "test"), transform=self.transform["val"]
        )
        
    def train_dataloader(self):
        """Train DataLoader."""
        # Num workers - speed up training
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=NUM_CPUS)

    def val_dataloader(self):
        """Validation DataLoader."""
        # TODO: Define a data loader here for validation data.  Note that shuffling is not needed for validation.
        ==> YOUR CODE HERE

        return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=False, num_workers=NUM_CPUS)

    def test_dataloader(self):
        """Test DataLoader."""
        return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, num_workers=NUM_CPUS)

Define Model¶

In [ ]:
class MobileNetV2Model(pl.LightningModule):
    """MobileNetV2 model class."""

    def __init__(self):
        super().__init__()
        self.automatic_optimization = (True)

        self.model = torch.hub.load(
            "pytorch/vision:v0.10.0", "mobilenet_v2", weights=torchvision.models.MobileNet_V2_Weights.DEFAULT, progress=False
        )
        
        self.accuracy = torchmetrics.Accuracy(task="binary").to(DEVICE)

        # Freeze weights up to layer 116
        for i, param in enumerate(self.model.parameters()):
            if i <= 116: param.requires_grad = False
            
        # Top model
        self.model.pooling = nn.AdaptiveAvgPool2d(output_size=1)
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.2), nn.Linear(1280, 1), nn.Sigmoid()
            # TODO: Add a Dropout layer for regularization with a probability of 0.2
            # TODO: Add a Linear layer to map from 1280 features to 1 output
            # TODO: Add a Sigmoid activation for binary classification
            # HINT: Use nn.Dropout(p=XX), nn.Linear(XX,XX), and nn.Sigmoid()
            ==> YOUR CODE HERE

        )

    def forward(self, x):
        return self.model(x.float())

    def configure_optimizers(self):
        opt = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()), lr=LEARNING_RATE
        )
        scheduler = LinearLR(opt, start_factor=0.5, total_iters=3)
        return [opt], [scheduler]

    def training_step(self, batch: int, batch_idx: int):
        """Training step."""
        X, y = batch
        logits = self(X.float()).view(-1)
        train_loss = F.binary_cross_entropy(logits, y.float())
        pred = logits >= 0.5
        train_acc = self.accuracy(pred, y)

        self.log("train_loss", train_loss, prog_bar=True, on_epoch=True, on_step=False)
        self.log("train_acc", train_acc, prog_bar=True, on_epoch=True, on_step=False)
        return train_loss

    def validation_step(self, batch: int, batch_idx: int):
        """Validation step."""
        X, y = batch
        logits = self(X.float()).view(-1)
        valid_loss = F.binary_cross_entropy(logits, y.float())
        pred = logits >= 0.5
        valid_acc = self.accuracy(pred, y)

        # TODO: Log "val_loss" and "val_acc" using self.log()
        # HINT: Look at the training_step method above
        ==> YOUR CODE HERE

        return valid_loss

    def test_step(self, batch: int, batch_idx: int):
        """Test step."""
        X, y = batch
        logits = self(X.float()).view(-1)
        test_loss = F.binary_cross_entropy(logits, y.float())
        pred = logits >= 0.5
        test_acc = self.accuracy(pred, y)
        return test_loss

Train Model¶

In [ ]:
# Define max epochs
num_epochs = 30

# Define early stopping callback
early_stop = pl_callbacks.EarlyStopping(
    monitor="val_loss", patience=3, min_delta=1e-3, verbose=True, mode="min"
)

data = CatsDogsData()

# TODO: Prepare the dataset by calling the appropriate method from the CatsDogsData class
# HINT: Look at the data module code above!
==> YOUR CODE HERE

data.prepare_data()


def get_predict(model, data_loader):
    """Get predictions from model and DataLoader."""
    true_values = []
    predicted_values = []
    for imgs, labels in data_loader:
        imgs = imgs.to(DEVICE)
        outputs = model(imgs)
        true_values.extend(labels)
        predicted_values.extend((outputs >= 0.5).view(-1).cpu().numpy())

    return true_values, predicted_values


# Define model checkpoint callback
checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath=CHECKPOINT_DIR,
    filename=str(num_epochs) + "_{epoch:02d}-{step}",
    monitor="val_loss",
    mode="min",
    save_weights_only=True,
    save_top_k=1,
    verbose = True,
)

trainer = pl.Trainer(
    accelerator="gpu",
    max_epochs=num_epochs,
    callbacks=[checkpoint, early_stop, TQDMProgressBar(refresh_rate=50)]
)

# Load model from feature_extraction

# TODO: Set model_path to the model trained from feature extraction
# HINT: Look in the models subdirectory
# HINT: Model weights are saved in best_model.ckpt.  
==> YOUR CODE HERE

model = MobileNetV2Model.load_from_checkpoint(
    checkpoint_path=model_path,
    strict=False,
)

# summarize model
model = model.to(DEVICE)
# summary(model, (3, 224, 224))
In [ ]:
# Fit model and get best model path

# TODO: Train the model using the model and data
# HINT: Use trainer.fit(...)
==> YOUR CODE HERE

best_model_path = checkpoint.best_model_path
print(f"Best model saved at: {best_model_path}")

Evaluate Model¶

In [ ]:
model = MobileNetV2Model.load_from_checkpoint(checkpoint_path=best_model_path)
model = model.to(DEVICE)
model.freeze()

# Make predictions
y_train, pred_train = get_predict(model, data.train_dataloader())
y_val, pred_val = get_predict(model, data.val_dataloader())

# TODO: Get predictions on the test set and print the classification report.  See train & val above.
==> YOUR CODE HERE
In [ ]:
print(checkpoint.best_model_path)
In [ ]:
print(f"Train:\n {classification_report(y_train, pred_train, digits=4)}")
print(f"Val:\n {classification_report(y_val, pred_val, digits=4)}")
print(f"Test:\n {classification_report(y_test, pred_test, digits=4)}")

Perform Inference¶

In [ ]:
model = model.to(DEVICE)
In [ ]:
# helper method to prepare image for model inference
def image_loader(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ])
    image = Image.open(image_path).convert("RGB")
    plt.figure(figsize=(5, 5))
    plt.imshow(image)
    plt.axis("off")
    plt.show()

    image = transform(image).unsqueeze(0).to(DEVICE)
    return image
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1070.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")

#Closer prediction is to 0, more confident it is a cat; close prediction is to 1, more confident it is a dog
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1233.jpg"
img = image_loader(image_path)

# TODO: Run the model on the image and get the predicted value
==> YOUR CODE HERE

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1080.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1132.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1311.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1338.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1342.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1180.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1048.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1342.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1308.jpg"
img = image_loader(image_path)
with torch.no_grad():
    img_y_pred = model(img).item()

print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
 
In [ ]: