Using Langchain, Chroma, and GPT for document-based retrieval-augmented generation#

Tip

As of version 12.3, Dataiku’s LLM mesh features enhance the user experience by providing oversight, governance and centralization of LLM-powered capabilities. Please refer to this tutorial for a LLM-mesh-oriented example of zero-shot classification problem.

While large language models (LLM) can perform well on text generation, they can only do so by leveraging the information gathered from the data they have been trained on. However, in several use-cases, you may need to rely on data that the model hasn’t seen (e.g., recent or domain-specific data).

In this tutorial you will leverage OpenAI’s GPT model with a custom source of information, namely a PDF file. This process is often called retrieval-augmented generation (RAG) and will also bring in new tools such as vector databases and the Langchain library.

Prerequisites#

  • Dataiku >= 11.4

  • “Use” permission on a code environment using Python >= 3.9 with the following packages:

    • In the code environment screen, for core package versions select “Pandas 1.3 (Python 3.7 and above)”

    • openai (tested with version 0.27.8)

    • langchain (tested with version 0.0.200)

    • chromadb (tested with version 0.3.26)

    • pypdf (tested with version 3.9.1)

    • sentence-transformers (tested with version 2.2.2) with the following resource initialization script:

      code_env_init_script.py
      ######################## Base imports #################################
      import logging
      import os
      import shutil
      
      from dataiku.code_env_resources import clear_all_env_vars
      from dataiku.code_env_resources import grant_permissions
      from dataiku.code_env_resources import set_env_path
      from dataiku.code_env_resources import set_env_var
      from dataiku.code_env_resources import update_models_meta
      
      # Set-up logging
      logging.basicConfig()
      logger = logging.getLogger("code_env_resources")
      logger.setLevel(logging.INFO)
      
      # Clear all environment variables defined by a previously run script
      clear_all_env_vars()
      
      # Optionally restrict the GPUs this code environment can use (it can use all by default)
      # set_env_var("CUDA_VISIBLE_DEVICES", "") # Hide all GPUs
      # set_env_var("CUDA_VISIBLE_DEVICES", "0") # Allow only cuda:0
      # set_env_var("CUDA_VISIBLE_DEVICES", "0,1") # Allow only cuda:0 & cuda:1
      
      ######################## Sentence Transformers #################################
      # Set sentence_transformers cache directory
      set_env_path("SENTENCE_TRANSFORMERS_HOME", "sentence_transformers")
      
      import sentence_transformers
      
      # Download pretrained models
      model_repo = "sentence-transformers/all-MiniLM-L6-v2"
      model_revision = "7dbbc90392e2f80f3d3c277d6e90027e55de9125"
      
      MODELS_REPO_AND_REVISION = [model_repo, model_revision]
      
      sentence_transformers_cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
      logger.info("Loading pretrained SentenceTransformer model: {}".format(model_repo))
      model_path = os.path.join(sentence_transformers_cache_dir, model_repo.replace("/", "_"))
      # Uncomment below to overwrite (force re-download of) all existing models
      # if os.path.exists(model_path):
      #     logger.warning("Removing model: {}".format(model_path))
      #     shutil.rmtree(model_path)
      # This also skips same models with a different revision
      if not os.path.exists(model_path):
          model_path_tmp = sentence_transformers.util.snapshot_download(
              repo_id=model_repo,
              revision=model_revision,
              cache_dir=sentence_transformers_cache_dir,
              library_name="sentence-transformers",
              library_version=sentence_transformers.__version__,
              ignore_files=["flax_model.msgpack", "rust_model.ot", "tf_model.h5",],
          )
          os.rename(model_path_tmp, model_path)
      else:
          logger.info("Model already downloaded, skipping")
      # Add sentence embedding models to the code-envs models meta-data
      # (ensure that they are properly displayed in the feature handling)
      update_models_meta()
      # Grant everyone read access to pretrained models in sentence_transformers/ folder
      # (by default, sentence transformers makes them only readable by the owner)
      grant_permissions(sentence_transformers_cache_dir)
      
      

Creating the vector database#

In this section, you will retrieve an external document and index it to be queried. In your project, create a new empty local managed folder called documents and write down its id.

Getting the data#

The examples covered in this tutorial will be based on the World Bank’s Global Economic Prospects (GEP) 2023 report. Download its PDF version from this page (Downloads -> Full report) into the managed folder. Then, rename the file as world_bank_2023.pdf.

