{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataiku: *voilà* an agent application!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the tools\n",
    "\n",
    "Note: The SQL query might be written differently depending on your SQL Engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataiku\n",
    "from dataiku.langchain.dku_llm import DKUChatLLM\n",
    "from dataiku import SQLExecutor2\n",
    "from dataiku.sql import Constant, toSQL, Dialects\n",
    "from duckduckgo_search import DDGS\n",
    "\n",
    "from langchain.agents import AgentExecutor\n",
    "from langchain.agents import create_react_agent\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain.tools import BaseTool, StructuredTool\n",
    "from langchain.pydantic_v1 import BaseModel, Field\n",
    "from typing import Type\n",
    "\n",
    "from textwrap import dedent\n",
    "\n",
    "LLM_ID = \"\"\n",
    "DATASET_NAME = \"pro_customers_sql\"\n",
    "llm = DKUChatLLM(llm_id=LLM_ID, temperature=0)\n",
    "\n",
    "class CustomerInfo(BaseModel):\n",
    "    \"\"\"Parameter for GetCustomerInfo\"\"\"\n",
    "    id: str = Field(description=\"customer ID\")\n",
    "\n",
    "\n",
    "class GetCustomerInfo(BaseTool):\n",
    "    \"\"\"Gathering customer information\"\"\"\n",
    "\n",
    "    name: str = \"GetCustomerInfo\"\n",
    "    description: str = \"Provide a name, job title and company of a customer, given the customer's ID\"\n",
    "    args_schema: Type[BaseModel] = CustomerInfo\n",
    "\n",
    "    def _run(self, id: str):\n",
    "        dataset = dataiku.Dataset(DATASET_NAME)\n",
    "        table_name = dataset.get_location_info().get('info', {}).get('quotedResolvedTableName')\n",
    "        executor = SQLExecutor2(dataset=dataset)\n",
    "        cid = Constant(str(id))\n",
    "        escaped_cid = toSQL(cid, dialect=Dialects.POSTGRES)  # Replace by your DB\n",
    "        query_reader = executor.query_to_iter(\n",
    "            f\"\"\"SELECT \"name\", \"job\", \"company\" FROM {table_name} WHERE \"id\" = {escaped_cid}\"\"\")\n",
    "        for (name, job, company) in query_reader.iter_tuples():\n",
    "            return f\"The customer's name is \\\"{name}\\\", holding the position \\\"{job}\\\" at the company named {company}\"\n",
    "        return f\"No information can be found about the customer {id}\"\n",
    "\n",
    "    def _arun(self, id: str):\n",
    "        raise NotImplementedError(\"This tool does not support async\")\n",
    "class CompanyInfo(BaseModel):\n",
    "    \"\"\"Parameter for the GetCompanyInfo\"\"\"\n",
    "    name: str = Field(description=\"Company's name\")\n",
    "\n",
    "\n",
    "class GetCompanyInfo(BaseTool):\n",
    "    \"\"\"Class for gathering in the company information\"\"\"\n",
    "\n",
    "    name: str = \"GetCompanyInfo\"\n",
    "    description: str = \"Provide general information about a company, given the company's name.\"\n",
    "    args_schema: Type[BaseModel] = CompanyInfo\n",
    "\n",
    "    def _run(self, name: str):\n",
    "        results = DDGS().text(name + \" (company)\", max_results=1)\n",
    "        result = \"Information found about \" + name + \": \" + results[0][\"body\"] + \"\\n\" \\\n",
    "            if len(results) > 0 and \"body\" in results[0] \\\n",
    "            else None\n",
    "        if not result:\n",
    "            results = DDGS().text(name, max_results=1)\n",
    "            result = \"Information found about \" + name + \": \" + results[0][\"body\"] + \"\\n\" \\\n",
    "                if len(results) > 0 and \"body\" in results[0] \\\n",
    "                else \"No information can be found about the company \" + name\n",
    "        return result\n",
    "\n",
    "    def _arun(self, name: str):\n",
    "        raise NotImplementedError(\"This tool does not support async\")\n",
    "\n",
    "tools = [GetCustomerInfo(), GetCompanyInfo()]\n",
    "tool_names = [tool.name for tool in tools]\n",
    "\n",
    "# Initializes the agent\n",
    "prompt = ChatPromptTemplate.from_template(\n",
    "    \"\"\"Answer the following questions as best you can. You have only access to the following tools:\n",
    "{tools}\n",
    "Use the following format:\n",
    "Question: the input question you must answer\n",
    "Thought: you should always think about what to do\n",
    "Action: the action to take, should be one of [{tool_names}]\n",
    "Action Input: the input to the action\n",
    "Observation: the result of the action\n",
    "... (this Thought/Action/Action Input/Observation can repeat N times)\n",
    "Thought: I now know the final answer\n",
    "Final Answer: the final answer to the original input question\n",
    "Begin!\n",
    "Question: {input}\n",
    "Thought:{agent_scratchpad}\"\"\")\n",
    "\n",
    "agent = create_react_agent(llm, tools, prompt)\n",
    "agent_executor = AgentExecutor(agent=agent, tools=tools,\n",
    "                               verbose=True, return_intermediate_steps=True, handle_parsing_errors=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Voilà application"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/dataiku/python-code-envs/devadv-voila-agent/lib/python3.9/site-packages/traitlets/traitlets.py:1385: DeprecationWarning: Passing unrecognized arguments to super(Output).__init__(value='').\n",
      "object.__init__() takes exactly one argument (the instance to initialize)\n",
      "This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.\n",
      "  warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4bb8b08252604e6e918604683a67680c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(Label(value='Enter the customer ID'), Text(value='', continuous_update=False, pl…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import ipywidgets as widgets\n",
    "import os\n",
    "\n",
    "\n",
    "label = widgets.Label(value=\"Enter the customer ID\")\n",
    "text = widgets.Text( placeholder=\"fdouetteau\", continuous_update=False)\n",
    "\n",
    "result = widgets.Output(value=\"\")\n",
    "\n",
    "def search(customer_id):\n",
    "    \"\"\"\n",
    "    Search information about a customer\n",
    "    Args:\n",
    "        customer_id: customer ID\n",
    "    Returns:\n",
    "        the agent result\n",
    "    \"\"\"\n",
    "    return agent_executor.invoke({\n",
    "        \"input\": f\"\"\"Give all the professional information you can about the customer with ID: {customer_id}. \n",
    "        Also include information about the company if you can.\"\"\",\n",
    "        \"tools\": tools,\n",
    "        \"tool_names\": tool_names\n",
    "    })['output']\n",
    "\n",
    "   \n",
    "def callback(customerId):\n",
    "    result.clear_output()\n",
    "    with result:\n",
    "        result.append_stdout(search(customerId.get('new', '')))\n",
    "\n",
    "\n",
    "text.observe(callback, 'value')\n",
    "display(widgets.VBox([widgets.HBox([label,text]),result]))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DSS Codeenv - devadv-voila-agent",
   "language": "python",
   "name": "py-dku-venv-devadv-voila-agent"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
