Load and re-use a PyTorch model#
Python >= 3.9
A Code Environment with the following packages:
Machine learning use cases can involve a lot of input data and compute-heavy thus expensive model training. You might not want to retrain a model from scratch for common tasks like processing images/ text or during your initial experiments. Instead, you can load pre-trained models retrieved from remote repositories and use them for generating predictions.
In this tutorial, you will use Dataiku’s Code Environment resources feature to download and save a pre-trained image classification model from PyTorch Hub. You will then re-use that model to predict the class of a downloaded image.
Loading the pre-trained model#
The first step is to download the required assets for your pre-trained model. To do so, in the Resources screen of your Code Environment, input the following initialization script then click on Update:
## Base imports from dataiku.code_env_resources import clear_all_env_vars from dataiku.code_env_resources import set_env_path from dataiku.code_env_resources import set_env_var from dataiku.code_env_resources import grant_permissions # Import torchvision models import torchvision.models as models # Clears all environment variables defined by previously run script clear_all_env_vars() ## PyTorch # Set PyTorch cache directory set_env_path("TORCH_HOME", "pytorch") # Download pretrained model: automatically managed by PyTorch, # does not download anything if model is already in TORCH_HOME resnet18 = models.resnet18(weights=True) # Grant everyone read access to pretrained models in pytorch/ folder # (by default, PyTorch makes them only readable by the owner) grant_permissions("pytorch")
This script will retrieve a ResNet18 model from PyTorch Hub and store it on the Dataiku Instance.
Note that it will only need to run once, after that all users allowed to use the Code Environment will be able to leverage the pre-trained model with re-downloading it again.
Using the pre-trained model for inference#
You can now re-use this pre-trained model in your Dataiku Project’s Python Recipe or notebook. Here is an example adapted from a tutorial on the PyTorch website that downloads a sample image and predicts its class from the ImageNet labels.
import torch from torchvision import models, transforms from PIL import Image # Import pre-trained model from cache & set to evaluation mode model = models.resnet18(weights=True) model.eval() # Download example image from pytorch (it's a doggie, but what kind?) img_url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg" img_file = "dog.jpg" torch.hub.download_url_to_file(img_url, img_file) # Pre-process image & create a mini-batch as expected by the model input_image = Image.open(img_file) preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) # Run softmax to get probabilities since the output has unnormalized scores with torch.no_grad(): output = model(input_batch) probabilities = torch.nn.functional.softmax(output, dim=0) # Download ImageNet class labels classes_url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" classes_file = "imagenet_classes.txt" torch.hub.download_url_to_file(classes_url, classes_file) # Map prediction to class labels and print top 5 predicted classes with open("imagenet_classes.txt", "r") as f: categories = [s.strip() for s in f.readlines()] top5_prob, top5_catid = torch.topk(probabilities, 5) for i in range(top5_prob.size(0)): print(categories[top5_catid[i]], top5_prob[i].item())
Running this code should give you an output similar to this:
Samoyed 0.8846230506896973 Arctic fox 0.0458049401640892 white wolf 0.044276054948568344 Pomeranian 0.00562133826315403 Great Pyrenees 0.004651993978768587
Your pre-trained model is now operational! From there you can easily reuse it, e.g. to directly classify other images stored in a Managed Folder or to fine-tune it for a more specific task.