Indexing and persisting the database#

The first step of your Flow will extract the text from your document, transform it into embeddings then store them inside a vector database. Simply put, those embeddings capture the semantic meaning and context of the encoded text: querying the vector database with a text input will return the most similar elements in that database (in the semantic sense).

In practice, you will use that vector database to enrich your prompt by adding relevant text from the document. This will allow the LLM to directly leverage the document’s data when asked about its content.

Create a new Python recipe using documents as input and a new local managed folder called vector_db as output with the code below. The highlighted lines indicate where you should put your folder ids:

compute_vector_db.py
import dataiku
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

FILE_NAME = "world_bank_2023.pdf"

# Load the PDF file and split it into smaller chunks
docs_folder = dataiku.Folder("xxx") # Replace with your input folder id
pdf_path = os.path.join(docs_folder.get_path(),
                        FILE_NAME)
loader = PyPDFLoader(pdf_path)
doc = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1500, separator="\n")
chunks = text_splitter.split_documents(doc)

# Retrieve embedding function from code env resources
emb_model = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(
    model_name=emb_model,
    cache_folder=os.getenv('SENTENCE_TRANSFORMERS_HOME')
)

# Index the vector database by embedding then inserting document chunks
vector_db_folder = dataiku.Folder("xxx") # Replace with your output folder id 
vector_db_path = os.path.join(vector_db_folder.get_path(),
                              "world_bank_2023_emb")
db = Chroma.from_documents(chunks,
                           embedding=embeddings,
                           metadatas=[{"source": f"{i}-wb23"} for i in range(len(chunks))],
                           persist_directory=vector_db_path)

# Save vector database as persistent files in the output folder
db.persist()

The 3 key ingredients used in this recipe are:

  • The document loader (here PyPDFLoader): one of Langchain’s tools to easily load data from various files and sources.

  • The embedding function: which kind of sentence embedding to use for encoding the document’s text. The recipe leverages a variant of the sentence transformer embeddings that maps paragraphs and sentences to a 384-dimensional dense vector space.

  • The vector database: there are many options available to store the embeddings. In this tutorial, you will use Chroma, a simple yet powerful open-source vector store that can efficiently be persisted in the form of Parquet files.

Note that the original document was split into smaller chunks before being indexed. This will allow you to find the most relevant pieces of the document for a given query and pass only those into your LLM’s prompt.

Running an enriched LLM query#

Now that the vector database is ready, you can combine it with a call to your LLM. To do so, you will take advantage of several main assets of the Langchain library: prompt templates, chains, loaders, and output parsers.

Creating the LLM object#

The first object to define when working with Langchain is the LLM. Langchain’s LLM API allows users to easily swap models without refactoring much code. Since this tutorial relies on OpenAI’s GPT, you will leverage the corresponding chat model called ChatOpenAI.

In your Python project library, create a new directory called gpt_utils, and inside that directory, create two files:

  • an empty __init__.py file

  • an auth.py file with the following code:

    auth.py
    import dataiku
    
    def get_api_key(secret_name: str) -> str:
        
        client = dataiku.api_client()
        auth_info = client.get_auth_info(with_secrets=True)
        secret_value = None
        for secret in auth_info["secrets"]:
                if secret["key"] == secret_name:
                        secret_value = secret["value"]
                        break
        if not secret_value:
                raise Exception("Secret not found")
        else:
               return secret_value
    
  • a models.py file with the following code:

    models.py
    from langchain.chat_models import ChatOpenAI
    from gpt_utils.auth import get_api_key
    
    def get_gpt_llm(secret_name: str):
        chat_params = {
            "model": "gpt-3.5-turbo-16k", # Bigger context window
            "openai_api_key": get_api_key(secret_name),
            "temperature": 0.5, # To avoid pure copy-pasting from docs lookup
            "max_tokens": 8192
        }
        llm = ChatOpenAI(**chat_params)
        return llm
    

From there you can easily test your LLM. Try running this code from a notebook:

from gpt_utils.models import get_gpt_llm

chat = get_gpt_llm(secret_name="OPENAI_API_KEY")
chat.predict("Define the role of the World Bank in one sentence.")
The World Bank is an international financial institution that provides loans and grants to 
developing countries for the purpose of reducing poverty and promoting sustainable economic growth.

