{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Dataiku: *voilà* an agent application!" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining the tools" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "import dataiku\n", "from dataiku.langchain.dku_llm import DKUChatLLM\n", "from dataiku import SQLExecutor2\n", "from duckduckgo_search import DDGS\n", "\n", "from langchain.agents import AgentExecutor\n", "from langchain.agents import create_react_agent\n", "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain.tools import BaseTool, StructuredTool\n", "from langchain.pydantic_v1 import BaseModel, Field\n", "from typing import Type\n", "\n", "from textwrap import dedent\n", "\n", "LLM_ID = \"openai:maxime_openAI:gpt-3.5-turbo\"\n", "DATASET_NAME = \"pro_customers_sql\"\n", "llm = DKUChatLLM(llm_id=LLM_ID, temperature=0)\n", "\n", "class CustomerInfo(BaseModel):\n", " \"\"\"Parameter for GetCustomerInfo\"\"\"\n", " id: str = Field(description=\"customer ID\")\n", "\n", "\n", "class GetCustomerInfo(BaseTool):\n", " \"\"\"Gathering customer information\"\"\"\n", "\n", " name = \"GetCustomerInfo\"\n", " description = \"Provide a name, job title and company of a customer, given the customer's ID\"\n", " args_schema: Type[BaseModel] = CustomerInfo\n", "\n", " def _run(self, id: str):\n", " dataset = dataiku.Dataset(DATASET_NAME)\n", " table_name = dataset.get_location_info().get('info', {}).get('table')\n", " executor = SQLExecutor2(dataset=dataset)\n", " eid = id.replace(\"'\", \"\\\\'\")\n", " query_reader = executor.query_to_iter(\n", " f\"\"\"SELECT name, job, company FROM \"{table_name}\" WHERE id = '{eid}'\"\"\")\n", " for (name, job, company) in query_reader.iter_tuples():\n", " return f\"The customer's name is \\\"{name}\\\", holding the position \\\"{job}\\\" at the company named {company}\"\n", " return f\"No information can be found about the customer {id}\"\n", "\n", " def _arun(self, id: str):\n", " raise NotImplementedError(\"This tool does not support async\")\n", "class CompanyInfo(BaseModel):\n", " \"\"\"Parameter for the GetCompanyInfo\"\"\"\n", " name: str = Field(description=\"Company's name\")\n", "\n", "\n", "class GetCompanyInfo(BaseTool):\n", " \"\"\"Class for gathering in the company information\"\"\"\n", "\n", " name = \"GetCompanyInfo\"\n", " description = \"Provide general information about a company, given the company's name.\"\n", " args_schema: Type[BaseModel] = CompanyInfo\n", "\n", " def _run(self, name: str):\n", " results = DDGS().answers(name + \" (company)\")\n", " result = \"Information found about \" + name + \": \" + results[0][\"text\"] + \"\\n\" \\\n", " if len(results) > 0 and \"text\" in results[0] \\\n", " else None\n", " if not result:\n", " results = DDGS().answers(name)\n", " result = \"Information found about \" + name + \": \" + results[0][\"text\"] + \"\\n\" \\\n", " if len(results) > 0 and \"text\" in results[0] \\\n", " else \"No information can be found about the company \" + name\n", " return result\n", "\n", " def _arun(self, name: str):\n", " raise NotImplementedError(\"This tool does not support async\")\n", "\n", "tools = [GetCustomerInfo(), GetCompanyInfo()]\n", "tool_names = [tool.name for tool in tools]\n", "\n", "# Initializes the agent\n", "prompt = ChatPromptTemplate.from_template(\n", " \"\"\"Answer the following questions as best you can. You have only access to the following tools:\n", "{tools}\n", "Use the following format:\n", "Question: the input question you must answer\n", "Thought: you should always think about what to do\n", "Action: the action to take, should be one of [{tool_names}]\n", "Action Input: the input to the action\n", "Observation: the result of the action\n", "... (this Thought/Action/Action Input/Observation can repeat N times)\n", "Thought: I now know the final answer\n", "Final Answer: the final answer to the original input question\n", "Begin!\n", "Question: {input}\n", "Thought:{agent_scratchpad}\"\"\")\n", "\n", "agent = create_react_agent(llm, tools, prompt)\n", "agent_executor = AgentExecutor(agent=agent, tools=tools,\n", " verbose=True, return_intermediate_steps=True, handle_parsing_errors=True)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Voilà application" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/dataiku/python-code-envs/devadv-voila-agent/lib/python3.9/site-packages/traitlets/traitlets.py:1385: DeprecationWarning: Passing unrecognized arguments to super(Output).__init__(value='').\n", "object.__init__() takes exactly one argument (the instance to initialize)\n", "This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.\n", " warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4bb8b08252604e6e918604683a67680c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(Label(value='Enter the customer ID'), Text(value='', continuous_update=False, pl…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import ipywidgets as widgets\n", "import os\n", "\n", "\n", "label = widgets.Label(value=\"Enter the customer ID\")\n", "text = widgets.Text( placeholder=\"fdouetteau\", continuous_update=False)\n", "\n", "result = widgets.Output(value=\"\")\n", "\n", "def search(customer_id):\n", " \"\"\"\n", " Search information about a customer\n", " Args:\n", " customer_id: customer ID\n", " Returns:\n", " the agent result\n", " \"\"\"\n", " return agent_executor.invoke({\n", " \"input\": f\"\"\"Give all the professional information you can about the customer with ID: {customer_id}. \n", " Also include information about the company if you can.\"\"\",\n", " \"tools\": tools,\n", " \"tool_names\": tool_names\n", " })['output']\n", "\n", " \n", "def callback(customerId):\n", " result.clear_output()\n", " with result:\n", " result.append_stdout(search(customerId.get('new', '')))\n", "\n", "\n", "text.observe(callback, 'value')\n", "display(widgets.VBox([widgets.HBox([label,text]),result]))\n" ] } ], "metadata": { "kernelspec": { "display_name": "DSS Codeenv - devadv-voila-agent", "language": "python", "name": "py-dku-venv-devadv-voila-agent" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 4 }