LLM Mesh#
The LLM Mesh is the common backbone for Enterprise Generative AI Applications. For more details on the LLM Mesh features of Dataiku, please visit Generative AI and LLM Mesh.
The LLM Mesh API allows you to:
Send completion and embedding queries to all LLMs supported by the LLM Mesh
Stream responses from LLMs that support it
Query LLMs using multimodal inputs (image and text)
Query the LLM Mesh from LangChain code
Interact with knowledge banks, and perform semantic search
Create a fine-tuned saved model
Read LLM Mesh metadata#
List and get LLMs#
import dataiku
client = dataiku.api_client()
project = client.get_default_project()
llm_list = project.list_llms()
By default, list_llms()
returns a list of DSSLLMListItem
. To get more details :
for llm in llm_list:
print(f"- {llm.description} (id: {llm.id})")
Perform completion queries on LLMs#
Your first simple completion query#
This sample receives an LLM and uses a completion query to ask the LLM to “write a haiku on GPT models.”
import dataiku
# Fill with your LLM id. For example, if you have an OpenAI connection called "myopenai", LLM_ID can be "openai:myopenai:gpt-4o"
# To get the list of LLM ids, you can use project.list_llms() (see above)
LLM_ID = ""
# Create a handle for the LLM of your choice
client = dataiku.api_client()
project = client.get_default_project()
llm = project.get_llm(LLM_ID)
# Create and run a completion query
completion = llm.new_completion()
completion.with_message("Write a haiku on GPT models")
resp = completion.execute()
# Display the LLM output
if resp.success:
print(resp.text)
# GPT, a marvel,
# Deep learning's symphony plays,
# Thoughts dance, words unveil.
Multi-turn and system prompts#
You can have multiple messages in the completion
object, with roles
completion = llm.new_completion()
# First, put a system prompt
completion.with_message("You are a poetic assistant who always answers in haikus", role="system")
# Then, give an example, or send the conversation history
completion.with_message("What is a transformer", role="user")
completion.with_message("Transformers, marvels\nOf the deep learning research\nAttention, you need", role="assistant")
# Then, the last query of the user
completion.with_message("What's your name", role="user")
resp = completion.execute()
Multimodal input#
Multimodal input is supported on a subset of the LLMs in the LLM Mesh:
OpenAI
Bedrock Anthropic Claude
Azure OpenAI
Gemini Pro
completion = llm.new_completion()
with open("myimage.jpg", "rb") as f:
image = f.read()
mp_message = completion.new_multipart_message()
mp_message.with_text("The image represents an artwork. Describe it as it would be described by art critics")
mp_message.with_inline_image(image)
# Add it to the completion request
mp_message.add()
resp = completion.execute()
Completion settings#
You can set settings on the completion query
completion = llm.new_completion()
completion.with_message("Write a haiku on GPT models")
completion.settings["temperature"] = 0.7
completion.settings["topK"] = 10
completion.settings["topP"] = 0.3
completion.settings["maxOutputTokens"] = 2048
completion.settings["stopSequences"] = [".", "\n"]
completion.settings["presencePenalty"] = 0.6
completion.settings["frequencyPenalty"] = 0.9
completion.settings["logitBias"] = {
1489: 60, # apply a logit bias of 60 on token value "1489"
}
completion.settings["logProbs"] = True
completion.settings["topLogProbs"] = 3
resp = completion.execute()
Response streaming#
from dataikuapi.dss.llm import DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter
completion = llm.new_completion()
completion.with_message("Please explain special relativity")
for chunk in completion.execute_streamed():
if isinstance(chunk, DSSLLMStreamedCompletionChunk):
print("Received text: %s" % chunk.data["text"])
elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
print("Completion is complete: %s" % chunk.data)
Text embedding#
import dataiku
EMBEDDING_MODEL_ID = "" # Fill with your embedding model id, for example: openai:myopenai:text-embedding-3-small
# Create a handle for the embedding model of your choice
client = dataiku.api_client()
project = client.get_default_project()
emb_model = project.get_llm(EMBEDDING_MODEL_ID)
# Create and run an embedding query
txt = "The quick brown fox jumps over the lazy dog."
emb_query = emb_model.new_embeddings()
emb_query.add_text(txt)
emb_resp = emb_query.execute()
# Display the embedding output
print(emb_resp.get_embeddings())
# [[0.000237455,
# -0.103262354,
# ...
# ]]
Tool calls#
Tool calls (sometimes referred to as “function calling”) allow you to augment a LLM with “tools”, functions that it can call and provide the arguments. Your client code can then perform those calls, and provide the output back to the LLM so that it can generate the next response.
Tool calls are supported on the compatible completion models of some LLM connections:
OpenAI
Azure OpenAI
Azure LLM
Anthropic Claude
Anthropic Claude models on AWS Bedrock connections
MistralAI
Define tools#
You can define tools as settings in the completion query. Tool parameters are defined as JSON Schema objects. See the JSON Schema reference for documentation about the format.
Tools can also be automatically prepared and invoked from Python code, e.g. using Langchain.
completion = llm.new_completion()
completion.settings["tools"] = [
{
"type": "function",
"function": {
"name": "multiply",
"description": "Multiply integers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "integer",
"description": "The first integer to multiply",
},
"b": {
"type": "integer",
"description": "The other integer to multiply",
},
},
"required": ["a", "b"],
}
}
}
]
completion.with_message("What is 3 * 6 ?")
resp = completion.execute()
print(resp.tool_calls)
# [{'type': 'function',
# 'function': {'name': 'multiply', 'arguments': '{"a":3,"b":6}'},
# 'id': 'call_gEB9fOdroydyxYuRs0Ge6Izg'}]
Response streaming with tool calls#
LLM responses which include tool calls can also leverage streaming. Depending on the LLM, response chunks may
include either complete tool calls or partial tool calls. When the LLM sends partial tool calls, the
streamed chunk contains an extra field index
allowing to reconstruct the whole LLM response.
for chunk in completion.execute_streamed():
if isinstance(chunk, DSSLLMStreamedCompletionChunk):
if "text" in chunk.data:
print("Received text: %s" % chunk.data["text"])
if "toolCalls" in chunk.data:
print("Received tool call: %s" % chunk.data["toolCalls"])
elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
print("Completion is complete: %s" % chunk.data)
Provide tool outputs#
Tool calls can then be parsed and executed. In order to provide the tool response in the chat messages, use the following methods:
import json
# Function to handle the tool call
def multiply(llm_arguments):
try:
json_arguments = json.loads(llm_arguments)
a = json_arguments["a"]
b = json_arguments["b"]
return str(a * b)
except Exception as e:
return f"Cannot call the 'multiply' tool: {str(e)}"
tool_calls = resp.tool_calls
call_id = tool_calls[0]["id"]
llm_arguments = tool_calls[0]["function"]["arguments"]
result = multiply(llm_arguments)
completion.with_tool_calls(tool_calls)
completion.with_tool_output(result, tool_call_id=call_id)
resp = completion.execute()
print(resp.text)
# 3 multiplied by 6 is 18.
Control tool usage#
Tool usage can be constrained in the completion settings:
completion = llm.new_completion()
# Let the LLM decide whether to call a tool
completion.settings["toolChoice"] = {"type": "auto"}
# The LLM must call at least one tool
completion.settings["toolChoice"] = {"type": "required"}
# The LLM must not call any tool
completion.settings["toolChoice"] = {"type": "none"}
# The LLM must call the tool with name 'multiply'
completion.settings["toolChoice"] = {"type": "tool_name", "name": "multiply"}
Knowledge Banks (KB)#
List and get KBs#
To list the KB present in a project:
import dataiku
client = dataiku.api_client()
project = client.get_default_project()
kb_list = project.list_knowledge_banks()
By default, list_knowledge_banks()
returns a list of DSSKnowledgeBankListItem
.
To get more details:
for kb in kb_list:
print(f"{kb.name} (id: {kb.id})")
To get a “core handle” on the KB (i.e. to retrieve a KnowledgeBank
object) :
KB_ID = "" # Fill with your KB id
kb_public_api = project.get_knowledge_bank(KB_ID)
kb_core = kb_public_api.as_core_knowledge_bank()
LangChain integration#
Dataiku LLM model objects can be turned into langchain-compatible objects, making it easy to:
stream responses
run asynchronous queries
batch queries
chain several models and adapters
integrate with the wider langchain ecosystem
Transforming LLM handles to LangChain model#
# In this sample, llm is the result of calling project.get_llm() (see above)
# Turn a regular LLM handle into a langchain-compatible one
langchain_llm = llm.as_langchain_llm()
# Run a single completion query
langchain_llm.invoke("Write a haiku on GPT models")
# Run a batch of completion queries
langchain_llm.batch(["Write a haiku on GPT models", "Write a haiku on GPT models in German"])
# Run a completion query and stream the response
for chunk in langchain_llm.stream("Write a haiku on GPT models"):
print(chunk, end="", flush=True)
See the langchain documentation for more details.
You can also turn it into a langchain “chat model”, a specific type of LLM geared towards conversation:
# In this sample, llm is the result of calling project.get_llm() (see above)
# Turn a regular LLM handle into a langchain-compatible one
langchain_llm = llm.as_langchain_chat_model()
# Run a simple query
langchain_llm.invoke("Write a haiku on GPT models")
# Run a chat query
from langchain_core.messages import HumanMessage, SystemMessage
messages = [
SystemMessage(content="You're a helpful assistant"),
HumanMessage(content="What is the purpose of model regularization?"),
]
langchain_llm.invoke(messages)
# Streaming and chaining
from langchain.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_template("Tell me a joke about {topic}")
chain = prompt | langchain_llm
for chunk in chain.stream({"topic": "parrot"}):
print(chunk.content, end="", flush=True)
See the langchain documentation for more details.
Creating Langchain models directly#
If running from inside DSS, you can also directly create the Langchain model:
from dataiku.langchain.dku_llm import DKULLM, DKUChatLLM
langchain_llm = DKUChatLLM(llm_id="your llm id") # For example: openai:myopenai:gpt-4o
Response streaming#
LangChain adapters (DKULLM and DKUChatLLM) also support streaming of answer:
from dataiku.langchain.dku_llm import DKULLM, DKUChatLLM
from langchain_core.messages import HumanMessage, SystemMessage
langchain_llm = DKUChatLLM(llm_id="your llm id") # For example: openai:myopenai:gpt-4o
messages = [
SystemMessage(content="You're a helpful assistant"),
HumanMessage(content="What is the purpose of model regularization?"),
]
for gen in langchain_llm.stream(messages):
print(gen)
Using knowledge banks as LangChain retrievers#
Core handles allow users to leverage the Langchain library and, through it:
query the KB for semantic similarity search
combine the KB with an LLM to form a chain and perform complex workflows such as retrieval-augmented generation (RAG).
In practice, core handles expose KBs as a Langchain-native vector store through two different methods:
as_langchain_retriever()
returns a genericVectorStoreRetriever
objectas_langchain_vectorstore()
returns an object whose class corresponds to the KB type. For example, for a FAISS-based KB, you will get alangchain.vectorstores.faiss.FAISS
object.
import dataiku
client = dataiku.api_client()
project = client.get_default_project()
kb_core = project.get_knowledge_bank(KB_ID).as_core_knowledge_bank()
# Return a langchain.vectorstores.base.VectorStoreRetriever
lc_generic_vs= kb_core.as_langchain_retriever()
# Return an object which type depends on the KB type
lc_vs = kb_core.as_langchain_vectorstore()
# [...] Move forward with similarity search or RAG
Hybrid Search#
Combines both similarity search (default behaviour) and keyword search to retrieve more relevant documents. Only supported by Azure AI Search and Elasticsearch; and not compatible with the diversity option.
Additionally, both vector store offer advanced reranking capabilities, to enhance the mix of documents retrieved. Each has its own specific configuration.
Azure AI Search#
import dataiku
client = dataiku.api_client()
project = client.get_default_project()
kb_core = project.get_knowledge_bank(KB_ID).as_core_knowledge_bank()
# 1 using as_langchain_retriever
azure_classic_retriever = kb_core.as_langchain_retriever(search_type="similarity")
azure_hybrid_retriever = kb_core.as_langchain_retriever(search_type="hybrid")
azure_hybrid_advanced_retriever = kb_core.as_langchain_retriever(search_type="semantic_hybrid")
# 2 using as_langchain_vectorstore to get retriever
azure_classic_retriever = kb_core.as_langchain_vectorstore().as_retriever(
search_type="similarity")
azure_hybrid_retriever = kb_core.as_langchain_vectorstore().as_retriever(
search_type="hybrid")
azure_hybrid_advanced_retriever = kb_core.as_langchain_vectorstore().as_retriever(
search_type="semantic_hybrid")
# 3 using as_langchain_vectorstore to perform query
query = "A text to match some doccuments"
azure_classic_result = kb_core.as_langchain_vectorstore().similarity_search(query)
azure_hybrid_result = kb_core.as_langchain_vectorstore().hybrid_search(query)
azure_hybrid_advanced_result = kb_core.as_langchain_vectorstore().semantic_hybrid_search(query)
ElasticSearch#
For elastic search, since we need the info at db instantiation time, thats why we need to use vectorstore_kwargs
.
Only similarity
search type is allowed when using a hybrid strategy.
import dataiku
from elasticsearch.helpers.vectorstore import DenseVectorStrategy
client = dataiku.api_client()
project = client.get_default_project()
kb_core = project.get_knowledge_bank(KB_ID).as_core_knowledge_bank()
hybrid_strategy = DenseVectorStrategy(hybrid=True)
hybrid_advanced_strategy = DenseVectorStrategy(hybrid=True, rrf=True)
# 1 using as_langchain_retriever
elastic_classic = kb_core.as_langchain_retriever()
elastic_hybrid = kb_core.as_langchain_retriever(
vectorstore_kwargs={"strategy": hybrid_strategy})
elastic_hybrid_advanced = kb_core.as_langchain_retriever(
vectorstore_kwargs={"strategy": hybrid_advanced_strategy})
# 2 using as_langchain_vectorstore
elastic_classic = kb_core.as_langchain_vectorstore().as_retriever()
elastic_hybrid = kb_core.as_langchain_vectorstore(
strategy=hybrid_strategy).as_retriever(search_type="similarity")
elastic_hybrid_advanced = kb_core.as_langchain_vectorstore(
strategy=hybrid_advanced_strategy).as_retriever(search_type="similarity")
Using tool calls#
The LangChain chat model adapter supports tool calling, assuming that the underlying LLM supports it too.
import dataiku
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
# Define tools
@tool
def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b
@tool
def multiply(a: int, b: int) -> int:
"""Multiplies a and b."""
return a * b
tools_by_name = {"add": add, "multiply": multiply}
tools = [add, multiply]
tool_choice = {"type": "auto"}
# Get the LangChain chat model, bind it to the tools
client = dataiku.api_client()
project = client.get_default_project()
llm_id = "<your llm id>" # For example: "openai:myopenai:gpt-4o"
llm = project.get_llm(llm_id).as_langchain_chat_model()
llm_with_tools = llm.bind_tools(tools, tool_choice=tool_choice)
# Ask your question
messages = [HumanMessage("What is 3 * 12? and 6 + 4?")]
ai_msg = llm_with_tools.invoke(messages)
messages.append(ai_msg)
# Retrieve tool calls, run them and put the results in the chat messages
for tool_call in ai_msg.tool_calls:
tool_name = tool_call["name"]
selected_tool = tools_by_name[tool_name]
tool_msg = selected_tool.invoke(tool_call)
messages.append(tool_msg)
# Get the final response
ai_msg = llm_with_tools.invoke(messages)
ai_msg.content
# '3 * 12 is 36, and 6 + 4 is 10.'
Code Agents#
In Dataiku, you can implement a custom agent in code that leverages models from the LLM Mesh, LangChain, and its wider ecosystem.
The resulting agent becomes part of the LLM Mesh, seamlessly integrating into your AI workflows.
Dataiku includes basic code examples to help you get started. Below are more advanced samples that showcase full-fledged examples of agents built with LangChain and LangGraph. They both work with the internal code environment for retrieval-augmented generation to avoid any code env issue.
Note
To be able to use Code Agents you will need the Advanced LLM Mesh add-on.
This support agent is designed to handle customer inquiries efficiently. With its tools, it can:
retrieve relevant information from an FAQ database
log issues for follow-up when immediate answers aren’t available
escalate complex requests to a human agent when necessary.
We have tested it on this Paris Olympics FAQ dataset,
which we used to create a knowledge bank with the Embed recipe.
We have embedded a column containing both the question and the corresponding answer.
Use the agent on inquiries like: How will transportation work in Paris during the Olympic Games?
or
I booked a hotel for the Olympic games in Paris, but never received any confirmation. What's happening?
and see how it reacts!
import dataiku
from dataiku.llm.python import BaseLLM
from dataiku.langchain.dku_llm import DKUChatLLM
from dataiku.langchain import LangchainToDKUTracer
from langchain.tools import Tool
from langchain.agents import create_openai_tools_agent, AgentExecutor
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
## 1. Set Up Vector Search for FAQs
# Here, we are using a knowledge bank from the flow, build with our native Embed recipe.
# We make it a Langchain retriever and pass it to our first tool.
faq_retriever = dataiku.KnowledgeBank(id="<KD_ID>").as_langchain_retriever()
faq_retriever_tool = Tool(
name="FAQRetriever",
func=faq_retriever.get_relevant_documents,
description="Retrieves answers from the FAQ database based on user questions."
)
## 2. Define (fake) Ticketing & Escalation Tools
# Simulated ticket creation function
def create_ticket(issue: str):
# Here, you would typically use the API to your internal ticketing tool.
return f"Ticket created: {issue}"
ticketing_tool = Tool(
name="CreateTicket",
func=create_ticket,
description="Creates a support ticket when the issue cannot be resolved automatically."
)
# Simulated escalation function
def escalate_to_human(issue: str):
# This function could send a notification to the support engineers, for instance.
# It can be useful to attach info about the customer's request, sentiment, and history.
return f"Escalation triggered: {issue} has been sent to a human agent."
escalation_tool = Tool(
name="EscalateToHuman",
func=escalate_to_human,
description="Escalates the issue to a human when it's too complex, or the user is upset."
)
## 3. Build the LangChain Agent
# Define LLM for agent reasoning
llm = DKUChatLLM(llm_id="<valid:model:id_from_the_llm_mesh>")
# Agent tools (FAQ retrieval + ticketing + escalation)
tools = [faq_retriever_tool, ticketing_tool, escalation_tool]
tool_names = [tool.name for tool in tools]
# Define the prompt
prompt = ChatPromptTemplate.from_messages(
[
("system",
"""You are an AI customer support agent. Your job is to assist users by:
- Answering questions using the FAQ retriever tool.
- Creating support tickets for unresolved issues.
- Escalating issues to a human when necessary."""),
MessagesPlaceholder("chat_history", optional=True),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
# Initialize an agent with tools.
# Here, we define it as an agent that uses OpenAI tools.
# More options are available at https://python.langchain.com/api_reference/langchain/agents.html
agent = create_openai_tools_agent(llm=llm, tools=tools, prompt=prompt)
class MyLLM(BaseLLM):
def __init__(self):
pass
def process(self, query, settings, trace):
prompt = query["messages"][0]["content"]
tracer = LangchainToDKUTracer(dku_trace=trace)
# Wrap the agent in an executor
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
response = agent_executor.invoke({"input": prompt}, config={"callbacks": [tracer]})
return {"text": response["output"]}
This data analysis agent is designed to automate insights from data.
Given a table (from an SQL database) and its schema (list of columns with information about what they contain), it can:
take a user question
translate it into an SQL query
run the query and fetch the result
interpret the result and convert it back into natural language.
The code below was written for this dataset about car sales.
We used a Prepare recipe to remove some columns and parse the date to a proper format.
Once implemented, test your agent with questions like:
What were the top 5 best-selling car models in 2023?
or
What was the year-over-year evolution in the Scottsdale region regarding the number of sales?
.
import dataiku
from dataiku.llm.python import BaseLLM
from dataiku.langchain.dku_llm import DKUChatLLM
from dataiku.langchain import LangchainToDKUTracer
from langchain.prompts import ChatPromptTemplate
from dataiku import SQLExecutor2
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
# Basic configuration
# Initialize LLM
llm = DKUChatLLM(llm_id="<valid:model:id_from_the_llm_mesh>")
# Connect to the sales database
dataset = dataiku.Dataset("car_data_prepared_sql")
table_name = dataset.get_location_info().get('info', {}).get('table')
table_schema = """
- `Car_id` (TEXT): Unique car ID
- `Date` (DATE): Date of the sale
- `Dealer_Name` (TEXT): Name of the car dealer
- `Company` (TEXT): Company or brand of the car
- `Model` (TEXT): Model of the car
- `Transmission` (TEXT): Type of transmission in the car
- `Color` (TEXT): Color of the car's exterior
- `Price` (INTEGER): Listed price of the car sold
- `Body_Style` (TEXT): Style or design of the car's body
- `Dealer_Region` (TEXT): Geographic region of the car dealer
"""
# Here, we are adding a dispatcher as the first step of our graph. If the user query is not related to car sales,
# the agent will simply answer that it can't talk about anything else that car sales.
def dispatcher(state):
"""
Decides if the query is related to car sales data or just a general question.
Args:
state (dict): The current graph state
Returns:
str: Binary decision for the next node to call
"""
user_query = state["user_query"]
# Classification prompt
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a classifier that determines whether a user's query is related to car sales data.\n\n"
"The table contains information about car sales, including:\n"
"- Sale date & price\n"
"- Info about the cars (brand, model, transmission, body style & color) \n"
"- Dealer name and region\n\n"
"If the query is related to analyzing car sales data, return 'SQL'.\n"
"Otherwise, return 'GENERIC'."
),
(
"human", "{query}"
)
]
)
# Get the classification result
classification = llm.invoke(
prompt.format_messages(query=user_query)
).content.strip()
return classification
# First node, take the user input and translate it into a coherent SQL query.
def sql_translation(state):
"""
Translates a natural language query into SQL using the database schema.
Args:
state (dict): The current graph state that contains the user_query
Returns:
state (dict): New key added to state -- sql_query -- that contains the query to execute.
"""
print("---Translate to SQL---")
user_query = state["user_query"]
# We need to pass the model our table name and schema. Adapt instructions according to your needs, of course.
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an AI assistant that converts natural language questions into SQL queries.\n\n"
"Here are the table name: {name} and schema:\n{schema}\n\n"
"Here are some important rules:\n"
"- Use correct table and column names\n"
"- Do NOT use placeholders (e.g., '?', ':param').\n"
"- The SQL should be executable in PostgreSQL, which means that table and column names should ALWAYS be double-quoted.\n"
"- Never return your answer with SQL Markdown decorators. Just the SQL query, nothing else."
),
(
"human",
"Convert the following natural language query into an SQL query:\n\n{query}"
)
]
)
# Invoke LLM with formatted prompt
sql_query = llm.invoke(
prompt.format_messages(name=table_name, schema=table_schema, query=user_query)
).content
return {"sql_query": sql_query}
# Second node, run the SQL query on the table. For this, we are using Dataiku's API.
def database_query(state):
"""
Executes the SQL query and retrieves results.
Args:
state (dict): The current graph state that contains the query to execute.
Returns:
state (dict): New key added to state -- query_result -- that contains the result of the query.
Returns an error key if not working.
"""
print("---Run SQL query---")
sql_query = state["sql_query"]
try:
executor = SQLExecutor2(dataset=dataset)
df = executor.query_to_df(sql_query)
return {"query_result": df.to_dict(orient="records")}
except Exception as e:
return {"error": str(e)}
# Third node, interpret the results and convert it back into natural language.
def result_interpreter(state):
"""
Takes the raw database output and converts it into a natural language response.
Args:
state (dict): The current graph state, that contains the result of the query (or an error if it didn't work)
Returns:
state (dict): New key added to state -- response -- that contains the final agent response.
"""
print("---Interpret results---")
query_result = state.get("query_result", [])
if not query_result:
return {"response": "No results were found, or the query failed."}
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an AI assistant that summarizes findings from database results in a clear, human-readable format.\n"
),
(
"human", "{query_result}"
)
]
)
formatted_prompt = prompt.format_messages(query_result=query_result)
if len(formatted_prompt) > 1000:
return {"response": "The returned results were too long to be analyzed. Rephrase your query."}
summary = llm.invoke(formatted_prompt).content
return {"response": summary}
# On the other branch of our graph, if the question is too generic, the agent will just answer with a generic response.
def generic_response(state):
return {
"response": "I'm an agent specialized in car sales data analysis. I only have access to info like "
"sales date, price, car characteristics, and dealer name or region. "
"Ask me anything about car sales!"
}
class AgentState(TypedDict):
"""State object for the agent workflow."""
user_query: str
sql_query: str
query_result: list
response: str
# Create graph
graph = StateGraph(AgentState)
# Add nodes
graph.add_node("sql_translation", sql_translation)
graph.add_node("database_query", database_query)
graph.add_node("result_interpreter", result_interpreter)
graph.add_node("generic_response", generic_response)
# Define decision edges
graph.add_conditional_edges(
START,
dispatcher,
{
"SQL": "sql_translation", # If query is about sales, go to SQL path
"GENERIC": "generic_response" # Otherwise, respond with a generic answer
}
)
# Define SQL query flow
graph.add_edge("sql_translation", "database_query")
graph.add_edge("database_query", "result_interpreter")
graph.add_edge("result_interpreter", END)
class MyLLM(BaseLLM):
def __init__(self):
pass
def process(self, query, settings, trace):
prompt = query["messages"][0]["content"]
tracer = LangchainToDKUTracer(dku_trace=trace)
# Compile the graph
query_analyzer = graph.compile()
result = query_analyzer.invoke({"user_query": prompt}, config={"callbacks": [tracer]})
resp_text = result["response"]
sql_query = result.get("sql_query", [])
if not sql_query:
return {"text": resp_text}
# If the agent did succeed, then we return the final response, as well as the sql_query, for audit purposes.
full_resp_text = f"{resp_text}\n\nHere is the SQL query I ran:\n\n{sql_query}"
return {"text": full_resp_text}
Fine-tuning#
Create a Fine-tuned LLM Saved Model version#
Note
Visual model fine-tuning is also available to customers with the Advanced LLM Mesh add-on.
With a Python recipe or notebook, it is possible to fine-tune an LLM from the
HuggingFace Hub and save it as a Fine-tuned LLM Saved Model version.
This is done with the create_finetuned_llm_version()
method, which takes an LLM Mesh connection name as input.
Settings on this connection like usage permission, guardrails, code environment, or
container configuration, will apply at inference time.
The above method must be called on an existing Saved Model. Create one
either programmatically (if you are in a notebook and don’t have one yet) with
create_finetuned_llm_saved_model()
or visually from the Saved Models list via +New Saved Model > Fine-tuned LLM
(if you want to do this in a python recipe, its output Saved Model must exist to create the recipe).
Here we fine-tune using several open-source frameworks from HuggingFace: transformers, trl & peft.
Attention
Note that fine-tuning a local LLM requires significant computational resources (GPU). The code samples below show state-of-the-art techniques to optimize memory usage and processing time, but this depends on your setup and might not always work. Also, beware that the size of your training (and optionally validation) dataset(s) greatly impacts the memory use and storage during fine-tuning.
One can fine-tune a smaller LLM with a small GPU available. Phi3 Mini is a good example, with “only” 3.8B parameters.
There are many techniques available to reduce memory usage and speed up computation. One of them is called Low-Rank Adaptation. It consists in freezing the weights from the base model and adding new, trainable matrices to the Transformer architecture. It drastically reduces the number of trainable parameters and, hence, the GPU memory requirement.
import datasets
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
from dataiku import recipe
from dataiku.llm.finetuning import formatters
base_model_name = "microsoft/Phi-3-mini-4k-instruct"
assert base_model_name, ("please specify a base LLM, it must be available"
" on HuggingFace hub")
connection_name = "a_huggingface_connection_name"
assert connection_name, ("please specify a connection name, the fine-tuned "
"LLM will be available from this connection")
##################
# Initial setup
##################
# Here, we're assuming that your training dataset is composed of 2 columns:
# the input (user message) and expected output (assistant message).
# If using a validation dataset, format should be the same.
user_message_column = "input"
assistant_message_column = "output"
columns = [user_message_column, assistant_message_column]
system_message_column = "" # optional
static_system_message = "" # optional
if system_message_column:
columns.append(system_message_column)
# Turn Dataiku datasets into SFTTrainer datasets.
training_dataset = recipe.get_inputs()[0]
df = training_dataset.get_dataframe(columns=columns)
train_dataset = datasets.Dataset.from_pandas(df)
validation_dataset = None
eval_dataset = None
if len(recipe.get_inputs()) > 1:
validation_dataset = recipe.get_inputs()[1]
df = validation_dataset.get_dataframe(columns=columns)
eval_dataset = datasets.Dataset.from_pandas(df)
saved_model = recipe.get_outputs()[0]
##################
# Model loading
##################
model = AutoModelForCausalLM.from_pretrained(base_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# It is mandatory to define a formatting function for fine-tuning,
# because ultimately, the model is fed with only one string:
# the concatenation of your input columns, in a specific format.
# Here, we leverage the apply_chat_template method, which depends on
# the tokenizer. For more information, see
# https://huggingface.co/docs/transformers/v4.43.3/chat_templating
formatting_func = formatters.ConversationalPromptFormatter(tokenizer.apply_chat_template,
*columns)
##################
# Fine-tune using SFTTrainer
##################
with saved_model.create_finetuned_llm_version(connection_name) as finetuned_llm_version:
# feel free to customize, the only requirement is for a transformers model
# to be created in finetuned_model_version.working_directory
# TRL package offers many possibilities to configure the training job.
# For the full list,
# see https://huggingface.co/docs/transformers/v4.43.3/en/main_classes/trainer#transformers.TrainingArguments
train_conf = SFTConfig(
output_dir=finetuned_llm_version.working_directory,
save_safetensors=True,
gradient_checkpointing=True,
num_train_epochs=1,
logging_steps=5,
eval_strategy="steps" if eval_dataset else "no",
)
# LoRA is one of the most popular adapter-based methods to reduce memory-usage
# and speed up fine-tuning
peft_conf = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type="CAUSAL_LM",
target_modules="all-linear",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
formatting_func=formatting_func,
args=train_conf,
peft_config=peft_conf,
)
trainer.train()
trainer.save_model()
# Finally, we are logging training information to the Saved Model version
config = finetuned_llm_version.config
config["trainingDataset"] = training_dataset.short_name
if validation_dataset:
config["validationDataset"] = validation_dataset.short_name
config["userMessageColumn"] = user_message_column
config["assistantMessageColumn"] = assistant_message_column
config["systemMessageColumn"] = system_message_column
config["staticSystemMessage"] = static_system_message
config["batchSize"] = trainer.state.train_batch_size
config["eventLog"] = trainer.state.log_history
It is also possible to fine-tune larger models, for instance, Mistral 7B. In that case, quantization can help further reducing the memory footprint. A paper called QLoRA shows how the LoRA technique can efficiently fine-tune quantized LLMs while limiting the performance loss.
import datasets
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
from dataiku import recipe
from dataiku.llm.finetuning import formatters
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
assert base_model_name, ("please specify a base LLM, it must be available"
" on HuggingFace hub")
connection_name = "a_huggingface_connection_name"
assert connection_name, ("please specify a connection name, the fine-tuned"
" LLM will be available from this connection")
##################
# Initial setup
##################
# Here, we're assuming that your training dataset is composed of 2 columns:
# the input (user message) and expected output (assistant message).
# If using a validation dataset, format should be the same.
user_message_column = "input"
assistant_message_column = "output"
columns = [user_message_column, assistant_message_column]
system_message_column = "" # optional
static_system_message = "" # optional
if system_message_column:
columns.append(system_message_column)
# Turn Dataiku datasets into SFTTrainer datasets.
training_dataset = recipe.get_inputs()[0]
df = training_dataset.get_dataframe(columns=columns)
train_dataset = datasets.Dataset.from_pandas(df)
validation_dataset = None
eval_dataset = None
if len(recipe.get_inputs()) > 1:
validation_dataset = recipe.get_inputs()[1]
df = validation_dataset.get_dataframe(columns=columns)
eval_dataset = datasets.Dataset.from_pandas(df)
saved_model = recipe.get_outputs()[0]
##################
# Model loading
##################
# Here, we are quantizing the Mistral model. It means that the weights
# are represented with lower-precision data types (like "Normal Float 4"
# from the [QLoRA paper](https://arxiv.org/pdf/2305.14314)) to optimize
# memory usage.
# We also change the data type used for matrix multiplication to speed
# up compute.
# One can of course use double (/nested) quantization, but with inevitable
# important precision loss.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
model = AutoModelForCausalLM.from_pretrained(base_model_name,
quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"
# It is mandatory to define a formatting function for fine-tuning,
# because ultimately, the model is fed with only one string:
# the concatenation of your input columns, in a specific format.
# Here, we leverage the apply_chat_template method, which depends
# on the tokenizer. For more information,
# see https://huggingface.co/docs/transformers/v4.43.3/chat_templating
formatting_func = formatters.ConversationalPromptFormatter(tokenizer.apply_chat_template,
*columns)
##################
# Fine-tune using SFTTrainer
##################
with saved_model.create_finetuned_llm_version(connection_name) as finetuned_llm_version:
# feel free to customize, the only requirement is for a transformers model
# to be created in finetuned_model_version.working_directory
# TRL package offers many possibilities to configure the training job.
# For the full list, see
# https://huggingface.co/docs/transformers/v4.43.3/en/main_classes/trainer#transformers.TrainingArguments
train_conf = SFTConfig(
output_dir=finetuned_llm_version.working_directory,
save_safetensors=True,
gradient_checkpointing=True,
num_train_epochs=1,
logging_steps=5,
eval_strategy="steps" if eval_dataset else "no",
)
# LoRA is one of the most popular adapter-based methods to reduce memory-usage
# and speed up fine-tuning
peft_conf = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type="CAUSAL_LM",
target_modules="all-linear",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
formatting_func=formatting_func,
args=train_conf,
peft_config=peft_conf,
)
trainer.train()
trainer.save_model()
# Finally, we are logging training information to the Saved Model version
config = finetuned_llm_version.config
config["trainingDataset"] = training_dataset.short_name
if validation_dataset:
config["validationDataset"] = validation_dataset.short_name
config["userMessageColumn"] = user_message_column
config["assistantMessageColumn"] = assistant_message_column
config["systemMessageColumn"] = system_message_column
config["staticSystemMessage"] = static_system_message
config["batchSize"] = trainer.state.train_batch_size
config["eventLog"] = trainer.state.log_history
Direct Preference Optimization (DPO) is a stable, efficient, and lightweight way to fine-tune LLMs using preference data. It was introduced in 2023 as a simpler alternative to complex reinforcement learning algorithms.
Instead of training a separate reward model, DPO uses a simple cross-entropy loss function to directly teach the LLM to assign high probability to preferred responses.
In this example, we leverage DPO with both quantization and low-rank adapters (LoRA) from the PEFT framework.
For more on DPO, see the original paper and TRL implementation. Other RLHF alternatives are also supported, like IPO or KTO. For traditional reinforcement learning algorithms, see PPO.
Note
Requirements to run this notebook:
Create an updated
INTERNAL_huggingface_local_vX
code environment with:trl==0.13.0 datasets==2.21.0
Use this code env for training & in the selected HuggingFace Local connection.
import dataiku
from dataiku import recipe
from datasets import Dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import huggingface_hub
from trl import DPOTrainer, DPOConfig
from peft import LoraConfig
#################################
# Model & Tokenizer Preparation #
#################################
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
MODEL_REVISION = "c72e5d1908b1e2929ec8fc4c8820e9706af1f80f"
connection_name = "a_huggingface_connection_name"
saved_model = recipe.get_outputs()[0]
# Here, we're assuming that your training dataset is composed of 3 columns:
# a question (we'll make it a prompt later), the chosen response and rejected response.
# If using a validation dataset, format should be the same.
train_dataset = Dataset.from_pandas(
dataiku.Dataset("po_train").get_dataframe()
)
validation_dataset = Dataset.from_pandas(
dataiku.Dataset("po_validation").get_dataframe()
)
auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
if secret["key"] == "hf_token":
huggingface_hub.login(token=secret["value"])
break
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION,
device_map="auto",
quantization_config=quantization_config,
use_cache=False # Because the model will change as it is fine-tuned
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION
)
tokenizer.pad_token = tokenizer.eos_token
####################
# Data Preparation #
####################
def return_prompt_and_responses(samples):
"""
Transform a batch of examples in a format suitable for DPO.
"""
return {
"prompt": [
f'[INST] Answer the following question in a concise manner: "{question}" [/INST]'
for question in samples["question"]
],
"chosen": samples["chosen"],
"rejected": samples["rejected"]
}
def transform(ds):
"""
Prepare the datasets in a format suitable for DPO.
"""
return ds.map(
return_prompt_and_responses,
batched=True,
remove_columns=ds.column_names
)
train_dataset = transform(train_dataset)
validation_dataset = transform(validation_dataset)
#####################
# Fine Tuning Model #
#####################
with saved_model.create_finetuned_llm_version(connection_name) as finetuned_llm_version:
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
# Define the training parameters
training_args = DPOConfig(
per_device_train_batch_size=4,
num_train_epochs=1,
output_dir=finetuned_llm_version.working_directory,
gradient_checkpointing=True
)
dpo_trainer = DPOTrainer(
model,
None, # The reference model is the base model (without LoRA adaptation)
peft_config=peft_config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
tokenizer=tokenizer,
)
# Fine-tune the model
dpo_trainer.train()
dpo_trainer.save_model()
config = finetuned_llm_version.config
config["batchSize"] = dpo_trainer.state.train_batch_size
config["eventLog"] = dpo_trainer.state.log_history
In these examples, we used popular techniques to optimize memory usage and processing time, like LoRA, quantization or gradient checkpointing. Note that the research and open source community is constantly coming up with new ways to make fine-tuning more accessible, while trying to avoid too much performance loss. For more information on other techniques you could try, see for instance the Transformers
or PEFT
documentations.
OpenAI-compatible API#
The OpenAI-compatible API provides an easy way to query the LLM Mesh as it is built on top of the LLM Mesh API and implements the most used parts of OpenAI’s API for text completion.
The OpenAI-compatible API allows you to send chat completion queries to all LLMs supported by the LLM Mesh, using a standard OpenAI format. This includes, for models that support it:
Streamed chat completion responses
Multimodal inputs (image and text)
Tool calls
JSON output mode
Attention
Some arguments from the OpenAI’s API reference are not supported.
Chat completion request:
n
response_format
seed
service_tier
parallel_tool_calls
user
function_call (deprecated)
functions (deprecated)
Chat completion response:
choices.message.refusal
choices.logprobs.refusal
created
service_tier
system_fingerprint
usage.completion_tokens_details
Your first OpenAI completion query#
from openai import OpenAI
# Specify the DSS OpenAI-compatible public API URL, e.g. http://my.dss/public/api/projects/PROJECT_KEY/llms/openai/v1/
BASE_URL = ""
# Fill with your DSS API Key
API_KEY = ""
# Fill with your LLM id. For example, if you have a HuggingFace connection called "myhf", LLM_ID can be "huggingfacelocal:myhf:meta-llama/Meta-Llama-3.1-8B-Instruct:TEXT_GENERATION_LLAMA_2:promptDriven=true"
# To get the list of LLM ids, you can use openai_client.models.list() or project.list_llms() through the dataiku client
LLM_ID = ""
# Create an OpenAI client
openai_client = OpenAI(
base_url=BASE_URL,
api_key=API_KEY
)
resp = openai_client.chat.completions.create(
model=LLM_ID,
messages=[{"role": "user", "content": "Write a haiku on GPT models" }],
)
if resp and resp.choices:
print(resp.choices[0].message.content)
# GPT, a marvel,
# Deep learning's symphony plays,
# Thoughts dance, words unveil.
from openai import OpenAI
# Specify the DSS OpenAI-compatible public API URL, e.g. http://my.dss/public/api/projects/PROJECT_KEY/llms/openai/v1/
BASE_URL = ""
# Fill with your DSS API Key
API_KEY = ""
# Fill with your LLM id. For example, if you have a HuggingFace connection called "myhf", LLM_ID can be "huggingfacelocal:myhf:meta-llama/Meta-Llama-3.1-8B-Instruct:TEXT_GENERATION_LLAMA_2:promptDriven=true"
# To get the list of LLM ids, you can use openai_client.models.list() or project.list_llms() through the dataiku client
LLM_ID = ""
# Create an OpenAI client
openai_client = OpenAI(
base_url=BASE_URL,
api_key=API_KEY
)
resp = openai_client.chat.completions.create(
model=LLM_ID,
messages=[{"role": "user", "content": "Write a haiku on GPT models" }],
stream=True
)
for chunk in resp:
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content)
# Words
# weave
# through
# the
# code
# ,
#
#
# Silent
# thoughts
# brought
# into
# light
# ,
# M
# inds
# connect
# in
# spark
# .
Image generation using the LLM Mesh#
Your first image-generation query#
This sample shows how to send an image generation query with the LLM Mesh to ask the image generation model to generate an image of a blue bird.
import dataiku
client = dataiku.api_client()
project = client.get_default_project()
# To list the image generation model ids, you can use project.list_llms(purpose="IMAGE_GENERATION")
IMAGE_GENERATION_MODELS = project.list_llms(purpose="IMAGE_GENERATION")
IMAGE_GENERATION_MODEL_ID = "" # Fill with your image generation model id, for example: openai:my_openai_connection:dall-e-3
# Create a handle for the image generation model of your choice
img_gen_model = project.get_llm(IMAGE_GENERATION_MODEL_ID)
prompt_text = "Vibrant blue bird in a serene scene on a blooming cherry blossom branch. Tranquil morning sky background with soft pastel colors of dawn, gently blending pinks, purples, and soft oranges. Distant view of a calm lake reflecting the colors of the sky and surrounded by lush greenery."
img_gen_query = img_gen_model.new_images_generation()
img_gen_query.with_prompt(prompt_text)
img_gen_resp = img_gen_query.execute()
image_data = img_gen_resp.first_image()
# You can display the image in your notebook
from IPython.display import Image, display
if img_gen_resp.success:
display(Image(image_data))
# Or you can save the image to a managed folder
FOLDER_ID = "" # Enter your managed folder id here
my_images_folder = dataiku.Folder(FOLDER_ID)
with my_images_folder.get_writer("blue_bird.png") as writer:
writer.write(image_data)
You can parameterize the query to impact the resulting image or generate more images.
The LLM Mesh maps each parameter to the corresponding parameter for the underlying model provider. Support varies across models/providers, and in particular not all models can generate more than one image.
If you want to generate multiple images with different prompts, you must query the LLM Mesh multiple times.
import dataiku
IMAGE_GENERATION_MODEL_ID = "" # Fill with your image generation model id
# Create a handle for the image generation model of your choice
client = dataiku.api_client()
project = client.get_default_project()
img_gen_model = project.get_llm(IMAGE_GENERATION_MODEL_ID)
generation = img_gen_model.new_images_generation()
generation.height = 1024
generation.width = 1024
generation.seed = 3
# If the underlying model supports weighted prompts they will be passed with
# their specified weight, otherwise they will just be merged and sent as a single prompt.
generation.with_prompt("meat pizza", weight=0.8).with_prompt("rustic wooden table", weight=0.6)
# Not all models or providers support more than one
generation.images_to_generate = 1
# Regardless of what parameter the underlying provider expects for the image dimensions,
# when using the LLM Mesh API you can specify either the height and width or the aspect_ratio.
# The LLM Mesh will do the translation between its API and the underlying provider.
# Not all models support the same dimensions.
generation.aspect_ratio = 21 / 9
# The following parameters are not relevant for all models
generation.with_negative_prompt("tomatoes, basil, green leaf", weight=1)
generation.fidelity = 0.5 # from 0.1 to 1, how strongly to adhere to prompt
# valid values depend on the targeted model
generation.quality = "hd"
generation.style = "anime"
resp = generation.execute()
Image-to-image query#
Some models can generate an image from another image, see this documentation.
Mask-free variation generates another image guided by a prompt
Some models can generate unprompted variations
Inpainting uses a mask (either black pixels in a second input image, or transparent pixels on the original image) to fill the corresponding pixels of the input image
In this example, we ask the model for an image variation by passing an image and a prompt using the MASK_FREE
mode.
import dataiku
IMAGE_GENERATION_MODEL_ID = "" # Fill with your image generation model id
img_gen_model = dataiku.api_client().get_default_project().get_llm(IMAGE_GENERATION_MODEL_ID)
# Your image to use as an input.
# Here we're retrieving it from a managed folder but it could also be an image from a previous generation
my_images_folder = dataiku.Folder("my_folder_id")
with my_images_folder.get_download_stream("cat_on_the_beach.png") as img_file:
input_img_data = img_file.read()
# Create the generation query
generation = img_gen_model.new_images_generation()
generation.with_original_image(input_img_data, mode="MASK_FREE", weight=0.3)
generation.with_prompt("dog on the beach")
resp = generation.execute()
Image-to-image generation with a prompt can also be used with the CONTROLNET_STRUCTURE
and CONTROLNET_SKETCH
modes.
In this example, we ask the model for an image variation by sending an image without a prompt.
import dataiku
IMAGE_GENERATION_MODEL_ID = "" # Fill with your image generation model id
img_gen_model = dataiku.api_client().get_default_project().get_llm(IMAGE_GENERATION_MODEL_ID)
# Your image to use as an input.
# Here we're retrieving it from a managed folder but it could also be an image from a previous generation
my_images_folder = dataiku.Folder("my_folder_id")
with my_images_folder.get_download_stream("cat_on_the_beach.png") as img_file:
input_img_data = img_file.read()
# Create the generation query
generation = img_gen_model.new_images_generation()
generation.with_original_image(input_img_data, mode="VARY", weight=0.3)
resp = generation.execute()
When using the MASK_IMAGE_BLACK
mask mode, you need to specify a mask with black pixels to fill.
import dataiku
IMAGE_GENERATION_MODEL_ID = "" # Fill with your image generation model id
img_gen_model = dataiku.api_client().get_default_project().get_llm(IMAGE_GENERATION_MODEL_ID)
# Your image to use as an input.
# Here we're retrieving it from a managed folder but it could also be an image from a previous generation
my_images_folder = dataiku.Folder("my_folder_id")
with my_images_folder.get_download_stream("cat_on_the_beach.png") as img_file:
input_img_data = img_file.read()
with my_images_folder.get_download_stream("cat_black_mask.png") as img_file:
black_mask_image_data = img_file.read()
# Create the generation query
generation = img_gen_model.new_images_generation()
generation.with_original_image(input_img_data, mode="INPAINTING", weight=0.1)
generation.with_mask("MASK_IMAGE_BLACK", image=black_mask_image_data)
generation.with_prompt("dog")
resp = generation.execute()
When using the ORIGINAL_IMAGE_ALPHA
you do not need to specify a mask image. The model will fill the transparent pixels from the original image.
import dataiku
IMAGE_GENERATION_MODEL_ID = "" # Fill with your image generation model id
img_gen_model = dataiku.api_client().get_default_project().get_llm(IMAGE_GENERATION_MODEL_ID)
# Your image to use as an input.
my_images_folder = dataiku.Folder("my_folder_id")
with my_images_folder.get_download_stream("cat_transparent_background.png") as img_file:
input_img_data = img_file.read()
# Create the generation query
generation = img_gen_model.new_images_generation()
generation.with_original_image(input_img_data, mode="INPAINTING", weight=0.1)
generation.with_mask("ORIGINAL_IMAGE_ALPHA")
generation.with_prompt("Beach scene at sunset, with golden sands, gentle waves at the shore.")
resp = generation.execute()
Reference documentation#
|
A handle to interact with a DSS-managed LLM. |
|
An item in a list of llms |
A handle to interact with a completion query. |
|
A handle to interact with a multi-completion query. |
|
|
|
Response to a completion |
|
A handle to interact with an embedding query. |
|
A handle to interact with an embedding query result. |
|
An item in a list of knowledege banks |
|
A handle to interact with a DSS-managed knowledge bank. |
|
|
Langchain-compatible wrapper around Dataiku-mediated chat LLMs |
|
Langchain-compatible wrapper around Dataiku-mediated LLMs |
Langchain-compatible wrapper around Dataiku-mediated embedding LLMs |