You will use more elaborate ways to call the model later on. Next comes the notion of the chain.

Creating the question-answering chain#

The main objective of the Langchain library is to enable users to chain together different components (models, prompts, vector databases, etc.). In the API, the Chain interface is a sequence of calls to these components taking a given input, and returning a given output.

In this section, you will create a specific type of chain called stuff document chain which will insert parts of a document into a prompt and then pass that prompt into a LLM.

Let’s see a usage example: in your notebook, run the code below. The highlighted lines indicate where you should put your output folder id:

import dataiku
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains.question_answering import load_qa_chain
from langchain.vectorstores import Chroma

# Load embedding function
emb_model = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=emb_model,
                                   cache_folder=os.getenv('SENTENCE_TRANSFORMERS_HOME')
)

# Load vector database
vector_db_folder_id = "xxx" # Replace with your vector db folder id
vector_db_name = "world_bank_2023_emb" 
vector_db_folder = dataiku.Folder(vector_db_folder_id)
persist_dir = os.path.join(vector_db_folder.get_path(), vector_db_name)
vector_db = Chroma(persist_directory=persist_dir, embedding_function=embeddings)

# Run similarity search query
q = "What are the 3 main perspectives regarding inflation?"
v = vector_db.similarity_search(q, include_metadata=True)

# Run the chain by passing the output of the similarity search
chain = load_qa_chain(chat, chain_type="stuff")
res = chain({"input_documents": v, "question": q})
print(res["output_text"])
Based on the given context, the three main perspectives regarding inflation are:

1. Favorable Base Effects: The deceleration of global inflation is largely attributed to favorable base effects 
from commodity prices falling below their peak levels in 2022. This suggests that the recent decrease in inflation 
is temporary and influenced by the fluctuations in commodity prices.

2. Excess Demand: In advanced economies, high inflation is primarily driven by excess demand, even as supply chain 
pressures ease and energy prices decline. The absence of economic slack and the ability of firms and workers to 
exercise pricing power contribute to the persistence of inflation.

3. Negative Supply Shocks: Negative supply shocks, such as significant disruptions to oil supplies caused by 
geopolitical disturbances or stronger-than-expected demand for commodities, can raise commodity prices and 
pass through to core consumer prices. This can lead to unanchored inflation expectations and prompt central banks 
to tighten monetary policy.

Let’s decompose this code snippet:

  • First, it loads the embedding function that will be used to encode the prompt before the similarity search query.

  • Then, it loads the Chroma vector database previously created in memory, making it ready to be queried.

  • Finally, the output of that search is passed to the chain created via load_qa_chain(), then run through the LLM, and the text response is displayed. Check out Langchain’s API reference to learn more about document chains.

Adding output formatting#

The previous code snippet provided a raw text output, but in many cases, the user may need something more structured, depending on what will be done afterward. This is a good opportunity to introduce two other key Langchain concepts: prompt templates and output parsers.

  • Prompt templates provide an abstraction layer on top of the text you send to your LLM. Basically, it allows you to define a dynamic blueprint from which the prompt text will be generated at execution time.

  • Output parsers help structure the LLM’s output to constrain it to a specific format. It is particularly useful when the model’s output is sent to another application that expects a specific data structure.

The following example combines these elements with a question-answering chain to retrieve information in the form of a string with comma-separated values:

from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate

reg_query = """
List all specific regions that were studied. Write the results in a comma-separated string.
"""

# Define the output parser
reg_parser = CommaSeparatedListOutputParser()
reg_pfi = reg_parser.get_format_instructions()

# Define the prompt template
reg_prompt = PromptTemplate(template="{context}\n{question}\n{fmt}",
                               input_variables=["context", "question"],
                               partial_variables={"fmt": reg_pfi})

# Define the question-answering chain 
reg_chain = load_qa_chain(chat, chain_type="stuff", prompt=reg_prompt)

# Run similarity search query
reg_simsearch = vector_db.similarity_search(reg_query, include_metadata=True)

Linking multiple enriched LLM queries#

You now have all the building blocks to build a more complex item to add to your Flow. You already know how to build a chain that retrieves the list of the main regions of interest from the report: in this section, you will implement a recipe to combine it with another one that generates summary reports on a given list of topics for each region.

This will translate into running 2 chains sequentially, and while Langchain also has utilities to perform this kind of operation, for the sake of simplicity, you will use plain string interpolation.

