Source code for visualization_toolkit.helpers.dash.query

import re
import string

from visualization_toolkit.constants import JINJA_ENVIRONMENT
from visualization_toolkit.exceptions import InvalidInputException


[docs] def sql_query_from_template( template_string: str, parameters: dict, sources: list[dict | str], ) -> str: """ Use Jinja2 to generate a SQL query string given a template string, parameters, and sources (tables). :param template_string: Jinja2 template passed in as a string :param parameters: Key value pairs of parameters to pass to the template. In the template these can be accessed as ``parameters.<key>``. Any JSON compatible data types are allowed. :param sources: List of table sources passed in to the query template. :return: SQL query that is the output of the template + parameters + sources """ context = { "parameters": parameters, "sources": _normalize_sources(sources), } template_render = ( JINJA_ENVIRONMENT.from_string(template_string).render(**context).strip() ) return template_render
def column_to_title(column: str) -> str: return string.capwords(column.replace("_", " ")) def _normalize_sources(sources: list[str | dict]) -> list[dict]: cleaned_sources = [] for source in sources: if isinstance(source, str): _validate_source(source) catalog_name, database_name, table_name = source.split(".") cleaned_sources.append( { "catalog_name": catalog_name, "database_name": database_name, "table_name": table_name, "full_name": f"{catalog_name}.{database_name}.{table_name}", } ) elif isinstance(source, dict) and "full_name" not in source: _validate_source_dict(source) full_name = f"{source['catalog_name']}.{source['database_name']}.{source['table_name']}" cleaned_sources.append(source | {"full_name": full_name}) elif isinstance(source, dict): _validate_source_dict(source) cleaned_sources.append(source) else: raise InvalidInputException( f"sources must be provided as a list of dicts or as tables names using the 3-level namespace" ) return cleaned_sources def _validate_source(source: str): matches = re.findall(r"^([a-z0-9_]+)\.([a-z0-9_]+)\.([a-z0-9_]+)$", source) if not matches: raise InvalidInputException( "source table name must use 3-level namespace and alphanumeric and underscore characters" ) def _validate_source_dict(source: dict): if "full_name" in source: return if ( ("catalog_name" not in source) or ("database_name" not in source) or ("table_name" not in source) ): raise InvalidInputException( "source table name when specified as a dictionary must either have full_name or a combination of catalog_name, database_name, and table_name as keys" )