Source code for visualization_toolkit.helpers.dash.databricks

import logging
import requests
from urllib.parse import urljoin
from typing import Literal
from datetime import datetime

import pandas as pd
from databricks import sql
from databricks.connect import DatabricksSession
from flask_caching import Cache
from pyspark.sql.connect.session import SparkSession

from yipit_databricks_client import (
    get_spark_session as ydbc_get_spark_session,
    get_setting_from_environment,
    is_inside_databricks,
)
from yipit_databricks_client.helpers.oauth import get_oauth_token

from visualization_toolkit.constants import APP_ENVIRONMENT
from visualization_toolkit.exceptions import InvalidEnvironmentException
from visualization_toolkit.helpers.dash.cache import (
    get_query_cache_timestamp,
    load_query_results_from_cache,
    cache_query,
    get_query_hash,
    cached_results_path,
    _resolve_redis_uri,
)
from visualization_toolkit.helpers.plotly.charts.core.chart import (
    _normalize_input_data_to_pandas,
)

logger = logging.getLogger("databricks.sql")

cache = Cache(
    config={
        "CACHE_TYPE": "simple",  # Options: 'redis', 'filesystem', etc.
        "CACHE_DEFAULT_TIMEOUT": 300,  # Cache timeout in seconds (5 min)
    }
)


