Programmatic RAG with Dataiku’s LLM Mesh and Langchain#

While large language models (LLM) can perform well on text generation, they can only do so by leveraging the information gathered from the data they have been trained on. However, in several use-cases, you may need to rely on data that the model hasn’t seen (e.g., recent or domain-specific data).

In this tutorial you will leverage OpenAI’s GPT-4 model in combination with a custom source of information, namely a PDF file. To implement this process also known as retrieval-augmented generation (RAG), you will take advantage of Dataiku’s LLM mesh features, namely:

  • LLM connections,

  • Knowledge Banks,

  • Dataiku’s Langchain integrations for vector stores and LLMs.

Prerequisites#

  • Dataiku >= 12.4

  • “Use” permission on a Python code environment with the “Retrieval augmented generation models” package set installed.

  • OpenAI LLM connection with a GPT-4 model enabled.

Creating the Knowledge Bank#

In this section, you will retrieve an external document and index it to be queried. As of version 12.3, Dataiku provides a native Flow item called Knowledge Bank that points to a vector databases where embeddings are stored.

Getting the data#

The examples covered in this tutorial will be based on the World Bank’s Global Economic Prospects (GEP) 2023 report.

  • Create a new managed folder called documents.

  • Download the report’s PDF version from this page (Downloads -> Full report) into the documents managed folder. Rename the file as world_bank_2023.pdf.

Transforming your document into embeddings#

The first step of your Flow will extract the text from your document, transform it into embeddings then store them in a vector database. Simply put, these embeddings capture the semantic meaning and context of the encoded text: querying the vector database with a text input will return the most similar elements in that database (in the semantic sense).

In practice you will use that vector database to enrich your prompt by adding relevant text from the document. This will allow the LLM to directly leverage the document’s data when asked about its content.

In the Dataiku LLM Mesh features, the vector database is represented by a KB Flow item which itself is created from a visual embedding recipe. Before running this recipe, you will first split your document’s text into smaller chunks to facilitate

When indexing a document into a vector database, it is a common practice to split it into smaller chunks first. Searching across multiple smaller chunks instead of a single large document allows for a more granular match between the input prompt and the document’s content.

  • Create a Python recipe with documents as an input and a new output dataset called documents_splits with the following code:

recipe_split.py#
import dataiku
import os
import tiktoken

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter

FILE_NAME = "world_bank_2023.pdf"
MF_ID = ""  # Fill with your `documents` managed folder id

CHUNK_SIZE = 1000
CHUNK_OVERLAP = 100

enc = tiktoken.encoding_for_model("gpt-4")
splitter = CharacterTextSplitter(chunk_size=CHUNK_SIZE,
                                 separator='\n',
                                 chunk_overlap=CHUNK_OVERLAP,
                                 length_function=len)

docs_folder = dataiku.Folder(MF_ID)
docs_splits = dataiku.Dataset("documents_splits")
docs_splits.write_schema([
    {"name": "split_id", "type": "string"},
    {"name": "text", "type": "string"},
    {"name": "nb_tokens", "type": "int"}
])

# Read PDF file, split it into smaller chunks and write each chunk data + metadata
# in the output dataset
pdf_path = os.path.join(docs_folder.get_path(), FILE_NAME)
splits = PyPDFLoader(pdf_path) \
    .load_and_split(text_splitter=splitter)

with docs_splits.get_writer() as w:
    for i, s in enumerate(splits):
        d = s.dict()
        w.write_row_dict({"split_id": str(d["metadata"]["page"]),
                          "text": d["page_content"],
                          "nb_tokens": len(enc.encode(d["page_content"]))
                          })

Once built, each row from documents_splits will contain a distinct chunk with:

  • its id,

  • its text,

  • its length measured by the number of tokens (useful to quantify how much of the LLM’s context window will be consumed, as well as evaluating the cost of computing embeddings with online services such as OpenAI).

  • Create a new embedding recipe with documents_splits as the input dataset and a KB called documents_splits_embedded with the following settings:

    • Embedding model: “Embedding (Ada 002)” (in the creation modal)

    • Knowledge column -> “text”

    • Metadata columns: split_id, nb_tokens

    • Document splitting method : “Do not split”

When accessing a KB from the Dataiku web UI, its URL is of the form:

https://dss.example/projects/YOUR_PROJECT_KEY/knowledge-bank/<DOCS_SPLITS_KB_ID>

Take note of <DOCS_SPLITS_KB_ID>. You can also retrieve this ID later on with the list_knowledge_banks() method.

Knowledge bases can be programmatically queried by being handled as Langchain vector stores. To test it, run the following code from a notebook:

sample_simsearch.py#
import dataiku

DOCS_SPLITS_KB_ID = ""  # Replace with your KB id

client = dataiku.api_client()
project = client.get_default_project()
kb_core = project.get_knowledge_bank(DOCS_SPLITS_KB_ID).as_core_knowledge_bank()
vector_store = kb_core.as_langchain_vectorstore()
query = "Summarise the current global status on inflation."
search_result = vector_store.similarity_search(query, include_metadata=True)

for r in search_result:
    print(r.json())

Running an enriched LLM query#

Now that the knowledge base is ready, you can use it in combination with a LLM. In practice, you will use your prompt as a query for similarity search in the knowledge base to retrieve additional context. This context will be appended to the initial prompt before running inference on the LLM.

Run the following code from a notebook:

sample_rag.py#
import dataiku
from langchain.chains.question_answering import load_qa_chain
from dataiku.langchain.dku_llm import DKUChatLLM

DOCS_SPLITS_KB_ID =  "" # Fill with your KB id
GPT_4_LLM_ID = ""  # Fill with your GPT-4 LLM id

# Retrieve the knowledge base and LLM handles
client = dataiku.api_client()
project = client.get_default_project()
kb_core = project.get_knowledge_bank(DOCS_SPLITS_KB_ID).as_core_knowledge_bank()
vector_store = kb_core.as_langchain_vectorstore()
gpt_4_lc = DKUChatLLM(llm_id=GPT_4_LLM_ID, temperature=0)

# Create and run the question answering chain
chain = load_qa_chain(gpt_4_lc, chain_type="stuff")
query = "Describe the state and causes of inflation in Europe."
search_results = vector_store.similarity_search(query, include_metadata=True)
resp = chain({"input_documents":search_results, "question": query})
print(resp["output_text"])

# In Europe, inflation persistence is notably influenced by energy prices. The pass-through of energy costs into broader
#  prices is contributing to the ongoing high inflation. Additionally, the sunsetting of fiscal programs that have 
# helped to mitigate price spikes for end-users may exacerbate inflation persistence. The absence of economic slack in
# Europe is also a factor, as it may increase the ability of firms and workers to exercise pricing power, making
#  inflation more responsive to economic activity.

If you don’t have your LLM ID at hand you can use the list_llms() method to list all available LLMs for this project.

Both the knowledge base and the LLM can be translated from DSS-native objects into Langchain-compatible items which are then used to build and run the question-answering chain. This will allow you to extend the capabilities of Dataiku-managed LLMs for more complex cases.

Wrapping up#

Congratulations, now you know how to perform RAG using Dataiku’s programmatic LLM mesh features! To go further you can for eaxmple:

  • test querying from multiple documents,

  • tweak the document retrieval process to add more results to your prompt,

  • customize the prompt to reinforce the retrieved context’s importance.

Here are the complete versions of the code presented in this tutorial:

recipe_split.py
import dataiku
import os
import tiktoken

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter

FILE_NAME = "world_bank_2023.pdf"
MF_ID = ""  # Fill with your `documents` managed folder id

CHUNK_SIZE = 1000
CHUNK_OVERLAP = 100

enc = tiktoken.encoding_for_model("gpt-4")
splitter = CharacterTextSplitter(chunk_size=CHUNK_SIZE,
                                 separator='\n',
                                 chunk_overlap=CHUNK_OVERLAP,
                                 length_function=len)

docs_folder = dataiku.Folder(MF_ID)
docs_splits = dataiku.Dataset("documents_splits")
docs_splits.write_schema([
    {"name": "split_id", "type": "string"},
    {"name": "text", "type": "string"},
    {"name": "nb_tokens", "type": "int"}
])

# Read PDF file, split it into smaller chunks and write each chunk data + metadata
# in the output dataset
pdf_path = os.path.join(docs_folder.get_path(), FILE_NAME)
splits = PyPDFLoader(pdf_path) \
    .load_and_split(text_splitter=splitter)

with docs_splits.get_writer() as w:
    for i, s in enumerate(splits):
        d = s.dict()
        w.write_row_dict({"split_id": str(d["metadata"]["page"]),
                          "text": d["page_content"],
                          "nb_tokens": len(enc.encode(d["page_content"]))
                          })
sample_simsearch.py
import dataiku

DOCS_SPLITS_KB_ID = ""  # Replace with your KB id

client = dataiku.api_client()
project = client.get_default_project()
kb_core = project.get_knowledge_bank(DOCS_SPLITS_KB_ID).as_core_knowledge_bank()
vector_store = kb_core.as_langchain_vectorstore()
query = "Summarise the current global status on inflation."
search_result = vector_store.similarity_search(query, include_metadata=True)

for r in search_result:
    print(r.json())
sample_rag.py
import dataiku
from langchain.chains.question_answering import load_qa_chain
from dataiku.langchain.dku_llm import DKUChatLLM

DOCS_SPLITS_KB_ID =  "" # Fill with your KB id
GPT_4_LLM_ID = ""  # Fill with your GPT-4 LLM id

# Retrieve the knowledge base and LLM handles
client = dataiku.api_client()
project = client.get_default_project()
kb_core = project.get_knowledge_bank(DOCS_SPLITS_KB_ID).as_core_knowledge_bank()
vector_store = kb_core.as_langchain_vectorstore()
gpt_4_lc = DKUChatLLM(llm_id=GPT_4_LLM_ID, temperature=0)

# Create and run the question answering chain
chain = load_qa_chain(gpt_4_lc, chain_type="stuff")
query = "Describe the state and causes of inflation in Europe."
search_results = vector_store.similarity_search(query, include_metadata=True)
resp = chain({"input_documents":search_results, "question": query})
print(resp["output_text"])

# In Europe, inflation persistence is notably influenced by energy prices. The pass-through of energy costs into broader
#  prices is contributing to the ongoing high inflation. Additionally, the sunsetting of fiscal programs that have 
# helped to mitigate price spikes for end-users may exacerbate inflation persistence. The absence of economic slack in
# Europe is also a factor, as it may increase the ability of firms and workers to exercise pricing power, making
#  inflation more responsive to economic activity.