Since the final result of this operation is meant to be written in a Dataiku dataset, your first step is to define how this output should be formatted so that it can be easily passed to the dataset writer.

In your project library, under gpt_utils create a new file called wb_reporter.py and add the following code:

wb_reporter.py
import json
from pydantic import BaseModel, Field
from typing import List
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.prompts import ChatPromptTemplate
from langchain.prompts import HumanMessagePromptTemplate
from langchain.prompts import SystemMessagePromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
from langchain.chains.question_answering import load_qa_chain


TOPICS = {
    "environment": "environmental sustainability and measures against climate change",
    "economy": "economic growth and development",
    "society": "poverty and inequality reduction"
}


class RegionOutlook(BaseModel):
    region_name: str = Field(description='The name of the region of interest')
    environment: str = Field(description=TOPICS["environment"])
    economy: str = Field(description=TOPICS["economy"])
    society: str = Field(description=TOPICS["society"])


class RegionOutlookList(BaseModel):
    items: List[RegionOutlook] = Field(description="The list of region states")

        
def build_qa_chain(llm: type[BaseChatModel],
                   parser: type[BaseOutputParser]) -> type[Chain]:

    # -- Create system prompt template
    sys_tpl = "You are a helpful assistant with expertise in trade and international economics."
    sys_msg_pt = SystemMessagePromptTemplate.from_template(sys_tpl)

    usr_pt = PromptTemplate(template="{context}\n{question}\n{fmt}",
                            input_variables=["context", "question"],
                            partial_variables={"fmt": parser.get_format_instructions()})
    usr_msg_pt = HumanMessagePromptTemplate(prompt=usr_pt)

    # -- Combine (system, user) into a chat prompt template
    prompt = ChatPromptTemplate.from_messages([sys_msg_pt, usr_msg_pt])

    # Create chain for QA and pass the prompt instead of plain text query
    chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=prompt)
    return chain


def run_qa_chain(chain, query, vec_db) -> str:
    # Lookup
    docs = vec_db.similarity_search(query, k=10, include_metadata=True)
    res = chain({"input_documents": docs, "question": query})
    return res["output_text"]
        

    

  • TOPICS contains the names and descriptions of each topic to generate a report on.

  • The RegionOutlook and RegionOutlookList are Pydantic models that will be used to parse the output into a JSON data structure containing the list of summary reports for each region by topic. To learn more about Langchain’s PydanticOutputParser, check out its documentation.

  • build_qa_chain() and run_qa_chain() wrap the prompt templating and chain definition/run steps into functions to make the recipe’s code more modular. build_qa_chain() uses a more elaborate version of the prompt template called ChatPromptTemplate, which structure follows the system and user message mechanism introduced in the OpenAI API client, see more details in its documentation.

You can now embed those components into a recipe in your Flow! Create a new Python recipe using vector_db as input and a new dataset called wb_regional_reports as output with the code below. The highlighted parts indicate where you should put vector_db’s id.

compute_wb_regional_reports.py
import dataiku
import os
import json
from gpt_utils.wb_reporter import build_qa_chain
from gpt_utils.wb_reporter import run_qa_chain
from gpt_utils.wb_reporter import RegionOutlookList
from gpt_utils.wb_reporter import TOPICS
from gpt_utils.models import get_gpt_llm
from langchain.output_parsers import PydanticOutputParser
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate


# Set up LLM
chat = get_gpt_llm(secret_name="OPENAI_API_KEY")

# Load the vector database
emb_model = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=emb_model,
                                   cache_folder=os.getenv('SENTENCE_TRANSFORMERS_HOME')
)

# Load vector database
vector_db_folder_id = "xxx" # Replace with your vector db folder id
vector_db_name = "world_bank_2023_emb" 
vector_db_folder = dataiku.Folder(vector_db_folder_id)
persist_dir = os.path.join(vector_db_folder.get_path(), vector_db_name)
vector_db = Chroma(persist_directory=persist_dir, embedding_function=embeddings)


# --- Chain #1: retrieve the list of regions

# Retrieve formatting instructions from output parser
reg_parser = CommaSeparatedListOutputParser()
reg_pfi = reg_parser.get_format_instructions()

# Define prompt template
reg_prompt = PromptTemplate(template="{context}\n{question}\n{fmt}",
                               input_variables=["context", "question"],
                               partial_variables={"fmt": reg_pfi})

