Transfer Learning on Cats-Dogs Classification - Feature Extraction¶
Features are extracted from a MobileNet-V2 model pre-trained on ImageNet data, then passed through a new classification head to classify cats vs. dogs.¶
Adapted from https://www.tensorflow.org/tutorials/images/transfer_learning¶
CIML Summer Institute¶
UC San Diego¶
Setup¶
In [ ]:
# --- Set logging level ---
import warnings
warnings.filterwarnings("ignore")
In [ ]:
import os
import random
import shutil
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# TODO: Import PyTorch Lightning as 'pl'
==> YOUR CODE HERE
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from sklearn.metrics import (
accuracy_score,
classification_report,
f1_score,
precision_score,
recall_score
)
from torch.utils.data import ConcatDataset, DataLoader
from torchmetrics.functional import accuracy
from torchvision import datasets, transforms
from torchvision.io import read_image
from torchvision.utils import make_grid
import torchvision
plt.rcParams["figure.facecolor"] = "white"
In [ ]:
# Set global random seed for reproducibility
def set_seed(seed=1234):
os.environ["PYTHONHASHSEED"] = str(0) # disable hash randomization
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
pl.seed_everything(seed, workers=False)
# TODO: Call set_seed() to set all seeds
==> YOUR CODE HERE
In [ ]:
!jupyter --version
!python --version
In [ ]:
print(
f"PyTorch version: {torch.__version__}\nPyTorch Lightning version: {pl.__version__}"
)
In [ ]:
# TODO: Use linux shell command nvidia-smi to see GPU device.
==> YOUR CODE HERE
In [ ]:
from os.path import expanduser
HOME = expanduser("~")
DATA_DIR = HOME + "/data/catsVsDogs"
CHECKPOINT_DIR = "models/feature_extraction"
IMG_DIM = 224
BATCH_SIZE = 16
ROTATION_DEGREES = 72
LEARNING_RATE = 0.0001
NUM_CPUS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Define Preprocessing and Data Augmentation¶
In [ ]:
class Preprocess(nn.Module):
def __init__(self, rescale: bool = False) -> None:
super().__init__()
self.rescale = rescale
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.rescale:
self.transform = transforms.Compose(
[
transforms.Resize(
size=(IMG_DIM, IMG_DIM),
interpolation=transforms.InterpolationMode.BILINEAR,
),
# Rescale to [-1, 1] range:
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
]
)
else:
self.transform = transforms.Compose(
[
# TODO: Add a transform here that resizes the image with bilinear interpolation
# HINT: look at the transforms used above in the rescale block
==> YOUR CODE HERE
transforms.PILToTensor(),
]
)
return self.transform(x)
class DataAugmentation(nn.Module):
def __init__(self) -> None:
super().__init__()
self.transforms = nn.Sequential(
transforms.RandomAffine(degrees=0, shear=0.2), # Shear
transforms.RandomResizedCrop(
size=IMG_DIM,
scale=(0.8, 1.2),
interpolation=transforms.InterpolationMode.NEAREST,
), # Zoom
# TODO: Add a transform here that randomly flips images horizontally
# HINT: use torchvision.transforms.RandomHorizontalFlip()
==> YOUR CODE HERE
)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.transforms(x)
Define Data Module¶
In [ ]:
class CatsDogsData(pl.LightningDataModule):
"""Cats and dogs dataset class."""
def __init__(
self,
augment: bool = False,
rescale: bool = False,
batch_size: int = BATCH_SIZE,
data_dir: str = DATA_DIR,
):
super().__init__()
self.augment = augment
self.rescale = rescale
self.batch_size = batch_size
self.preprocess = Preprocess(self.rescale)
self.transform = DataAugmentation()
def prepare_data(self):
"""Load data and apply transforms."""
if self.augment:
train_orig = datasets.ImageFolder(
root=os.path.join(DATA_DIR, "train"), transform=self.preprocess
)
train_aug = datasets.ImageFolder(
root=os.path.join(DATA_DIR, "train"),
transform=transforms.Compose([self.preprocess, self.transform]),
)
self.train_data = ConcatDataset([train_orig, train_aug])
else:
self.train_data = datasets.ImageFolder(
root=os.path.join(DATA_DIR, "train"), transform=self.preprocess
)
self.val_data = datasets.ImageFolder(
root=os.path.join(DATA_DIR, "val"), transform=self.preprocess
)
self.test_data = datasets.ImageFolder(
root=os.path.join(DATA_DIR, "test"), transform=self.preprocess
)
def show_batch(self, win_size=(10, 10)):
def _to_viz(data):
return make_grid(data).permute(2, 1, 0)
imgs, labels = (next(iter(self.train_dataloader())))
plt.figure(figsize=win_size)
plt.imshow(_to_viz(imgs))
def train_dataloader(self):
"""Train DataLoader."""
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=NUM_CPUS)
def val_dataloader(self):
"""Validation DataLoader."""
# TODO: Add a data loader here for the validation dataset. No shuffling is needed.
==> YOUR CODE HERE
def test_dataloader(self):
"""Test DataLoader."""
return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, num_workers=NUM_CPUS)
Define Model¶
In [ ]:
class MobileNetV2Model(pl.LightningModule):
"""MobileNetV2 model class."""
def __init__(self):
super().__init__()
self.model = torch.hub.load(
"pytorch/vision:v0.10.0", "mobilenet_v2", weights=torchvision.models.MobileNet_V2_Weights.DEFAULT, progress=False
)
self.accuracy = torchmetrics.Accuracy(task="binary").to(DEVICE)
# Freeze the pretrained model weights
for param in self.model.parameters():
# TODO: Freeze all weights in the pre-trained model by setting requires_grad
# for each parameter to False
==> YOUR CODE HERE
# Top model
self.model.pooling = nn.AdaptiveAvgPool2d(output_size=1)
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.2), nn.Linear(1280, 1), nn.Sigmoid()
)
def forward(self, x):
return self.model(x.float())
def configure_optimizers(self):
return torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
def training_step(self, batch: int, batch_idx: int):
"""Training step. Perform data augmentation if specified."""
X, y = batch
logits = self(X.float()).view(-1)
train_loss = F.binary_cross_entropy(logits, y.float())
pred = logits >= 0.5
train_acc = self.accuracy(pred, y)
self.log("train_loss", train_loss, on_epoch = True, on_step=False, prog_bar=True)
self.log("train_acc", train_acc, on_epoch = True, on_step=False, prog_bar=True)
return train_loss
def validation_step(self, batch: int, batch_idx: int):
"""Validation step. Perform data augmentation if specified."""
X, y = batch
logits = self(X.float()).view(-1)
valid_loss = F.binary_cross_entropy(logits, y.float())
pred = logits >= 0.5
valid_acc = self.accuracy(pred, y)
self.log("val_loss", valid_loss, prog_bar=True)
self.log("val_acc", valid_acc, prog_bar=True)
return valid_loss
def test_step(self, batch: int, batch_idx: int):
"""Test step. Perform data augmentation if specified."""
X, y = batch
logits = self(X.float()).view(-1)
test_loss = F.binary_cross_entropy(logits, y.float())
pred = logits >= 0.5
test_acc = self.accuracy(pred, y)
return test_loss
Train Model¶
In [ ]:
# Prepare data and helpers
data = CatsDogsData(augment=True, rescale=True)
# TODO: Prepare the dataset by calling the appropriate method from the CatsDogsData class
# HINT: Look at the CatsDogsDataclass above.
==> YOUR CODE HERE
def get_predict(model, data_loader):
"""Get predictions from model and DataLoader."""
true_values = []
predicted_values = []
for imgs, labels in data_loader:
imgs = imgs.to(DEVICE)
outputs = model(imgs)
true_values.extend(labels)
predicted_values.extend((outputs >= 0.5).view(-1).cpu().numpy())
return true_values, predicted_values
In [ ]:
# Define number of epochs
num_epochs = 5
data = CatsDogsData(augment=True, rescale=True)
data.prepare_data()
# Define checkpoint callback
checkpoint = pl.callbacks.ModelCheckpoint(
dirpath=CHECKPOINT_DIR,
filename=str(num_epochs) + "_{epoch:02d}-{step}",
monitor="val_loss",
mode="min",
save_weights_only=True,
save_top_k=1,
)
# Define trainer
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
max_epochs=num_epochs,
callbacks=[checkpoint, TQDMProgressBar(refresh_rate=50)]
)
# Define and fit model
model = MobileNetV2Model()
# TODO: Train the model using the model and data
# HINT: Use trainer.fit(...)
==> YOUR CODE HERE
In [ ]:
# Saving model weights to path
model_path = "models/feature_extraction/best_model.ckpt"
trainer.save_checkpoint(model_path, weights_only=True)
Evaluate Model¶
In [ ]:
# Freeze model parameters for inference
model = MobileNetV2Model.load_from_checkpoint(checkpoint_path=model_path)
model = model.to(DEVICE)
model.freeze()
# Make predictions
y_train, pred_train = get_predict(model, data.train_dataloader())
y_val, pred_val = get_predict(model, data.val_dataloader())
y_test, pred_test = get_predict(model, data.test_dataloader())
# Print classification reports
print()
print(f"Train:\n {classification_report(y_train, pred_train, digits=4)}")
print(f"Val:\n {classification_report(y_val, pred_val, digits=4)}")
# TODO: Get performance numbers for test dataset using classification_report().
==> YOUR CODE HERE
Perform Inference¶
In [ ]:
# helper method to prepare image for model inference
def image_loader(image_path):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])
image = Image.open(image_path).convert("RGB")
plt.figure(figsize=(5, 5))
plt.imshow(image)
plt.axis("off")
plt.show()
image = transform(image).unsqueeze(0).to(DEVICE)
return image
In [ ]:
!ls ~/data/catsVsDogs/test/cats/cat.1070.jpg
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1070.jpg"
img = image_loader(image_path)
with torch.no_grad():
img_y_pred = model(img).item()
print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1233.jpg"
img = image_loader(image_path)
with torch.no_grad():
img_y_pred = model(img).item()
print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/cats/cat.1080.jpg"
img = image_loader(image_path)
with torch.no_grad():
img_y_pred = model(img).item()
print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
image_path = DATA_DIR + "/test/dogs/dog.1132.jpg"
img = image_loader(image_path)
with torch.no_grad():
img_y_pred = model(img).item()
print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]:
# TODO: Set image_path for cat image 1311
==> YOUR CODE HERE
img = image_loader(image_path)
with torch.no_grad():
img_y_pred = model(img).item()
print()
print(f"Prediction for {image_path}: \n{'dog' if img_y_pred >= 0.5 else 'cat'} ({img_y_pred:.4f})")
In [ ]: