import inspect
import json
from typing import Callable, List, Literal, Optional, Union
import uuid
import openai
from pyspark.sql.connect.dataframe import DataFrame as cDataFrame
from pyspark.sql.dataframe import DataFrame
from yipit_databricks_client.helpers.telemetry import track_usage
from ml_toolkit.functions.llm.constants import DEFAULT_MODEL_QUERY, MAX_FILE_SIZE_MB
from ml_toolkit.functions.llm.helpers.authentication import create_llm_client
from ml_toolkit.functions.llm.helpers.dataframes import (
convert_dataframe_to_json,
get_dataframe_context,
)
from ml_toolkit.functions.llm.helpers.files import (
format_context_for_prompt,
read_context_files,
)
from ml_toolkit.functions.llm.helpers.logging import log_model_usage
from ml_toolkit.ops.helpers.logger import get_logger
[docs]
class LLMResponse:
"""
Class that defines the LLM response of the ``query`` function. It exposes the following attributes
* ``.text``: returns the text response
* ``response``: returns the raw LLM response class (``openai.ChatCompletion``)
And also the following methods:
* ``.json()``: tries to parse the output into a python ``dict`` if possible
"""
def __init__(self, raw_response: openai.ChatCompletion):
"""
:meta private:
"""
self.response = raw_response
self.text = raw_response.choices[0].message.content or ""
[docs]
def __new__(cls, *args, **kwargs):
"""
Only added here to avoid having this method appear on the docs.
:meta private:
"""
return super(LLMResponse, cls).__new__(cls)
def json(self, strict: bool = True) -> Optional[dict]:
"""
Parse the response text as JSON.
:param strict: If True, raise JSONDecodeError for invalid JSON. If False, return None for invalid JSON.
"""
if not self.text.strip():
if strict:
raise json.JSONDecodeError("Empty response", "", 0)
return None
# Try to extract JSON from the response if it's embedded in other text
text = self.text.strip()
try:
# First try parsing the entire response
return json.loads(text)
except json.JSONDecodeError:
if strict:
raise
# If that fails and we're not in strict mode, try to find JSON in the text
try:
# Look for text between curly braces
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
return json.loads(text[start : end + 1])
except json.JSONDecodeError:
pass
return None
def python_function_to_tool(fn: Callable) -> dict:
"""
Convert a Python function into a tool definition compatible with OpenAI's tool format.
"""
signature = inspect.signature(fn)
doc = fn.__doc__ or ""
properties = {}
required = []
for name, param in signature.parameters.items():
annotation = param.annotation
param_schema = {}
if annotation is str:
param_schema["type"] = "string"
elif annotation is int:
param_schema["type"] = "integer"
elif annotation is float:
param_schema["type"] = "number"
elif annotation is bool:
param_schema["type"] = "boolean"
else:
param_schema["type"] = "string" # default fallback
if param.default is not inspect.Parameter.empty:
# Optional param
pass
else:
required.append(name)
properties[name] = param_schema
return {
"type": "function",
"function": {
"name": fn.__name__,
"description": doc.strip(),
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
@track_usage
[docs]
def query(
prompt: str,
inputs: Optional[Union[str, dict, DataFrame]] = None,
context_files: Optional[Union[str, List[str]]] = None,
model: str = DEFAULT_MODEL_QUERY,
response_format: Literal["text", "json"] = "text",
tools: Optional[list[Union[Callable, dict]]] = None,
max_tool_calls: int = None,
max_output_tokens: Optional[int] = 2048,
system_prompt: Optional[str] = None,
cost_component_name: Optional[str] = None,
model_extra_options: dict = {},
) -> LLMResponse:
"""
Query a LLM model and get its response.
This function allows you to prompt an LLM and implements many resources to streamline usage. Some examples are:
input data (including dataframes), files, tools, web search. Check the examples to see advanced usage of these
features.
:param prompt: The main prompt text to send to the LLM model.
:param inputs: Optional input data that can be a string, dictionary, or DataFrame to provide context.
:param context_files: Optional file path(s) to include as context. Can be a single string or list of strings.
:param model: Optional name of the LLM model to use.
:param response_format: Format of the response - either "text" or "json". Defaults to "text".
:param tools: Optional list of functions or dicts that the LLM can use as tools.
:param max_tool_calls: Maximum number of tool calls allowed.
:param max_output_tokens: Maximum number of tokens in the response. Defaults to 2048.
:param system_prompt: Optional system prompt to set the behavior of the model.
:param cost_component_name: Name for tracking costs.
:param model_extra_options: Additional model-specific options as a dictionary.
:returns: :py:class:`LLMResponse` object containing the model's response.
:raises ValueError: If there are errors reading context files.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Querying a search-enabled model.
from ml_toolkit.functions.llm import query
resp = query(
"What are the latest info on the Trump Tariffs?",
model="gpt-4o-mini-search-preview",
model_extra_options={
"web_search_options":{
"user_location": {
"type": "approximate",
"approximate": {
"country": "US",
}
},
}
},
cost_component_name=... # use your team's cost component here!
)
print(resp.text)
.. code-block:: python
:caption: Passing a dataframe as context and asking the model questions about our data.
from ml_toolkit.functions.llm import query
df = spark.table(...).limit(10)
response = query(
"How can I query this table to know the usage rates by hour over the last 21 days?",
inputs=[df],
response_format="text",
cost_component_name=... # use your team's cost component here!
)
print(response.text)
.. code-block:: python
:caption: Passing functions to increase the models' capabilities.
from ml_toolkit.functions.llm import query
def sum_two_numbers(a: float, b: float) -> float:
\"""
Sums two numbers and returns the result.
\
"""
return a + b
response = query(
"What's the sum between 1.1234 and 4.4321?",
tools=[sum_two_numbers],
response_format="json",
cost_component_name=... # use your team's cost component here!
)
print(response.json())
"""
logger = get_logger()
source_uuid = str(uuid.uuid4())
client = create_llm_client(model)
logger.info(f"Using model: {model}")
if cost_component_name is None:
raise ValueError("`cost_component_name` is obligatory.")
# 1. Construct the prompt
user_prompt = prompt
if inputs:
for input in inputs:
if isinstance(input, dict):
user_prompt += "\nInput data as JSON:\n" + json.dumps(input)
elif isinstance(input, (DataFrame, cDataFrame)):
user_prompt += "\nDataframe context:\n" + get_dataframe_context(input)
user_prompt += (
"\nDataframe data as JSON:\n" + convert_dataframe_to_json(input)
)
else:
user_prompt += "\n" + input
if context_files:
try:
file_context = read_context_files(context_files, MAX_FILE_SIZE_MB)
user_prompt = format_context_for_prompt(file_context, user_prompt)
except (FileNotFoundError, ValueError) as e:
raise ValueError(f"Error reading context files: {e}")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if response_format == "json":
messages.append(
{
"role": "system",
"content": """
You are a helpful assistant that always responds in valid JSON format.
- Your response should contain ONLY the JSON object, no other text
- Use proper JSON syntax with double quotes for keys and string values
- Follow the requested JSON structure exactly
""",
}
)
messages.append({"role": "user", "content": user_prompt})
# 2. Prepare tools if provided
tools_payload = []
tool_lookup = {}
if tools:
tools_payload = []
for tool in tools:
if callable(tool):
tool_def = python_function_to_tool(tool)
tools_payload.append(tool_def)
tool_lookup[tool.__name__] = tool
elif isinstance(tool, dict):
tools_payload.append(tool)
logger.info(f"Using tools: {tool_lookup.keys()}")
# 3. Main tool-calling loop
num_tool_calls = 0
should_call = True
while should_call is True:
api_params = {
"model": model,
"messages": messages,
"max_tokens": max_output_tokens,
**model_extra_options,
}
# Add tools if provided
if tools_payload:
api_params["tools"] = tools_payload
api_params["tool_choice"] = "auto"
if "databricks" in model:
api_params["extra_body"] = {
"usage_context": {
"model": model,
"source_uuid": source_uuid,
"cost_component_name": cost_component_name,
}
}
response = client.chat.completions.create(**api_params)
log_model_usage(
source_uuid,
model=model,
prompt=prompt,
actual_input_tokens=response.usage.prompt_tokens,
actual_output_tokens=response.usage.completion_tokens,
cost_component_name=cost_component_name,
)
tool_calls = (
response.choices[0].message.tool_calls
if hasattr(response.choices[0].message, "tool_calls")
else []
) or []
if not tool_calls:
should_call = False
if isinstance(max_tool_calls, int) and num_tool_calls > max_tool_calls:
should_call = False
for call in tool_calls:
tool_name = call.function.name
tool_args = json.loads(call.function.arguments)
tool_fn = tool_lookup.get(tool_name)
if tool_fn:
logger.debug(f"Calling tool: {tool_name} with args: {tool_args}")
result = tool_fn(**tool_args)
messages.append(
{
"role": "assistant",
"tool_calls": [
{
"id": call.id,
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(tool_args),
},
}
],
}
)
messages.append(
{
"role": "tool",
"tool_call_id": call.id,
"name": tool_name,
"content": json.dumps(result),
}
)
return LLMResponse(response)