[docs] def execute_query( query: str, sources: list[dict | str] = None, use_cache: bool = False, overwrite_cache: bool = False, cache_uri: str = None, return_type: Literal["list", "pandas_dataframe"] = "list", ) -> list[dict] | pd.DataFrame: """ Execute a SQL query and return the query results. This function includes utilities to cache query results offline to avoid re-running them on the warehouse and improve performance / reduce costs. When running the function make sure the following environment variables are set to connect to Databricks: - ``DATABRICKS_SERVER_HOSTNAME``: Workspace URL that the SQL Warehouse resides in - ``DATABRICKS_HTTP_PATH``: Path for the given SQL Warehouse to use - ``DATABRICKS_ACCESS_TOKEN``: API token for a given Service Principal that should be used to authenticate to Databricks :param query: SQL query string to execute :param sources: List of query source tables that are used in the query. Specify this is required if caching query results as the offline cache will use this to index the results. :param use_cache: When ``True``, any query results retrieved from the warehouse will be cached offline. If the query is executed again, the offline cache will retrieve results instead of executing on the SQL warehouse. :param overwrite_cache: When ``True``, cached results will not be fetched and the query will be excuted on the warehouse and overwrite the offline cache. This should only be used for maintenance purposes to fix the cache. :param cache_uri: Optional redis URI for the cache_uri to use. Will default to the ``ATLAS_CACHE_URI`` environment variable if not specified. If not specified, caching will not be enabled. :param return_type: Optionally control whether results are returned as a ``list`` of dictionaries or as a pandas dataframe (``pandas_dataframe``). Default is list of dictionaries returned. :return: List of query results for the query. When ``return_type==list``, results are returned as a list of dictionaries. Each dict is a row of the query. When return_type == ``pandas_dataframe``, a ``pd.DataFrame` will be returned``. """ # Must specify query sources in order to use caching feature # The sources act as identifiers to be able to clear the cache results sources = sources or [] cache_uri = _resolve_redis_uri(cache_uri) cache_enabled = use_cache and len(sources) and (cache_uri is not None) if cache_enabled and not overwrite_cache: logger.info("Cache enabled for query execution ..") query_hash = get_query_hash(query) cache_timestamp = get_query_cache_timestamp(query_hash, cache_uri) if cache_timestamp is not None: s3_path = cached_results_path(query_hash, cache_timestamp) try: logger.warning(f"Loading cached results from {s3_path} ..") cached_pdf = load_query_results_from_cache(s3_path) return _normalize_return_type(cached_pdf, return_type) except FileNotFoundError: logger.warning(f"Cached results not found at {s3_path} ..") logger.warning("Query will be executed on warehouse as a fallback ..") # If inside databricks notebooks use its spark session to execut queries if is_inside_databricks(): spark = ydbc_get_spark_session() logger.info(f"Executing query with spark-connect session:\n {query} \n") result = spark.sql(query).collect() parsed = [row.asDict() for row in result] if cache_enabled: cache_query( query, parsed, query_sources=sources, cache_uri=cache_uri, ) return _normalize_return_type(parsed, return_type) _validate_databricks_credentials() if APP_ENVIRONMENT.get("YDBC_OAUTH_CLIENT_ID") and APP_ENVIRONMENT.get( "YDBC_OAUTH_SECRET" ): logger.warning("Using oauth to get access token ..") token = get_oauth_token(APP_ENVIRONMENT.get("YDBC_DEPLOYMENT_NAME")) else: logger.warning("Using fixed access token ..") token = APP_ENVIRONMENT.get("DATABRICKS_ACCESS_TOKEN") with sql.connect( server_hostname=APP_ENVIRONMENT.get("DATABRICKS_SERVER_HOSTNAME"), http_path=APP_ENVIRONMENT.get("DATABRICKS_HTTP_PATH"), access_token=token, ) as conn: with conn.cursor() as cursor: logger.info(f"Executing query:\n {query} \n") cursor.execute(query) result = cursor.fetchall() parsed = [row.asDict() for row in result] if cache_enabled: cache_query( query, parsed, query_sources=sources, cache_uri=cache_uri, ) return _normalize_return_type(parsed, return_type)
def _validate_databricks_credentials(): server_hostname = APP_ENVIRONMENT.get("DATABRICKS_SERVER_HOSTNAME") http_path = APP_ENVIRONMENT.get("DATABRICKS_HTTP_PATH") access_token = APP_ENVIRONMENT.get("DATABRICKS_ACCESS_TOKEN") oauth_available = APP_ENVIRONMENT.get( "YDBC_OAUTH_CLIENT_ID" ) and APP_ENVIRONMENT.get("YDBC_OAUTH_SECRET") # If not inside databricks notebooks, use databricks SQL to execute the query if ( (server_hostname is None) or (http_path is None) or (access_token is None and not oauth_available) ): raise InvalidEnvironmentException( "Databricks credentials not found, " "please check that DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_ACCESS_TOKEN are set in the environment. " "If using oauth, please check that YDBC_OAUTH_CLIENT_ID and YDBC_OAUTH_SECRET are set in the environment instead of DATABRICKS_ACCESS_TOKEN." ) def _normalize_return_type( results: list[dict] | pd.DataFrame, return_type: Literal["list", "pandas_dataframe"], ) -> pd.DataFrame | list[dict]: match return_type: case "list": if isinstance(results, list): return results return results.to_dict("records") case "pandas_dataframe": return _normalize_input_data_to_pandas(results)
[docs] def get_spark_session() -> SparkSession: """ Create a ``pyspark.SparkSession`` to execute queries on databricks via spark-connect. This should be used when pyspark is necessary for a data app and comes with trade-offs: - pyspark is more powerful in generating transformations dynamically - spark-connect however has less caching and concurrency capabilities than databricks SQL. - spark-connect can also be more costly than databricks SQL. To use a pyspark session locally, set ``YDBC_DEPLOYMENT_NAME``, ``YDBC_OAUTH_CLIENT_ID``, and ``YDBC_OAUTH_SECRET`` as environment variables. These are typically set via infisical by the Platform Engineering team. :return: """ if APP_ENVIRONMENT.get_bool("ATLAS_IS_DATABRICKS_APP", False): def _get_spark_session(): # Use databricks app preset credentials to create a session return DatabricksSession.builder.serverless(True).getOrCreate() else: def _get_spark_session(): # Fetch the SparkSession via Databricks Connect # Following environment variables are required in CircleCI: # - YDBC_DEPLOYMENT_NAME # - YDBC_CLUSTER_ID (classic only) # - YDBC_OAUTH_CLIENT_ID # - YDBC_OAUTH_SECRET return ydbc_get_spark_session( serverless=hasattr(DatabricksSession.builder, "serverless"), deployment_name=get_setting_from_environment("YDBC_DEPLOYMENT_NAME"), ) # Use databricks app preset credentials to create a session spark = _get_spark_session() # Restart spark session every 600 seconds by checking # if the cache has a timestamp marker for the last time the session was restarted spark_timestamp = cache.get("SPARK_SESSION_TIMESTAMP") now = datetime.utcnow() if not spark_timestamp: print("Restarting spark session after 600 seconds ..") spark.stop() spark = _get_spark_session() cache.set( "SPARK_SESSION_TIMESTAMP", now.isoformat(), timeout=600, ) return spark