# Define and run question-answering chain
reg_query = "List the main regions studied inside a comma-separated string."
reg_chain = load_qa_chain(chat, chain_type="stuff", prompt=reg_prompt)
reg_simsearch = vector_db.similarity_search(reg_query, include_metadata=True)
regions = run_qa_chain(chain=reg_chain,
                       query=reg_query,
                       vec_db=vector_db)

# --- Chain #2: write summary for each region

# Retrieve formatting instructions from output parser
rpt_parser = PydanticOutputParser(pydantic_object=RegionOutlookList)
rpt_chain = build_qa_chain(llm=chat, parser=rpt_parser)

topics_str = ",".join(TOPICS.keys())
q = f"""
For each region in {regions} describe in a few sentences the state of the 
following topics: {topics_str}.
"""

reports = run_qa_chain(chain=rpt_chain,
                       query=q,
                       vec_db=vector_db)

# Write final results in output dataset
output_dataset = dataiku.Dataset("wb_regional_reports")
output_schema = [
    {"name": "region_name", "type": "string"}
]
for k in TOPICS.keys():
    output_schema.append({"name": k, "type": "string"})
output_dataset.write_schema(output_schema)
with output_dataset.get_writer() as w:
    for item in json.loads(reports).get("items"):
        w.write_row_dict(item)

In this recipe’s code you start by running a first chain (reg_chain) that outputs the list of regions. Then, the rpt_chain leverages that list to generate a summary analysis on each topic of interest and the final result is written in a dataset, each record representing a region and each topic aligned on a column.

Wrapping up#

Congratulations, you have implemented a full example of RAG-based LLM usage to extract information from a document! From there, you can dig deeper and for example:

  • try uploading your own document(s) and adjust the prompt templates accordingly

  • use Langchain’s own tools like SequentialChain to run linked chains

  • play with the number of similarity search outputs to widen the information given to the model

If you want a high-level introduction to LLMs in the context of Dataiku, check out this guide.

Here are the complete versions of the code presented in this tutorial:

compute_vector_db.py
import dataiku
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

FILE_NAME = "world_bank_2023.pdf"

# Load the PDF file and split it into smaller chunks
docs_folder = dataiku.Folder("xxx") # Replace with your input folder id
pdf_path = os.path.join(docs_folder.get_path(),
                        FILE_NAME)
loader = PyPDFLoader(pdf_path)
doc = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1500, separator="\n")
chunks = text_splitter.split_documents(doc)

# Retrieve embedding function from code env resources
emb_model = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(
    model_name=emb_model,
    cache_folder=os.getenv('SENTENCE_TRANSFORMERS_HOME')
)

# Index the vector database by embedding then inserting document chunks
vector_db_folder = dataiku.Folder("xxx") # Replace with your output folder id 
vector_db_path = os.path.join(vector_db_folder.get_path(),
                              "world_bank_2023_emb")
db = Chroma.from_documents(chunks,
                           embedding=embeddings,
                           metadatas=[{"source": f"{i}-wb23"} for i in range(len(chunks))],
                           persist_directory=vector_db_path)

# Save vector database as persistent files in the output folder
db.persist()
auth.py
import dataiku

def get_api_key(secret_name: str) -> str:
    
    client = dataiku.api_client()
    auth_info = client.get_auth_info(with_secrets=True)
    secret_value = None
    for secret in auth_info["secrets"]:
            if secret["key"] == secret_name:
                    secret_value = secret["value"]
                    break
    if not secret_value:
            raise Exception("Secret not found")
    else:
           return secret_value
models.py
from langchain.chat_models import ChatOpenAI
from gpt_utils.auth import get_api_key

def get_gpt_llm(secret_name: str):
    chat_params = {
        "model": "gpt-3.5-turbo-16k", # Bigger context window
        "openai_api_key": get_api_key(secret_name),
        "temperature": 0.5, # To avoid pure copy-pasting from docs lookup
        "max_tokens": 8192
    }
    llm = ChatOpenAI(**chat_params)
    return llm
wb_reporter.py
import json
from pydantic import BaseModel, Field
from typing import List
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.prompts import ChatPromptTemplate
from langchain.prompts import HumanMessagePromptTemplate
from langchain.prompts import SystemMessagePromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseOutputParser
from langchain.chains.base import Chain
from langchain.chains.question_answering import load_qa_chain


