import dataiku
from dataiku.llm.agent_tools import BaseAgentTool
from pymongo import MongoClient

MONGO_URI = "mongodb+srv://[username]:[password]@[cluster]"
MONGO_DB = "movies"
MONGO_COLLECTION = "movies_embeddings"
MONGO_INDEX = "movie_index"
TEXT_FIELD = "title"
VECTOR_FIELD = "embedding"
NUM_CANDIDATES = 100
EMBEDDING_MODEL_ID = "internal-embedding-id"

class MongoVectorSearchTool(BaseAgentTool):
    def set_config(self, config, plugin_config):
        # Not needed for this script
        pass

    def get_descriptor(self, tool):
        return {
            "description": "Semantic search over a MongoDB Atlas Vector Search collection",
            "inputSchema": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "User question or search text",
                    },
                    "k": {
                        "type": "integer",
                        "description": "Number of results to return",
                        "default": 5,
                    },
                },
                "required": ["query"],
            },
        }

    def invoke(self, input, trace):
        args = input["input"]
        query_text = args["query"]
        k = args.get("k", 5)

        client = dataiku.api_client()
        project = client.get_default_project()
        llm = project.get_llm(EMBEDDING_MODEL_ID)

        emb_query = llm.new_embeddings()
        emb_query.add_text(query_text)
        emb_resp = emb_query.execute()
        query_vector = emb_resp.get_embeddings()[0]

        mongo = MongoClient(MONGO_URI)
        coll = mongo[MONGO_DB][MONGO_COLLECTION]

        pipeline = [
            {
                "$vectorSearch": {
                    "index": MONGO_INDEX,
                    "path": VECTOR_FIELD,
                    "queryVector": query_vector,
                    "numCandidates": NUM_CANDIDATES,
                    "limit": k,
                }
            },
            {
                "$project": {
                    TEXT_FIELD: 1,
                    "score": {"$meta": "vectorSearchScore"},
                }
            },
        ]

        raw = list(coll.aggregate(pipeline))
        results = [
            {
                "text": doc[TEXT_FIELD],
                "score": doc["score"],
            }
            for doc in raw
        ]

        return {"output": results}
