import logging
import requests
from urllib.parse import urljoin
from typing import Literal
from datetime import datetime
import pandas as pd
from databricks import sql
from databricks.connect import DatabricksSession
from flask_caching import Cache
from pyspark.sql.connect.session import SparkSession
from yipit_databricks_client import (
get_spark_session as ydbc_get_spark_session,
get_setting_from_environment,
is_inside_databricks,
)
from yipit_databricks_client.helpers.oauth import get_oauth_token
from visualization_toolkit.constants import APP_ENVIRONMENT
from visualization_toolkit.exceptions import InvalidEnvironmentException
from visualization_toolkit.helpers.dash.cache import (
get_query_cache_timestamp,
load_query_results_from_cache,
cache_query,
get_query_hash,
cached_results_path,
_resolve_redis_uri,
)
from visualization_toolkit.helpers.plotly.charts.core.chart import (
_normalize_input_data_to_pandas,
)
logger = logging.getLogger("databricks.sql")
cache = Cache(
config={
"CACHE_TYPE": "simple", # Options: 'redis', 'filesystem', etc.
"CACHE_DEFAULT_TIMEOUT": 300, # Cache timeout in seconds (5 min)
}
)
[docs]
def execute_query(
query: str,
sources: list[dict | str] = None,
use_cache: bool = False,
overwrite_cache: bool = False,
cache_uri: str = None,
return_type: Literal["list", "pandas_dataframe"] = "list",
) -> list[dict] | pd.DataFrame:
"""
Execute a SQL query and return the query results. This function includes utilities to cache query results offline
to avoid re-running them on the warehouse and improve performance / reduce costs.
When running the function make sure the following environment variables are set to connect to Databricks:
- ``DATABRICKS_SERVER_HOSTNAME``: Workspace URL that the SQL Warehouse resides in
- ``DATABRICKS_HTTP_PATH``: Path for the given SQL Warehouse to use
- ``DATABRICKS_ACCESS_TOKEN``: API token for a given Service Principal that should be used to authenticate to Databricks
:param query: SQL query string to execute
:param sources: List of query source tables that are used in the query. Specify this is required if caching query results as the offline cache will use this to index the results.
:param use_cache: When ``True``, any query results retrieved from the warehouse will be cached offline. If the query is executed again, the offline cache will retrieve results instead of executing on the SQL warehouse.
:param overwrite_cache: When ``True``, cached results will not be fetched and the query will be excuted on the warehouse and overwrite the offline cache. This should only be used for maintenance purposes to fix the cache.
:param cache_uri: Optional redis URI for the cache_uri to use. Will default to the ``ATLAS_CACHE_URI`` environment variable if not specified. If not specified, caching will not be enabled.
:param return_type: Optionally control whether results are returned as a ``list`` of dictionaries or as a pandas dataframe (``pandas_dataframe``). Default is list of dictionaries returned.
:return: List of query results for the query. When ``return_type==list``, results are returned as a list of dictionaries. Each dict is a row of the query. When return_type == ``pandas_dataframe``, a ``pd.DataFrame` will be returned``.
"""
# Must specify query sources in order to use caching feature
# The sources act as identifiers to be able to clear the cache results
sources = sources or []
cache_uri = _resolve_redis_uri(cache_uri)
cache_enabled = use_cache and len(sources) and (cache_uri is not None)
if cache_enabled and not overwrite_cache:
logger.info("Cache enabled for query execution ..")
query_hash = get_query_hash(query)
cache_timestamp = get_query_cache_timestamp(query_hash, cache_uri)
if cache_timestamp is not None:
s3_path = cached_results_path(query_hash, cache_timestamp)
try:
logger.warning(f"Loading cached results from {s3_path} ..")
cached_pdf = load_query_results_from_cache(s3_path)
return _normalize_return_type(cached_pdf, return_type)
except FileNotFoundError:
logger.warning(f"Cached results not found at {s3_path} ..")
logger.warning("Query will be executed on warehouse as a fallback ..")
# If inside databricks notebooks use its spark session to execut queries
if is_inside_databricks():
spark = ydbc_get_spark_session()
logger.info(f"Executing query with spark-connect session:\n {query} \n")
result = spark.sql(query).collect()
parsed = [row.asDict() for row in result]
if cache_enabled:
cache_query(
query,
parsed,
query_sources=sources,
cache_uri=cache_uri,
)
return _normalize_return_type(parsed, return_type)
_validate_databricks_credentials()
if APP_ENVIRONMENT.get("YDBC_OAUTH_CLIENT_ID") and APP_ENVIRONMENT.get(
"YDBC_OAUTH_SECRET"
):
logger.warning("Using oauth to get access token ..")
token = get_oauth_token(APP_ENVIRONMENT.get("YDBC_DEPLOYMENT_NAME"))
else:
logger.warning("Using fixed access token ..")
token = APP_ENVIRONMENT.get("DATABRICKS_ACCESS_TOKEN")
with sql.connect(
server_hostname=APP_ENVIRONMENT.get("DATABRICKS_SERVER_HOSTNAME"),
http_path=APP_ENVIRONMENT.get("DATABRICKS_HTTP_PATH"),
access_token=token,
) as conn:
with conn.cursor() as cursor:
logger.info(f"Executing query:\n {query} \n")
cursor.execute(query)
result = cursor.fetchall()
parsed = [row.asDict() for row in result]
if cache_enabled:
cache_query(
query,
parsed,
query_sources=sources,
cache_uri=cache_uri,
)
return _normalize_return_type(parsed, return_type)
def _validate_databricks_credentials():
server_hostname = APP_ENVIRONMENT.get("DATABRICKS_SERVER_HOSTNAME")
http_path = APP_ENVIRONMENT.get("DATABRICKS_HTTP_PATH")
access_token = APP_ENVIRONMENT.get("DATABRICKS_ACCESS_TOKEN")
oauth_available = APP_ENVIRONMENT.get(
"YDBC_OAUTH_CLIENT_ID"
) and APP_ENVIRONMENT.get("YDBC_OAUTH_SECRET")
# If not inside databricks notebooks, use databricks SQL to execute the query
if (
(server_hostname is None)
or (http_path is None)
or (access_token is None and not oauth_available)
):
raise InvalidEnvironmentException(
"Databricks credentials not found, "
"please check that DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_ACCESS_TOKEN are set in the environment. "
"If using oauth, please check that YDBC_OAUTH_CLIENT_ID and YDBC_OAUTH_SECRET are set in the environment instead of DATABRICKS_ACCESS_TOKEN."
)
def _normalize_return_type(
results: list[dict] | pd.DataFrame,
return_type: Literal["list", "pandas_dataframe"],
) -> pd.DataFrame | list[dict]:
match return_type:
case "list":
if isinstance(results, list):
return results
return results.to_dict("records")
case "pandas_dataframe":
return _normalize_input_data_to_pandas(results)
[docs]
def get_spark_session() -> SparkSession:
"""
Create a ``pyspark.SparkSession`` to execute queries on databricks via spark-connect.
This should be used when pyspark is necessary for a data app and comes with trade-offs:
- pyspark is more powerful in generating transformations dynamically
- spark-connect however has less caching and concurrency capabilities than databricks SQL.
- spark-connect can also be more costly than databricks SQL.
To use a pyspark session locally, set ``YDBC_DEPLOYMENT_NAME``, ``YDBC_OAUTH_CLIENT_ID``, and ``YDBC_OAUTH_SECRET`` as environment variables.
These are typically set via infisical by the Platform Engineering team.
:return:
"""
if APP_ENVIRONMENT.get_bool("ATLAS_IS_DATABRICKS_APP", False):
def _get_spark_session():
# Use databricks app preset credentials to create a session
return DatabricksSession.builder.serverless(True).getOrCreate()
else:
def _get_spark_session():
# Fetch the SparkSession via Databricks Connect
# Following environment variables are required in CircleCI:
# - YDBC_DEPLOYMENT_NAME
# - YDBC_CLUSTER_ID (classic only)
# - YDBC_OAUTH_CLIENT_ID
# - YDBC_OAUTH_SECRET
return ydbc_get_spark_session(
serverless=hasattr(DatabricksSession.builder, "serverless"),
deployment_name=get_setting_from_environment("YDBC_DEPLOYMENT_NAME"),
)
# Use databricks app preset credentials to create a session
spark = _get_spark_session()
# Restart spark session every 600 seconds by checking
# if the cache has a timestamp marker for the last time the session was restarted
spark_timestamp = cache.get("SPARK_SESSION_TIMESTAMP")
now = datetime.utcnow()
if not spark_timestamp:
print("Restarting spark session after 600 seconds ..")
spark.stop()
spark = _get_spark_session()
cache.set(
"SPARK_SESSION_TIMESTAMP",
now.isoformat(),
timeout=600,
)
return spark