Programmatic RAG with Dataiku’s LLM Mesh and Langchain#

While large language models (LLM) perform text generation well, they can only leverage the information on which they have been trained. However, many use cases might rely on data the model has not seen. This inability of LLMs to perform tasks outside their training data is a well-known shortcoming–for example, with recent or domain-specific information. This tutorial covers a technique that overcomes this common pitfall.

This tutorial implements this process known as retrieval-augmented generation (RAG). To perform this task, you will use OpenAI’s GPTx model over a custom source of information, namely a PDF file. In the process, you will use Dataiku’s LLM mesh features, namely:

  1. LLM connections

  2. Knowledge Banks

  3. Dataiku’s Langchain integrations for vector stores and LLMs

Prerequisites#

  • Dataiku >= 12.4

  • permission to use a Python code environment with the Retrieval augmented generation models package set installed, plus the tiktoken, pypdf packages

  • OpenAI LLM connection for a GPT model enabled (preferably GPT-4)

Note

You will index a document so it can be queried by an LLM. From version 12.3, Dataiku provides a native Flow item called Knowledge Bank that points to a vector databases where its embeddings are stored.

Converting a downloaded document into chunks#

Create a Python recipe and create an output dataset named document_splits. Within it, run the following script that downloads a PDF document:

Python script - split document
recipe_split.py#
import dataiku
import os
import tiktoken

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter

FILE_URL = "https://bit.ly/GEP-Jan-2024" # Update as needed

CHUNK_SIZE = 1000
CHUNK_OVERLAP = 100

enc = tiktoken.encoding_for_model("gpt-4")
splitter = CharacterTextSplitter(chunk_size=CHUNK_SIZE,
                                 separator='\n',
                                 chunk_overlap=CHUNK_OVERLAP,
                                 length_function=len)

docs_dataset = dataiku.Dataset("document_splits")
docs_dataset.write_schema([
    {"name": "split_id", "type": "int"},
    {"name": "text", "type": "string"},
    {"name": "page", "type": "int"},
    {"name": "nb_tokens", "type": "int"}
])


# Read PDF file, split it into smaller chunks and write each chunk data + metadata
# in the output dataset
splits = PyPDFLoader(FILE_URL) \
    .load_and_split(text_splitter=splitter)

with docs_dataset.get_writer() as w:
    for i, s in enumerate(splits):
        d = s.dict()
        w.write_row_dict({"split_id": i,
            "text": d["page_content"],
            "page": d["metadata"]["page"],
            "nb_tokens": len(enc.encode(d["page_content"]))
            })

Caution

This tutorial uses the World Bank’s Global Economic Prospects (GEP) report. If the referenced publication is no longer available, look for the latest report’s PDF version on this page

The next step will be to extract the document text, transform it into embeddings and store them in a vector database. Before indexing a document into a vector database this way, it is a common practice to split the text into smaller chunks first. Searching across multiple smaller chunks instead of a single large document allows for a more granular match between the input prompt and the document’s content.

Once built, each row from document_splits will contain a distinct chunk with:

  1. an ID

  2. text

  3. origin page number

  4. length measured by the number of tokens, which allows us to quantify how much of the LLM’s context window will be consumed and the cost of computing embeddings via services like the OpenAI API

Storing embeddings in a vector database#

Embeddings capture the semantic meaning and context of the encoded text. Querying the vector database with a text input will return the most similar elements (in the semantic sense) in that database. These results from that vector database are used to enrich your prompt by adding relevant text from the document. This allows the LLM to leverage the document’s data directly when asked about its content.

In Dataiku, a knowledge bank (KB) flow item represents the vector database. It is created using a visual embedding recipe.

Implement and run a new embedding recipe (+RECIPE > LLM Recipes > Embed) with document_splits as the input dataset and a KB called document_embedded with the following settings:

  1. embedding model: “Embedding (Ada 002)”

  2. knowledge column -> text

  3. metadata columns (optional): page, split_id, nb_tokens

  4. document splitting method: “Do not split”

Note

When accessing a KB from the Dataiku UI, its URL is of the form:

https://dss.example/projects/YOUR_PROJECT_KEY/knowledge-bank/<KB_ID>

Take note of the KB ID. You can also retrieve this identifier later on with the list_knowledge_banks() method.

By handling it as Langchain vector stores, you can query a KB programmatically. To test, run the following code from a notebook:

Python notebook - test kb
sample_simsearch.py#
import dataiku

KB_ID = ""  # Replace with your KB id

client = dataiku.api_client()
project = client.get_default_project()
kb = dataiku.KnowledgeBank(id=KB_ID,
    project_key=project.project_key)
vector_store = kb.as_langchain_vectorstore()
query = "Summarize the current global status on inflation."
search_result = vector_store.similarity_search(query, include_metadata=True)

for r in search_result:
    print(r.json())

Running an enriched LLM query#

Now that the KB is ready to query, you can use it with an LLM. In practice, you will use your prompt as a query for similarity search to retrieve additional data as context. Before running inference on the LLM, this context can be added to the initial prompt.

Run the following code from a notebook:

sample_rag.py#
import dataiku
from langchain.chains.question_answering import load_qa_chain
from dataiku.langchain.dku_llm import DKUChatLLM

KB_ID = "" # Fill with your KB id
GPT_LLM_ID = ""  # Fill with your LLM-Mesh id

# Retrieve the knowledge base and LLM handles
client = dataiku.api_client()
project = client.get_default_project()
kb = dataiku.KnowledgeBank(id=KB_ID, project_key=project.project_key)
vector_store = kb.as_langchain_vectorstore()
gpt_lc = DKUChatLLM(llm_id=GPT_LLM_ID, temperature=0)

# Create the question answering chain
chain = load_qa_chain(gpt_lc, chain_type="stuff")
query = "What will inflation in Europe look like and why?"
search_results = vector_store.similarity_search(query)

# ⚡ Get the results ⚡
resp = chain({"input_documents":search_results, "question": query})
print(resp["output_text"])

# Inflation in Europe is expected to remain high in the near term due to
# persistently high inflation that will prevent a rapid easing of monetary
# policy in most economies and weigh on private consumption. Projected fiscal
# consolidation further dampens the outlook. Risks such as an escalation of
# the conflict in the Middle East could increase energy prices, tighten
# financial conditions, and negatively affect confidence.

# Geopolitical risks in the region, including an escalation of the Russian
# Federation’s invasion of Ukraine, are elevated and could materialize.
# Higher-than-anticipated inflation or a weaker-than-expected recovery in
# the euro area would also negatively affect regional activity. However, by
# 2024-25, global inflation is expected to decline further, underpinned by
# the projected weakness in global demand growth and slightly lower
# commodity prices.

If you don’t have your LLM ID at hand you can use the list_llms() method to list all available LLMs for this project.

The DSS-native knowledge base and the LLM objects are translated into Langchain-compatible items, which are then used to build and run the question-answering chain. This allows you to extend the capabilities of Dataiku-managed LLMs for more complex cases.

Wrapping up#

Congratulations, now you can perform RAG using Dataiku’s programmatic LLM mesh features! To go further, you can:

  • query from multiple documents

  • retrieve more results from the knowledge bank to feed the LLM

  • reinforce the retrieved context’s importance in the prompt