import gradio as gr
import os
import re

import dataiku
from dataiku.langchain.dku_llm import DKUChatModel
from dataiku import SQLExecutor2
from dataiku.sql import Constant, toSQL, Dialects
from langchain.agents import create_agent
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from typing import Type

from ddgs import DDGS


LLM_ID = "" # Fill in with a valid LLM ID
DATASET_NAME = "pro_customers_sql"
VERSION = "V3"

llm = DKUChatModel(llm_id=LLM_ID, temperature=0)


class CustomerInfo(BaseModel):
    """Parameter for GetCustomerInfo"""
    id: str = Field(description="customer ID")


class GetCustomerInfo(BaseTool):
    """Gathering customer information"""

    name: str = "GetCustomerInfo"
    description: str = "Provide a name, job title and company of a customer, given the customer's ID"
    args_schema: Type[BaseModel] = CustomerInfo

    def _run(self, id: str):
        dataset = dataiku.Dataset(DATASET_NAME)
        table_name = dataset.get_location_info().get('info', {}).get('quotedResolvedTableName')
        executor = SQLExecutor2(dataset=dataset)
        cid = Constant(str(id))
        escaped_cid = toSQL(cid, dialect=Dialects.POSTGRES)  # Replace by your DB
        query_reader = executor.query_to_iter(
            f"""SELECT "name", "job", "company" FROM {table_name} WHERE "id" = {escaped_cid}""")
        for (name, job, company) in query_reader.iter_tuples():
            return f"The customer's name is \"{name}\", holding the position \"{job}\" at the company named {company}"
        return f"No information can be found about the customer {id}"

    def _arun(self, id: str):
        raise NotImplementedError("This tool does not support async")


class CompanyInfo(BaseModel):
    """Parameter for the GetCompanyInfo"""
    name: str = Field(description="Company's name")


class GetCompanyInfo(BaseTool):
    """Class for gathering in the company information"""

    name:str = "GetCompanyInfo"
    description:str = "Provide general information about a company, given the company's name."
    args_schema: Type[BaseModel] = CompanyInfo

    def _run(self, name: str):
        results = DDGS().text(name + " (company)", max_results=1)
        result = "Information found about " + name + ": " + results[0]["body"] + "\n" \
            if len(results) > 0 and "body" in results[0] \
            else None
        if not result:
            results = DDGS().text(name, max_results=1)
            result = "Information found about " + name + ": " + results[0]["body"] + "\n" \
                if len(results) > 0 and "body" in results[0] \
                else "No information can be found about the company " + name
        return result

    def _arun(self, name: str):
        raise NotImplementedError("This tool does not support async")


tools = [GetCustomerInfo(), GetCompanyInfo()]
tool_names = [tool.name for tool in tools]

# Initializes the agent
prompt = """Answer the following questions as best you can. You have only access to the following tools:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin!

Question: {input}
Thought:{agent_scratchpad}"""

agent = create_agent(model=llm, tools=tools, system_prompt=prompt)


def search_V1(customer_id):
    """
    Search information about a customer
    Args:
        customer_id: customer ID

    Returns:
        the agent result
    """
    result = agent.invoke({
        "input": f"""Give all the professional information you can about the customer with ID: {customer_id}. 
        Also include information about the company if you can.""",
        "tools": tools,
        "tool_names": tool_names
    })
    return result["messages"][-1].pretty_repr()


async def search_V2(customer_id):
    """
    Search information about a customer
    Args:
        customer_id: customer ID

    Returns:
        the agent result
    """

    content = f"""Give all the professional information you can about the customer with ID: {customer_id.strip()}. 
        Also include information about the company if you can."""

    for step in agent.stream(
        {"messages": [HumanMessage(content=content)]},
        tools=tools,
        tool_names=tool_names,
        stream_mode="updates"):
        if 'model' in step.keys():
            for message in step['model']['messages']:
                yield message.pretty_repr()


async def search_V3(customer_id):
    """
    Search information about a customer
    Args:
        customer_id: customer ID

    Returns:
        the agent result
    """
    content = f"""Give all the professional information you can about the customer with ID: {customer_id.strip()}. 
        Also include information about the company if you can."""
    messages=""
    tool_calls=""
    for step in agent.stream(
        {"messages": [HumanMessage(content=content)]},
        tools=tools,
        tool_names=tool_names,
        stream_mode="updates"):
        if 'model' in step.keys():
            for message in step['model']['messages']:
                messages = message.pretty_repr()
        elif 'tools' in step.keys():
            print("Calling tools:")
            for tool_message in step['tools']['messages']:
                tool_calls = f"# tool: {tool_message.name}"
                tool_calls += f"  {tool_message.content}"
        yield[messages, tool_calls]



if VERSION == "V1":
    demo = gr.Interface(
        fn=search_V1,
        inputs=gr.Textbox(label="Enter a customer ID to get more information", placeholder="ID Here..."),
        outputs="text"
    )

if VERSION == "V2":
    demo = gr.Interface(
        fn=search_V2,
        inputs=gr.Textbox(label="Enter a customer ID to get more information", placeholder="ID Here..."),
        outputs="text"
    )

if VERSION == "V3":
    demo = gr.Interface(
        fn=search_V3,
        inputs=gr.Textbox(label="Enter a customer ID to get more information", placeholder="ID Here..."),
        outputs=[
            gr.Textbox(label="Agent thought"),
            gr.Textbox(label="Tool Result")]
    )

browser_path = os.getenv("DKU_CODE_STUDIO_BROWSER_PATH_7860")
# replacing env var keys in browser_path with their values
env_var_pattern = re.compile(r'(\${(.*)})')
env_vars = env_var_pattern.findall(browser_path)
for env_var in env_vars:
    browser_path = browser_path.replace(env_var[0], os.getenv(env_var[1], ''))

# WARNING: make sure to use the same params as the ones defined below when calling the launch method,
# otherwise you app might not be responding!
demo.queue().launch(server_port=7860, root_path=browser_path)