TOPICS = {
    "environment": "environmental sustainability and measures against climate change",
    "economy": "economic growth and development",
    "society": "poverty and inequality reduction"
}


class RegionOutlook(BaseModel):
    region_name: str = Field(description='The name of the region of interest')
    environment: str = Field(description=TOPICS["environment"])
    economy: str = Field(description=TOPICS["economy"])
    society: str = Field(description=TOPICS["society"])


class RegionOutlookList(BaseModel):
    items: List[RegionOutlook] = Field(description="The list of region states")

        
def build_qa_chain(llm: type[BaseChatModel],
                   parser: type[BaseOutputParser]) -> type[Chain]:

    # -- Create system prompt template
    sys_tpl = "You are a helpful assistant with expertise in trade and international economics."
    sys_msg_pt = SystemMessagePromptTemplate.from_template(sys_tpl)

    usr_pt = PromptTemplate(template="{context}\n{question}\n{fmt}",
                            input_variables=["context", "question"],
                            partial_variables={"fmt": parser.get_format_instructions()})
    usr_msg_pt = HumanMessagePromptTemplate(prompt=usr_pt)

    # -- Combine (system, user) into a chat prompt template
    prompt = ChatPromptTemplate.from_messages([sys_msg_pt, usr_msg_pt])

    # Create chain for QA and pass the prompt instead of plain text query
    chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=prompt)
    return chain


def run_qa_chain(chain, query, vec_db) -> str:
    # Lookup
    docs = vec_db.similarity_search(query, k=10, include_metadata=True)
    res = chain({"input_documents": docs, "question": query})
    return res["output_text"]
        

    

compute_wb_regional_reports.py
import dataiku
import os
import json
from gpt_utils.wb_reporter import build_qa_chain
from gpt_utils.wb_reporter import run_qa_chain
from gpt_utils.wb_reporter import RegionOutlookList
from gpt_utils.wb_reporter import TOPICS
from gpt_utils.models import get_gpt_llm
from langchain.output_parsers import PydanticOutputParser
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate


# Set up LLM
chat = get_gpt_llm(secret_name="OPENAI_API_KEY")

# Load the vector database
emb_model = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=emb_model,
                                   cache_folder=os.getenv('SENTENCE_TRANSFORMERS_HOME')
)

# Load vector database
vector_db_folder_id = "xxx" # Replace with your vector db folder id
vector_db_name = "world_bank_2023_emb" 
vector_db_folder = dataiku.Folder(vector_db_folder_id)
persist_dir = os.path.join(vector_db_folder.get_path(), vector_db_name)
vector_db = Chroma(persist_directory=persist_dir, embedding_function=embeddings)


# --- Chain #1: retrieve the list of regions

# Retrieve formatting instructions from output parser
reg_parser = CommaSeparatedListOutputParser()
reg_pfi = reg_parser.get_format_instructions()

# Define prompt template
reg_prompt = PromptTemplate(template="{context}\n{question}\n{fmt}",
                               input_variables=["context", "question"],
                               partial_variables={"fmt": reg_pfi})

# Define and run question-answering chain
reg_query = "List the main regions studied inside a comma-separated string."
reg_chain = load_qa_chain(chat, chain_type="stuff", prompt=reg_prompt)
reg_simsearch = vector_db.similarity_search(reg_query, include_metadata=True)
regions = run_qa_chain(chain=reg_chain,
                       query=reg_query,
                       vec_db=vector_db)

# --- Chain #2: write summary for each region

# Retrieve formatting instructions from output parser
rpt_parser = PydanticOutputParser(pydantic_object=RegionOutlookList)
rpt_chain = build_qa_chain(llm=chat, parser=rpt_parser)

topics_str = ",".join(TOPICS.keys())
q = f"""
For each region in {regions} describe in a few sentences the state of the 
following topics: {topics_str}.
"""

reports = run_qa_chain(chain=rpt_chain,
                       query=q,
                       vec_db=vector_db)

# Write final results in output dataset
output_dataset = dataiku.Dataset("wb_regional_reports")
output_schema = [
    {"name": "region_name", "type": "string"}
]
for k in TOPICS.keys():
    output_schema.append({"name": k, "type": "string"})
output_dataset.write_schema(output_schema)
with output_dataset.get_writer() as w:
    for item in json.loads(reports).get("items"):
        w.write_row_dict(item)