Source code for ml_toolkit.functions.llm.query

import inspect
import json
from typing import Callable, List, Literal, Optional, Union
import uuid

import openai
from pyspark.sql.connect.dataframe import DataFrame as cDataFrame
from pyspark.sql.dataframe import DataFrame
from yipit_databricks_client.helpers.telemetry import track_usage

from ml_toolkit.functions.llm.constants import DEFAULT_MODEL_QUERY, MAX_FILE_SIZE_MB
from ml_toolkit.functions.llm.helpers.authentication import create_llm_client
from ml_toolkit.functions.llm.helpers.dataframes import (
    convert_dataframe_to_json,
    get_dataframe_context,
)
from ml_toolkit.functions.llm.helpers.files import (
    format_context_for_prompt,
    read_context_files,
)
from ml_toolkit.functions.llm.helpers.logging import log_model_usage
from ml_toolkit.ops.helpers.logger import get_logger


[docs] class LLMResponse: """ Class that defines the LLM response of the ``query`` function. It exposes the following attributes * ``.text``: returns the text response * ``response``: returns the raw LLM response class (``openai.ChatCompletion``) And also the following methods: * ``.json()``: tries to parse the output into a python ``dict`` if possible """ def __init__(self, raw_response: openai.ChatCompletion): """ :meta private: """ self.response = raw_response self.text = raw_response.choices[0].message.content or ""
[docs] def __new__(cls, *args, **kwargs): """ Only added here to avoid having this method appear on the docs. :meta private: """ return super(LLMResponse, cls).__new__(cls)
def json(self, strict: bool = True) -> Optional[dict]: """ Parse the response text as JSON. :param strict: If True, raise JSONDecodeError for invalid JSON. If False, return None for invalid JSON. """ if not self.text.strip(): if strict: raise json.JSONDecodeError("Empty response", "", 0) return None # Try to extract JSON from the response if it's embedded in other text text = self.text.strip() try: # First try parsing the entire response return json.loads(text) except json.JSONDecodeError: if strict: raise # If that fails and we're not in strict mode, try to find JSON in the text try: # Look for text between curly braces start = text.find("{") end = text.rfind("}") if start >= 0 and end > start: return json.loads(text[start : end + 1]) except json.JSONDecodeError: pass return None
def python_function_to_tool(fn: Callable) -> dict: """ Convert a Python function into a tool definition compatible with OpenAI's tool format. """ signature = inspect.signature(fn) doc = fn.__doc__ or "" properties = {} required = [] for name, param in signature.parameters.items(): annotation = param.annotation param_schema = {} if annotation is str: param_schema["type"] = "string" elif annotation is int: param_schema["type"] = "integer" elif annotation is float: param_schema["type"] = "number" elif annotation is bool: param_schema["type"] = "boolean" else: param_schema["type"] = "string" # default fallback if param.default is not inspect.Parameter.empty: # Optional param pass else: required.append(name) properties[name] = param_schema return { "type": "function", "function": { "name": fn.__name__, "description": doc.strip(), "parameters": { "type": "object", "properties": properties, "required": required, }, }, } @track_usage
[docs] def query( prompt: str, inputs: Optional[Union[str, dict, DataFrame]] = None, context_files: Optional[Union[str, List[str]]] = None, model: str = DEFAULT_MODEL_QUERY, response_format: Literal["text", "json"] = "text", tools: Optional[list[Union[Callable, dict]]] = None, max_tool_calls: int = None, max_output_tokens: Optional[int] = 2048, system_prompt: Optional[str] = None, cost_component_name: Optional[str] = None, model_extra_options: dict = {}, ) -> LLMResponse: """ Query a LLM model and get its response. This function allows you to prompt an LLM and implements many resources to streamline usage. Some examples are: input data (including dataframes), files, tools, web search. Check the examples to see advanced usage of these features. :param prompt: The main prompt text to send to the LLM model. :param inputs: Optional input data that can be a string, dictionary, or DataFrame to provide context. :param context_files: Optional file path(s) to include as context. Can be a single string or list of strings. :param model: Optional name of the LLM model to use. :param response_format: Format of the response - either "text" or "json". Defaults to "text". :param tools: Optional list of functions or dicts that the LLM can use as tools. :param max_tool_calls: Maximum number of tool calls allowed. :param max_output_tokens: Maximum number of tokens in the response. Defaults to 2048. :param system_prompt: Optional system prompt to set the behavior of the model. :param cost_component_name: Name for tracking costs. :param model_extra_options: Additional model-specific options as a dictionary. :returns: :py:class:`LLMResponse` object containing the model's response. :raises ValueError: If there are errors reading context files. Examples ^^^^^^^^^^ .. code-block:: python :caption: Querying a search-enabled model. from ml_toolkit.functions.llm import query resp = query( "What are the latest info on the Trump Tariffs?", model="gpt-4o-mini-search-preview", model_extra_options={ "web_search_options":{ "user_location": { "type": "approximate", "approximate": { "country": "US", } }, } }, cost_component_name=... # use your team's cost component here! ) print(resp.text) .. code-block:: python :caption: Passing a dataframe as context and asking the model questions about our data. from ml_toolkit.functions.llm import query df = spark.table(...).limit(10) response = query( "How can I query this table to know the usage rates by hour over the last 21 days?", inputs=[df], response_format="text", cost_component_name=... # use your team's cost component here! ) print(response.text) .. code-block:: python :caption: Passing functions to increase the models' capabilities. from ml_toolkit.functions.llm import query def sum_two_numbers(a: float, b: float) -> float: \""" Sums two numbers and returns the result. \""" return a + b response = query( "What's the sum between 1.1234 and 4.4321?", tools=[sum_two_numbers], response_format="json", cost_component_name=... # use your team's cost component here! ) print(response.json()) """ logger = get_logger() source_uuid = str(uuid.uuid4()) client = create_llm_client(model) logger.info(f"Using model: {model}") if cost_component_name is None: raise ValueError("`cost_component_name` is obligatory.") # 1. Construct the prompt user_prompt = prompt if inputs: for input in inputs: if isinstance(input, dict): user_prompt += "\nInput data as JSON:\n" + json.dumps(input) elif isinstance(input, (DataFrame, cDataFrame)): user_prompt += "\nDataframe context:\n" + get_dataframe_context(input) user_prompt += ( "\nDataframe data as JSON:\n" + convert_dataframe_to_json(input) ) else: user_prompt += "\n" + input if context_files: try: file_context = read_context_files(context_files, MAX_FILE_SIZE_MB) user_prompt = format_context_for_prompt(file_context, user_prompt) except (FileNotFoundError, ValueError) as e: raise ValueError(f"Error reading context files: {e}") messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) if response_format == "json": messages.append( { "role": "system", "content": """ You are a helpful assistant that always responds in valid JSON format. - Your response should contain ONLY the JSON object, no other text - Use proper JSON syntax with double quotes for keys and string values - Follow the requested JSON structure exactly """, } ) messages.append({"role": "user", "content": user_prompt}) # 2. Prepare tools if provided tools_payload = [] tool_lookup = {} if tools: tools_payload = [] for tool in tools: if callable(tool): tool_def = python_function_to_tool(tool) tools_payload.append(tool_def) tool_lookup[tool.__name__] = tool elif isinstance(tool, dict): tools_payload.append(tool) logger.info(f"Using tools: {tool_lookup.keys()}") # 3. Main tool-calling loop num_tool_calls = 0 should_call = True while should_call is True: api_params = { "model": model, "messages": messages, "max_tokens": max_output_tokens, **model_extra_options, } # Add tools if provided if tools_payload: api_params["tools"] = tools_payload api_params["tool_choice"] = "auto" if "databricks" in model: api_params["extra_body"] = { "usage_context": { "model": model, "source_uuid": source_uuid, "cost_component_name": cost_component_name, } } response = client.chat.completions.create(**api_params) log_model_usage( source_uuid, model=model, prompt=prompt, actual_input_tokens=response.usage.prompt_tokens, actual_output_tokens=response.usage.completion_tokens, cost_component_name=cost_component_name, ) tool_calls = ( response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") else [] ) or [] if not tool_calls: should_call = False if isinstance(max_tool_calls, int) and num_tool_calls > max_tool_calls: should_call = False for call in tool_calls: tool_name = call.function.name tool_args = json.loads(call.function.arguments) tool_fn = tool_lookup.get(tool_name) if tool_fn: logger.debug(f"Calling tool: {tool_name} with args: {tool_args}") result = tool_fn(**tool_args) messages.append( { "role": "assistant", "tool_calls": [ { "id": call.id, "type": "function", "function": { "name": tool_name, "arguments": json.dumps(tool_args), }, } ], } ) messages.append( { "role": "tool", "tool_call_id": call.id, "name": tool_name, "content": json.dumps(result), } ) return LLMResponse(response)