Load and re-use a TensorFlow Hub model#


  • Dataiku >= 10.0.0.

  • A Code Environment with the following packages:

    • tensorflow==2.8.0

    • tensorflow-estimator==2.6.0

    • tensorflow-hub==0.12.0

    • protobuf>=3.20,<3.21

    • requests==2.28.1

    • Pillow==9.3.0


Machine learning use cases can involve a lot of input data and compute-heavy thus expensive model training. Instead, you might want to reuse artifacts for common tasks like pre-processing images or text for model training. You can load pre-trained models from remote repositories and embed them in your code.

In this tutorial, you will use Dataiku’s Code Environment resources feature to download and save a pre-trained image classification model from the TensorFlow Hub. You will then reuse it to classify a picture downloaded from the Internet.

Loading the pre-trained model#

The first step is to download the required assets for your pre-trained model. To do so, in the Resources screen of your Code Environment, input the following initialization script then click on Update:

## Base imports
from dataiku.code_env_resources import clear_all_env_vars
from dataiku.code_env_resources import set_env_path
from dataiku.code_env_resources import set_env_var

# Clears all environment variables defined by previously run script

## TensorFlow
# Set TensorFlow cache directory
set_env_path("TFHUB_CACHE_DIR", "tensorflow")

# Import TensorFlow Hub
import tensorflow_hub as hub

# Download pretrained model: automatically managed by Tensorflow,
# does not download anything if model is already in TFHUB_CACHE_DIR
model_hub_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"

This script will retrieve a MobileNet v2 model from Tensorflow Hub and store it on the Dataiku Instance.

Note that it will only need to run once, after that all users allowed to use the Code Environment will be able to leverage the pre-trained model with re-downloading it again.

Using the pre-trained model for inference#

You can now re-use this pre-trained model in your Dataiku Project’s Python Recipe or notebook. Here is an example adapted from a tutorial on the Tensorflow website that downloads a sample image and predicts its class from the ImageNet labels.

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np

from PIL import Image

model_name = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"

# Load the pre-trained model
img_shape = (224, 224)
classifier = tf.keras.Sequential([
    hub.KerasLayer(model_name, input_shape=img_shape+(3,))

# Download image and compute prediction
img_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg"
img = tf.keras.utils.get_file('image.jpg', img_url)
img = Image.open(img).resize(IMAGE_SHAPE)
img = np.array(img)/255.0
result = classifier.predict(img[np.newaxis, ...])

# Map the prediction result to the corresponding class label
labels_url = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
predicted_class = tf.math.argmax(result[0], axis=-1)
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt', labels_url)
imagenet_labels = np.array(open(labels_path).read().splitlines())
predicted_class_name = imagenet_labels[predicted_class]

print(f"Predicted class name: {predicted_class_name}")

Running this code should give you the following output:

Predicted class name: tiger

Wrapping up#

Your pre-trained model is now operational! From there you can easily reuse it, e.g. to directly classify other images stored in a Managed Folder or to fine-tune it for a more specific task.