from typing import List, Literal, Optional, Union
import uuid
from pyspark.sql import DataFrame
from yipit_databricks_client import get_spark_session
from yipit_databricks_client.helpers.telemetry import track_usage
from ml_toolkit.functions.llm.constants import MAX_NUM_ROWS_INTERACTIVE_MODE
from ml_toolkit.functions.llm.run_vector_search.function import (
vector_search_batch,
vector_search_interactive,
)
from ml_toolkit.ops.helpers.exceptions import MLOpsToolkitTooManyRowsForInteractiveUsage
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.validation import (
assert_columns_exists_in_dataframe,
assert_columns_exists_in_table,
assert_table_exists,
)
def resolve_dry_run(data_source: DataFrame, dry_run: bool = None) -> bool:
num_rows = data_source.count()
if dry_run is True and num_rows > MAX_NUM_ROWS_INTERACTIVE_MODE:
raise MLOpsToolkitTooManyRowsForInteractiveUsage(
num_rows=num_rows, max_num_rows=MAX_NUM_ROWS_INTERACTIVE_MODE
)
if dry_run is None and num_rows <= MAX_NUM_ROWS_INTERACTIVE_MODE:
dry_run = True
elif dry_run is None and num_rows > MAX_NUM_ROWS_INTERACTIVE_MODE:
dry_run = False
return dry_run
@track_usage
[docs]
def run_vector_search(
data_source: Union[DataFrame, str],
index_name: str,
search_column: str,
output_columns: Optional[List[str]] = None,
output_table_name: Optional[str] = None,
num_results: int = 10,
dry_run: bool = None,
wait_for_completion: bool = True,
table_operation: Literal["overwrite", "append"] = "overwrite",
primary_key_columns: Optional[List[str]] = None,
cost_component_name: str = None,
query_type: Literal["nearest", "hybrid"] = "nearest",
) -> dict:
"""
Performs vector search on the input data (spark ``DataFrame`` or table) using a specified search index.
The search is performed on the column specified by ``search_column`` and results are written to ``output_table_name``.
The function supports two query types:
- ``nearest``: Pure vector similarity search (default) - best for semantic similarity matching
- ``hybrid``: Combined vector and keyword search - useful when you want both semantic and exact keyword matching
There are two operations modes:
1. ``dry_run=True``: good for quick experimentation and testing; only works with less than 1k rows of data.
2. ``dry_run=False``: should be used for the production pipeline or when running over a lot of data. Writes data to ``output_table_name`` and can either ``overwrite`` or ``append`` (set via ``table_operation``).
.. caution:: ``dry_run=True`` does incur usage and costs. It's a mode designed to allow faster and cheaper
experimentation.
.. attention:: You **must** pass a ``cost_component_name``, otherwise this function will raise an exception.
Parameters:
^^^^^^^^^^
:param data_source: DataFrame or Delta table name to run vector search on.
:param index_name: Fully qualified name of the vector search index to use.
:param search_column: Name of the column containing text to search.
:param output_columns: Optional list of column names from the base DataFrame to return. If None, returns all columns.
:param output_table_name: Optional Delta table to write results to.
:param num_results: Number of results to return per search (default: 10).
:param dry_run: Whether to run the processing job locally or triggering a remote batch run.
:param wait_for_completion: Whether to wait for job completion (only applies in batch mode).
:param table_operation: Operation to perform on the output table.
:param primary_key_columns: Primary key columns for the output table.
:param cost_component_name: Name of the cost component.
:param query_type: Type of query to run. Must be either 'nearest' (default) for approximate nearest neighbor or 'hybrid' for combined vector and keyword search.
:returns: Result of the processing job.
:raises ValueError: If output_table_name is not provided in batch mode, cost_component_name is missing, or query_type is invalid.
:raises MLOpsToolkitTooManyRowsForInteractiveUsage: If dry_run=True with too many rows.
Examples
^^^^^^^^^^
.. code-block:: python
:caption: Running vector search on a table.
from ml_toolkit.functions.llm import run_vector_search
run_vector_search(
data_source="catalog.schema.input_table",
index_name="catalog.schema.search_index",
search_column="text_to_search",
return_columns=["id", "name", "description"],
output_table_name="catalog.schema.output_table",
table_operation="overwrite",
query_type="nearest", # Use pure vector similarity search
cost_component_name=... # use your team's cost component here!
)
.. code-block:: python
:caption: Running vector search interactively on a DataFrame.
import pyspark.sql.functions as F
from ml_toolkit.functions.llm import run_vector_search
df = spark.createDataFrame([
("apple iphone 13",),
("samsung galaxy s21",),
], ["product_name"])
res = run_vector_search(
data_source=df,
index_name="catalog.schema.products_index",
search_column="product_name",
return_columns=["product_name", "category"],
num_results=5,
query_type="hybrid", # Use combined vector and keyword search
cost_component_name=..., # use your team's cost component here!
)
df_results = res["df"]
display(df_results)
.. code-block:: python
:caption: Using different query types for different use cases.
# For semantic similarity search (recommended for most use cases)
run_vector_search(
data_source="catalog.schema.products",
index_name="catalog.schema.product_embeddings",
search_column="product_description",
query_type="nearest", # Pure vector similarity
output_table_name="catalog.schema.similar_products",
cost_component_name=...
)
# For search that combines semantic and keyword matching
run_vector_search(
data_source="catalog.schema.queries",
index_name="catalog.schema.document_embeddings",
search_column="search_query",
query_type="hybrid", # Vector + keyword search
output_table_name="catalog.schema.search_results",
cost_component_name=...
)
"""
logger = get_logger()
# Generate a unique ID for this processing job
source_uuid = str(uuid.uuid4())
if isinstance(data_source, str):
spark = get_spark_session()
assert_table_exists(data_source)
assert_columns_exists_in_table(data_source, columns=[search_column])
if output_columns:
assert_columns_exists_in_table(data_source, columns=output_columns)
data_source = spark.table(data_source)
else:
assert_columns_exists_in_dataframe(
data_source, columns=[search_column], table_name="dataframe"
)
if output_columns:
assert_columns_exists_in_dataframe(
data_source, columns=output_columns, table_name="dataframe"
)
dry_run = resolve_dry_run(data_source, dry_run=dry_run)
if dry_run is True:
logger.info(
f"Running job {source_uuid} with dry_run={dry_run} and query_type={query_type}"
)
else:
logger.info(
f"Running job {source_uuid} with dry_run={dry_run} and table_operation={table_operation} and query_type={query_type}."
)
# Validate parameters based on mode
if not dry_run:
if output_table_name is None or output_table_name == "":
raise ValueError(
"`output_table_name` is required for batch processing mode."
)
if cost_component_name is None or cost_component_name == "":
raise ValueError("`cost_component_name` is obligatory.")
df_results = None
if dry_run is True:
df_results = vector_search_interactive(
df=data_source,
index_name=index_name,
search_column=search_column,
return_columns=output_columns,
num_results=num_results,
query_type=query_type,
)
elif dry_run is False:
df_results = vector_search_batch(
df=data_source,
index_name=index_name,
search_column=search_column,
return_columns=output_columns,
output_table_name=output_table_name,
num_results=num_results,
table_operation=table_operation,
primary_key_columns=primary_key_columns,
query_type=query_type,
)
result = {
"source_uuid": source_uuid,
"index_name": index_name,
"search_column": search_column,
"return_columns": output_columns,
"num_results": num_results,
"query_type": query_type,
"cost_component_name": cost_component_name,
"output_table_name": output_table_name,
"df": df_results,
"error_message": None,
}
return result