Load and re-use a PyTorch model#

Prerequisites#

Introduction#

Machine learning use cases can involve a lot of input data and compute-heavy thus expensive model training. You might not want to retrain a model from scratch for common tasks like processing images/ text or during your initial experiments. Instead, you can load pre-trained models retrieved from remote repositories and use them for generating predictions.

In this tutorial, you will use Dataiku’s Code Environment resources feature to download and save a pre-trained image classification model from PyTorch Hub. You will then re-use that model to predict the class of a downloaded image.

Loading the pre-trained model#

The first step is to download the required assets for your pre-trained model. To do so, in the Resources screen of your Code Environment, input the following initialization script then click on Update:

## Base imports
from dataiku.code_env_resources import clear_all_env_vars
from dataiku.code_env_resources import set_env_path
from dataiku.code_env_resources import set_env_var
from dataiku.code_env_resources import grant_permissions

# Import torchvision models
import torchvision.models as models

# Clears all environment variables defined by previously run script
clear_all_env_vars()

## PyTorch
# Set PyTorch cache directory
set_env_path("TORCH_HOME", "pytorch")

# Download pretrained model: automatically managed by PyTorch,
# does not download anything if model is already in TORCH_HOME
resnet18 = models.resnet18(weights=True)

# Grant everyone read access to pretrained models in pytorch/ folder
# (by default, PyTorch makes them only readable by the owner)
grant_permissions("pytorch")

This script will retrieve a ResNet18 model from PyTorch Hub and store it on the Dataiku Instance.

Note that it will only need to run once, after that all users allowed to use the Code Environment will be able to leverage the pre-trained model with re-downloading it again.

Using the pre-trained model for inference#

You can now re-use this pre-trained model in your Dataiku Project’s Python Recipe or notebook. Here is an example adapted from a tutorial on the PyTorch website that downloads a sample image and predicts its class from the ImageNet labels.

import torch

from torchvision import models, transforms
from PIL import Image

# Import pre-trained model from cache & set to evaluation mode
model = models.resnet18(weights=True)
model.eval()

# Download example image from pytorch (it's a doggie, but what kind?)
img_url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg"
img_file = "dog.jpg"
torch.hub.download_url_to_file(img_url, img_file)

# Pre-process image & create a mini-batch as expected by the model
input_image = Image.open(img_file)
preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) 

# Run softmax to get probabilities since the output has unnormalized scores 
with torch.no_grad():
    output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)

# Download ImageNet class labels
classes_url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
classes_file = "imagenet_classes.txt"
torch.hub.download_url_to_file(classes_url, classes_file)

# Map prediction to class labels and print top 5 predicted classes
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

Running this code should give you an output similar to this:

Samoyed 0.8846230506896973
Arctic fox 0.0458049401640892
white wolf 0.044276054948568344
Pomeranian 0.00562133826315403
Great Pyrenees 0.004651993978768587

Wrapping up#

Your pre-trained model is now operational! From there you can easily reuse it, e.g. to directly classify other images stored in a Managed Folder or to fine-tune it for a more specific task.