Source code for ml_toolkit.functions.llm.run_vector_search.interface

from typing import List, Literal, Optional, Union
import uuid

from pyspark.sql import DataFrame
from yipit_databricks_client import get_spark_session
from yipit_databricks_client.helpers.telemetry import track_usage

from ml_toolkit.functions.llm.constants import MAX_NUM_ROWS_INTERACTIVE_MODE
from ml_toolkit.functions.llm.run_vector_search.function import (
    vector_search_batch,
    vector_search_interactive,
)
from ml_toolkit.ops.helpers.exceptions import MLOpsToolkitTooManyRowsForInteractiveUsage
from ml_toolkit.ops.helpers.logger import get_logger
from ml_toolkit.ops.helpers.validation import (
    assert_columns_exists_in_dataframe,
    assert_columns_exists_in_table,
    assert_table_exists,
)


def resolve_dry_run(data_source: DataFrame, dry_run: bool = None) -> bool:
    num_rows = data_source.count()
    if dry_run is True and num_rows > MAX_NUM_ROWS_INTERACTIVE_MODE:
        raise MLOpsToolkitTooManyRowsForInteractiveUsage(
            num_rows=num_rows, max_num_rows=MAX_NUM_ROWS_INTERACTIVE_MODE
        )

    if dry_run is None and num_rows <= MAX_NUM_ROWS_INTERACTIVE_MODE:
        dry_run = True
    elif dry_run is None and num_rows > MAX_NUM_ROWS_INTERACTIVE_MODE:
        dry_run = False

    return dry_run


@track_usage