Source code for visualization_toolkit.helpers.dash.query
import re
import string
from visualization_toolkit.constants import JINJA_ENVIRONMENT
from visualization_toolkit.exceptions import InvalidInputException
[docs]
def sql_query_from_template(
template_string: str,
parameters: dict,
sources: list[dict | str],
) -> str:
"""
Use Jinja2 to generate a SQL query string given a template string, parameters, and sources (tables).
:param template_string: Jinja2 template passed in as a string
:param parameters: Key value pairs of parameters to pass to the template. In the template these can be accessed as ``parameters.<key>``. Any JSON compatible data types are allowed.
:param sources: List of table sources passed in to the query template.
:return: SQL query that is the output of the template + parameters + sources
"""
context = {
"parameters": parameters,
"sources": _normalize_sources(sources),
}
template_render = (
JINJA_ENVIRONMENT.from_string(template_string).render(**context).strip()
)
return template_render
def column_to_title(column: str) -> str:
return string.capwords(column.replace("_", " "))
def _normalize_sources(sources: list[str | dict]) -> list[dict]:
cleaned_sources = []
for source in sources:
if isinstance(source, str):
_validate_source(source)
catalog_name, database_name, table_name = source.split(".")
cleaned_sources.append(
{
"catalog_name": catalog_name,
"database_name": database_name,
"table_name": table_name,
"full_name": f"{catalog_name}.{database_name}.{table_name}",
}
)
elif isinstance(source, dict) and "full_name" not in source:
_validate_source_dict(source)
full_name = f"{source['catalog_name']}.{source['database_name']}.{source['table_name']}"
cleaned_sources.append(source | {"full_name": full_name})
elif isinstance(source, dict):
_validate_source_dict(source)
cleaned_sources.append(source)
else:
raise InvalidInputException(
f"sources must be provided as a list of dicts or as tables names using the 3-level namespace"
)
return cleaned_sources
def _validate_source(source: str):
matches = re.findall(r"^([a-z0-9_]+)\.([a-z0-9_]+)\.([a-z0-9_]+)$", source)
if not matches:
raise InvalidInputException(
"source table name must use 3-level namespace and alphanumeric and underscore characters"
)
def _validate_source_dict(source: dict):
if "full_name" in source:
return
if (
("catalog_name" not in source)
or ("database_name" not in source)
or ("table_name" not in source)
):
raise InvalidInputException(
"source table name when specified as a dictionary must either have full_name or a combination of catalog_name, database_name, and table_name as keys"
)