Skip to main content

Transfer Learning with PyTorch: Classifying Ants and Bees Using ResNet18

Transfer Learning

Transfer learning is a powerful technique in deep learning that allows us to leverage pre-trained models to solve new tasks with limited data. In this blog post, we’ll walk through a practical example of transfer learning using PyTorch. We’ll fine-tune a pre-trained ResNet18 model to classify images of ants and bees from the Hymenoptera dataset, downloaded from Kaggle. By the end, you’ll understand how to set up the dataset, apply data transformations, train the model, and visualize predictions—all with a few lines of code!

What is Transfer Learning?

Transfer learning involves taking a model trained on a large, general dataset (like ImageNet) and adapting it to a specific task. Instead of training a neural network from scratch, which requires massive data and computing resources, we start with a pre-trained model and tweak it for our needs. This approach is especially useful when working with small datasets, as it reduces training time and the need for extensive labeled data.

In this example, we’ll use ResNet18, pre-trained on ImageNet, and fine-tune it to distinguish between ants and bees—a binary classification task.

Code:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

# Kaggle API setup
os.environ["KAGGLE_USERNAME"] = "tspradeepkumar"  # Replace with your Kaggle username
os.environ["KAGGLE_KEY"] = "740d0138672a033f5e2020390c3cb021"  # Replace with your Kaggle API key

from kaggle.api.kaggle_api_extended import KaggleApi

# Enable CuDNN benchmarking for performance
cudnn.benchmark = True
plt.ion()  # Enable interactive mode for matplotlib

# Download Hymenoptera dataset from Kaggle if not already present
data_dir = './hymenoptera'
if not os.path.exists(os.path.join(data_dir, 'train')):
    print("Downloading Hymenoptera dataset from Kaggle...")
    api = KaggleApi()
    api.authenticate()
    api.dataset_download_files('thedatasith/hymenoptera', path=data_dir, unzip=True)
    print("Dataset downloaded and extracted.")
else:
    print("Hymenoptera dataset already exists.")

# Data transformations for training and validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
}

# Load datasets and create dataloaders
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                              shuffle=True, num_workers=4)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Helper to show images
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

# Show a batch of training data
inputs, classes = next(iter(dataloaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
    since = time.time()
    with TemporaryDirectory() as tempdir:
        best_model_path = os.path.join(tempdir, 'best_model.pt')
        torch.save(model.state_dict(), best_model_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)
            for phase in ['train', 'val']:
                model.train() if phase == 'train' else model.eval()
                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_path)
            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:.4f}')
        model.load_state_dict(torch.load(best_model_path))
    return model

# Visualization of model predictions
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for inputs, labels in dataloaders['val']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images // 2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'Predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

# Predict a single image
def visualize_model_predictions(model, img_path):
    model.eval()
    img = Image.open(img_path)
    img = data_transforms['val'](img)
    img = img.unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)
        plt.figure()
        plt.title(f'Predicted: {class_names[preds[0]]}')
        imshow(img.cpu().data[0])

# Load pretrained ResNet18 and modify for binary classification
model_conv = models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
model_conv = model_conv.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

# Train the model
model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=5)

# Visualize model performance
visualize_model(model_conv)

# Predict on a specific image
visualize_model_predictions(model_conv, os.path.join(data_dir, 'val', 'bees', 'bees2.jpg'))

plt.ioff()
plt.show()

In this post, we’ve demonstrated transfer learning with PyTorch using a pre-trained ResNet18 model. By fine-tuning only the final layer, we achieved good performance on the Hymenoptera dataset with minimal training time. This approach is versatile and can be adapted to other datasets and tasks—try experimenting with different models, datasets, or unfreezing more layers for even better results!

Sample Output
Transfer Learning
Transfer Learning
Transfer Learning
Transfer Learning




Comments

Popular posts from this blog

Installing ns3 in Ubuntu 22.04 | Complete Instructions

In this post, we are going to see how to install ns-3.36.1 in Ubuntu 22.04. You can follow the video for complete details Tools used in this simulation: NS3 version ns-3.36.1  OS Used: Ubuntu 22.04 LTS Installation of NS3 (ns-3.36.1) There are some changes in the ns3 installation procedure and the dependencies. So open a terminal and issue the following commands Step 1:  Prerequisites $ sudo apt update In the following packages, all the required dependencies are taken care and you can install all these packages for the complete use of ns3. $ sudo apt install g++ python3 python3-dev pkg-config sqlite3 cmake python3-setuptools git qtbase5-dev qtchooser qt5-qmake qtbase5-dev-tools gir1.2-goocanvas-2.0 python3-gi python3-gi-cairo python3-pygraphviz gir1.2-gtk-3.0 ipython3 openmpi-bin openmpi-common openmpi-doc libopenmpi-dev autoconf cvs bzr unrar gsl-bin libgsl-dev libgslcblas0 wireshark tcpdump sqlite sqlite3 libsqlite3-dev  libxml2 libxml2-dev libc6-dev libc6-dev-i386 libc...

Simulation of URDF, Gazebo and Rviz | ROS Noetic Tutorial 8

Design a User-defined robot of your choice (or you can use the URDF file) and enable the LIDAR Scanner so that any obstacle placed on the path of the light scan will cut the light rays. Visualize the robot in the Gazebo workspace, and also show the demonstration in RViz.   (NB: Gain knowledge on wiring URDF file and .launch file for enabling any user-defined robot to get launched in the gazebo platform.) SLAM : One of the most popular applications of ROS is SLAM(Simultaneous Localization and Mapping). The objective of the SLAM in mobile robotics is to construct and update the map of an unexplored environment with the help of the available sensors attached to the robot which will be used for exploring. URDF: Unified Robotics Description Format, URDF, is an XML specification used in academia and industry to model multibody systems such as robotic manipulator arms for manufacturing assembly lines and animatronic robots for amusement parks. URDF is especially popular with users of the ...

Installation of NS2 in Ubuntu 22.04 | NS2 Tutorial 2

NS-2.35 installation in Ubuntu 22.04 This post shows how to install ns-2.35 in Ubuntu 22.04 Operating System Since ns-2.35 is too old, it needs the following packages gcc-4.8 g++-4.8 gawk and some more libraries Follow the video for more instructions So, here are the steps to install this software: To download and extract the ns2 software Download the software from the following link http://sourceforge.net/projects/nsnam/files/allinone/ns-allinone-2.35/ns-allinone-2.35.tar.gz/download Extract it to home folder and in my case its /home/pradeepkumar (I recommend to install it under your home folder) $ tar zxvf ns-allinone-2.35.tar.gz or Right click over the file and click extract here and select the home folder. $ sudo apt update $ sudo apt install build-essential autoconf automake libxmu-dev gawk To install gcc-4.8 and g++-4.8 $ sudo gedit /etc/apt/sources.list make an entry in the above file deb http://in.archive.ubuntu.com/ubuntu/ bionic main universe $ sudo apt update Since, it...