import os
import string
from typing import Literal, Optional, Any
from pathlib import Path
import base64
from datetime import date, datetime
import pandas as pd
from plotly import graph_objects as go
from pyspark.sql import DataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from visualization_toolkit.exceptions import InvalidInputException
from visualization_toolkit.helpers.plotly.charts.core.axis import Axis
from visualization_toolkit.helpers.plotly.charts.core.series import Series, series
from visualization_toolkit.helpers.plotly.charts.core.annotation import (
Annotation,
)
from visualization_toolkit.helpers.plotly.charts.core.shading import (
ShadeX,
ShadeY,
)
from visualization_toolkit.helpers.plotly.colors import resolve_color
from visualization_toolkit.helpers.plotly.theme import (
HOVER_LABEL_BACKGROUND_COLOR,
TRANSPARENT,
LEGEND_Y_POSITION,
CHART_LOGO_X_POSITION,
CHART_LOGO_MEDIUM_X_POSITION,
CHART_LOGO_X_SIZE,
CHART_LOGO_Y_SIZE,
WHITE,
BACKGROUND_SECONDARY,
COLORS,
)
from visualization_toolkit.helpers.plotly.charts.core.figure import generate_base_figure
BASE_DIR = Path(__file__).resolve().parent.parent.parent.parent.parent
[docs]
def chart(
df: DataFrame | pd.DataFrame | list[dict[str, Any]],
chart_series: list[Series],
x_axis: Axis,
y1_axis: Optional[Axis] = None,
y2_axis: Optional[Axis] = None,
annotations: Optional[list[Annotation]] = None,
shaded_regions: Optional[list[ShadeX | ShadeY]] = None,
include_logo: bool = False,
theme: Optional[dict] = None,
custom_options: Optional[dict] = None,
) -> go.Figure:
"""
Standard function generate a plotly-based chard with standard company styling.
The ``chart`` function works closely with the other toolkit building blocks for charts: ``axis``, ``series``, and ``annotations``.
The input data to the chart can be passed in as a spark dataframe, pandas dataframe, or a list of dictionaries. The input data
is normalized to handle the remaining operations to be supplied to plotly to render the chart.
:param df: Input data to visualize. The columns from this data will be used for the ``chart_series`` and various axes.
:param chart_series: A list of ``series`` instances that will be plotted on this chart figure.
:param x_axis: The axis configuration for the x-axis of this chart figure.
:param y1_axis: The axis configuration for the y1-axis of this chart figure.
:param y2_axis: The axis configuration for the y2-axis of this chart figure. By default, no y2 axis is included.
:param annotations: A list of ``annotations`` to include on the chart.
:param shaded_regions: A list of ``shade_x`` and/or ``shade_y`` to include on the chart to shade regions.
:param include_logo: Optionally include the YipitData logo at the bottom right of the chart. Default is False (no logo added).
:param theme: Optionally changes the theme used to format the chart. Defaults to the ``YD_CLASSIC_THEME`` if not set.
:param custom_options: Optionally include any options to pass into figure.update_layout. This is an escape hatch for final adjustments to a chart.
:return: Plotly ``figure`` object that can be displayed in databricks or supplied to a dash app as a property of the ``dcc.Graph`` component.
Examples
^^^^^^^^^^^^^
.. code-block:: python
:caption: Common example of creating multiple series on a chart. Notice that each series corresponds to a different column on the dataset.
from visualization_toolkit.helpers.plotly import chart, axis, series, annotation
fig = chart(
df,
x_axis=axis(
column_name="year",
label="Year",
),
chart_series=[
series(
column_name="australia",
label="Australia",
),
series(
column_name="new_zealand",
label="New Zealand",
),
],
y1_axis=axis(
label="Life Expectancy",
axis_type="number",
),
)
display(fig)
"""
# validate inputs
_validate_axes(
chart_series=chart_series,
x_axis=x_axis,
y1_axis=y1_axis,
y2_axis=y2_axis,
)
# initialize blank figure object
if theme is not None:
fig = generate_base_figure(theme=theme)
else:
fig = generate_base_figure()
annotations = annotations or []
shaded_regions = shaded_regions or []
# standardize data in pandas
# truncate data on x-axis if given
pdf = _normalize_input_data_to_pandas(df)
pdf = filter_pdf_for_x_axis_bounds(pdf, x_axis)
# organize series for axes
x_axis.location = "x"
x_series_data = _sorted_x_data(pdf, x_axis)
x_configuration = {
"data_min": x_series_data.min(),
"data_max": x_series_data.max(),
"series": x_series_data,
}
y1_axis.location = "y1"
y1_configuration = {
"data_min": None,
"data_max": None,
"series": [],
"axis": y1_axis,
}
if y2_axis:
y2_axis.location = "y2"
y2_configuration = {
"data_min": None,
"data_max": None,
"series": [],
"axis": y2_axis,
}
# pivot data if needed for chart
if any((s.pivot_column_name is not None for s in chart_series)):
pivoted_series = [s for s in chart_series if s.pivot_column_name is not None]
if len([s for s in pivoted_series if s.include_all_categories]) > 1:
raise InvalidInputException(
"Only one series can be pivoted per chart. "
"Either pivot the input data ahead of time or adjust the series passed to the chart function."
)
pivoted_pdf = pivot_pdf(
pdf,
x_axis,
pivoted_series,
pivoted_series[0].pivot_column_name,
)
else:
pivoted_pdf = None
pdf, pivoted_pdf, x_configuration = filter_for_null_data_points(
pdf,
x_axis,
chart_series,
x_configuration=x_configuration,
pivoted_pdf=pivoted_pdf,
)
y1_min, y1_max, y2_min, y2_max = generate_y_axes_bounds(
pdf,
chart_series,
x_axis,
pivoted_pdf=pivoted_pdf,
)
y1_configuration["data_min"] = y1_min
y1_configuration["data_max"] = y1_max
y2_configuration["data_min"] = y2_min
y2_configuration["data_max"] = y2_max
for chart_series_item in chart_series:
series_to_add = [chart_series_item]
# If pivoting the data and including all categories
# generate a chart_series for each category dynamically using the pivoted_pdf
pivot_column_name = chart_series_item.pivot_column_name
include_all_categories = chart_series_item.include_all_categories
if pivot_column_name is not None and include_all_categories:
categories = chart_series_item.get_categories(pdf)
series_to_add = [
series(
column_name=category,
location=chart_series_item.location,
label=_resolve_label(
None,
category,
chart_series_item.column_name,
),
mode=chart_series_item.mode,
y_data=pivoted_pdf[chart_series_item.column_name][category],
shape=chart_series_item.shape,
hover_format=chart_series_item.hover_format,
color=chart_series_item.color,
extra_options=chart_series_item.extra_options,
)
for category in categories
]
# Otherwise if pivoted for a subset of categories,
# modify specified series using the pivoted pdf to plot the correct data
elif pivot_column_name:
series_to_add = [
series(
column_name=chart_series_item.category_name,
location=chart_series_item.location,
label=_resolve_label(
chart_series_item.label,
chart_series_item.category_name,
chart_series_item.column_name,
),
mode=chart_series_item.mode,
y_data=pivoted_pdf[chart_series_item.column_name][
chart_series_item.category_name
],
shape=chart_series_item.shape,
hover_format=chart_series_item.hover_format,
color=chart_series_item.color,
extra_options=chart_series_item.extra_options,
)
]
if chart_series_item.location == "y1":
for _chart_series_item in series_to_add:
y1_configuration["series"].append(_chart_series_item)
elif chart_series_item.location == "y2":
for _chart_series_item in series_to_add:
y2_configuration["series"].append(_chart_series_item)
# Add each y1 series to the figure object
for idx, chart_series_item in enumerate(y1_configuration["series"]):
if chart_series_item.shade_series is not None:
chart_series_item.shade_series.add_to_figure(
fig=fig,
pdf=pdf,
x_axis=x_axis,
y_axis=y1_axis,
base_series=chart_series_item,
base_series_idx=idx,
)
trace = chart_series_item.resolved_trace(
pdf,
x_axis,
y1_axis,
idx,
)
fig.add_trace(trace)
# Add each y2 series to the figure object
for idx, chart_series_item in enumerate(y2_configuration["series"]):
if chart_series_item.shade_series is not None:
chart_series_item.shade_series.add_to_figure(
fig=fig,
pdf=pdf,
x_axis=x_axis,
y_axis=y2_axis,
base_series=chart_series_item,
base_series_idx=idx,
)
trace = chart_series_item.resolved_trace(
pdf,
x_axis,
y2_axis,
idx,
)
fig.add_trace(trace)
# Configure x-axis
general_options = {}
apply_offset = True
for s in chart_series:
if s.is_bar_type and s.is_stacked:
# Use stacked bar charts and ensure axis is scaled correctly
general_options = {"barmode": "relative"}
elif s.mode == "area":
apply_offset = False
if x_axis.axis_type == "date":
x_axis_options = x_axis.date_axis_configuration(
x_configuration["series"],
apply_offset=apply_offset,
)
elif any([s.is_bar_type for s in chart_series]):
x_axis_options = x_axis.category_axis_configuration(x_configuration["series"])
else:
x_axis_options = x_axis.axis_configuration(
x_configuration["data_min"], x_configuration["data_max"]
)
# configure y1-axis
y1_axis_options = y1_axis.axis_configuration(
y1_configuration["data_min"], y1_configuration["data_max"]
)
# configure y2-axis
if len(y2_configuration["series"]):
y2_axis_options = y2_axis.axis_configuration(
y2_configuration["data_min"], y2_configuration["data_max"]
)
else:
y2_axis_options = {}
# Add annotations
for annotation in annotations:
y_options = y1_axis_options["yaxis"]
if annotation.axis_location == "x" and "range" in y_options:
y_range = y_options["range"]
elif annotation.axis_location == "x":
y_range = y_options["tickvals"][0], y_options["tickvals"][-1]
else:
y_range = None
annotation.register_in_figure(
fig,
# pass in the y-range bounds if the annotation is on the x-axis,
# this is to avoid a plotly bug with add_vlines
# https://github.com/plotly/plotly.py/issues/3065
y_range=y_range,
)
# Add shaded regions
for region in shaded_regions:
region.add_to_figure(fig)
# finalize figure layout
fig.update_layout(
legend_title_text="",
legend_y=LEGEND_Y_POSITION,
# Controlling x-axis
**x_axis_options,
**y1_axis_options,
**y2_axis_options,
**general_options,
# To customize hover label format
hovermode="x unified",
hoverlabel_bgcolor=WHITE,
hoverlabel_bordercolor=BACKGROUND_SECONDARY,
xaxis_spikecolor=COLORS["dark-grey"][700],
xaxis_spikethickness=1,
)
# Clustered bar charts need additional customization
if any([s.resolved_mode == "clustered_bar" for s in chart_series]):
fig.update_layout(
hovermode="closest",
xaxis_tickangle=0,
showlegend=False,
)
fig.update_xaxes(
showspikes=False, tickfont={"color": resolve_color("dark-grey")}
)
fig.update_yaxes(showspikes=False)
if include_logo:
_add_yipitdata_logo_to_chart(
fig,
y1_axis=y1_axis,
y2_axis=y2_axis,
)
if custom_options is not None:
fig.update_layout(**custom_options)
return fig
def _normalize_input_data_to_pandas(
df: DataFrame | pd.DataFrame | list[dict[str, Any]]
) -> pd.DataFrame:
if isinstance(df, (DataFrame, ConnectDataFrame)):
return df.toPandas()
elif isinstance(df, pd.DataFrame):
return df
elif isinstance(df, list):
return pd.DataFrame(df)
raise ValueError("Invalid type for dataframe")
def _resolve_label(
chart_series_label: str | None, category_name: str | None, column_name: str
) -> str:
# If a specific title is passed in, use that
if chart_series_label:
return chart_series_label
elif isinstance(category_name, (date, datetime)):
return category_name.isoformat()
# For acronyms, return as-is, ex: UCAN
elif category_name.upper() == category_name:
return category_name
# Otherwise titleize the category or column name
elif category_name:
return string.capwords(category_name.replace("_", " "))
return string.capwords(column_name.replace("_", " "))
def _add_yipitdata_logo_to_chart(
fig: go.Figure,
y1_axis: Axis = None,
y2_axis: Axis = None,
):
# Determine x positioning based on a % of plot area
# simple factor in determining plot area size is the number of y-axis titles
# more titles, smaller plot area, requires higher % for x-position so logo is placed on edge of plot
title_count = 0
if y1_axis is not None and y1_axis.label is not None:
title_count += 1
if y2_axis is not None and y2_axis.label is not None:
title_count += 1
match title_count:
case 0 | 1:
x = CHART_LOGO_X_POSITION
case 2:
x = CHART_LOGO_MEDIUM_X_POSITION
# Load the logo image from the library and attach to the figure as base64 data
with open(
os.path.join(BASE_DIR, "static/logos/yipitdata-logo-dark.png"), "rb"
) as f:
_input = f.read()
binary_data = base64.b64encode(_input)
encoded = binary_data.decode("utf-8")
# Place logo on the right edge of x-axis and aligned with legend on the bottom
fig.update_layout(
images=[
dict(
source=f"data:image/png;base64,{encoded}",
xref="paper",
yref="paper",
x=x,
y=fig._layout["legend"]["y"],
sizex=CHART_LOGO_X_SIZE,
sizey=CHART_LOGO_Y_SIZE,
layer="below",
xanchor="right",
yanchor="bottom",
)
],
)
def _validate_axes(
chart_series: list[Series],
x_axis: Axis = None,
y1_axis: Axis = None,
y2_axis: Axis = None,
):
if x_axis is None:
raise InvalidInputException("x_axis must be supplied")
if not isinstance(x_axis, Axis):
raise InvalidInputException(
f"Invalid type, x_axis must be an ``axis`` instance, received: {type(x_axis)}"
)
for _series in chart_series:
match _series.location:
case "y1":
if y1_axis is None:
raise InvalidInputException("y1_axis must be supplied")
if not isinstance(y1_axis, Axis):
raise InvalidInputException(
"Invalid type, y1_axis must be an ``axis`` instance"
)
case "y2":
if y2_axis is None:
raise InvalidInputException("y2_axis must be supplied")
if not isinstance(y2_axis, Axis):
raise InvalidInputException(
"Invalid type, y2_axis must be an ``axis`` instance"
)
case _:
raise InvalidInputException(
"Invalid `location` for series, must be either 'y1' or 'y2'"
)
def generate_y_axes_bounds(
pdf: pd.DataFrame,
chart_series: list[Series],
x_axis: Axis,
pivoted_pdf: pd.DataFrame = None,
) -> (Optional[float], Optional[float], Optional[float], Optional[float]):
"""
Returns the upper and lower bounds for y1 and y2 axes of the given series data.
If an axis is not used by any series, the bounds will be None.
These bounds are used to position the axes min and max with more human-readable values
so that all plotted series are within the graph's frame.
:param pdf:
:param chart_series:
:param x_axis:
:param include_all_categories:
:param pivoted_pdf:
:param pivot_column_name:
:return:
"""
y1_min, y1_max = None, None
y2_min, y2_max = None, None
for chart_series_item in chart_series:
# Account for shaded region boundary columns if part of the series
pivot_column_name = chart_series_item.pivot_column_name
include_all_categories = chart_series_item.include_all_categories
if chart_series_item.shade_series is not None:
columns_with_boundaries = [
chart_series_item.column_name,
*chart_series_item.shade_series.boundary_column_names,
]
else:
columns_with_boundaries = [chart_series_item.column_name]
if chart_series_item.is_stacked:
# If this is a stacked series that needs to be pivoted, take the sum for each x-axis period across all series
# and use the min / max of those period totals to generate bounds for the series on the y-axis
if include_all_categories and pivot_column_name is not None:
grouped_series = pdf.groupby(x_axis.column_name)[
columns_with_boundaries
].sum()
elif pivot_column_name is not None:
# If a stacked series that needs to be pivoted with some categories included, filter for the specified categories,
# and then take the sum for each x-axis period to generate bounds for the series on y-axis
grouped_series = (
pdf[
pdf[pivot_column_name].isin(
[
s.category_name
for s in chart_series
if s.location == chart_series_item.location
]
)
]
.groupby(x_axis.column_name)[columns_with_boundaries]
.sum()
)
else:
# Otherwise for a stacked set of series, calculate the totals along the x-axis
# for the series columns specified and use those totals to set the bound
pdf["__stacked_total__"] = sum(
[
pdf[s.column_name]
for s in chart_series
if s.location == chart_series_item.location
]
)
grouped_series = pdf.groupby(x_axis.column_name)[
"__stacked_total__"
].sum()
# Newer versions of Pandas do not support converting a single-item series to float
# so need to normalize to a numerical value by indexing into the series if possible
# Older versions return a numpy.int64 value so that can be casted directly
agg_min = grouped_series.min()
if hasattr(agg_min, "iloc"):
series_min = float(agg_min.iloc[0])
else:
series_min = float(agg_min)
agg_max = grouped_series.max()
if hasattr(agg_max, "iloc"):
series_max = float(agg_max.iloc[0])
else:
series_max = float(agg_max)
# If working with a pivoted dataset and not stacked
# use the category min and max value to generate bounds for the series on the y-axis
elif pivoted_pdf is not None and chart_series_item.category_name is not None:
grouped_series = pivoted_pdf[chart_series_item.column_name][
chart_series_item.category_name
]
series_min = float(grouped_series.min())
series_max = float(grouped_series.max())
elif len(columns_with_boundaries) > 1:
series_min = float(pdf[columns_with_boundaries].min().min())
series_max = float(pdf[columns_with_boundaries].max().max())
else:
series_min = float(pdf[chart_series_item.column_name].min())
series_max = float(pdf[chart_series_item.column_name].max())
# Update global min/max bounds if it exceeds existing bounds from prior series
if chart_series_item.location == "y1":
if y1_min is None or series_min < y1_min:
y1_min = series_min
if y1_max is None or series_max > y1_max:
y1_max = series_max
elif chart_series_item.location == "y2":
if y2_min is None or series_min < y2_min:
y2_min = series_min
if y2_max is None or series_max > y2_max:
y2_max = series_max
return y1_min, y1_max, y2_min, y2_max
def filter_pdf_for_x_axis_bounds(
pdf: pd.DataFrame,
x_axis: Axis,
) -> pd.DataFrame:
# If axis is a date type, handle datetime casting before filtering
# Only apply filters on axis min and max if specified
if x_axis.axis_min is not None and x_axis.axis_type == "date":
pdf[x_axis.column_name] = pd.to_datetime(pdf[x_axis.column_name])
pdf = pdf[pdf[x_axis.column_name].dt.date >= x_axis.axis_min]
elif x_axis.axis_min is not None:
pdf = pdf[pdf[x_axis.column_name] >= x_axis.axis_min]
if x_axis.axis_max is not None and x_axis.axis_type == "date":
pdf[x_axis.column_name] = pd.to_datetime(pdf[x_axis.column_name])
pdf = pdf[pdf[x_axis.column_name].dt.date <= x_axis.axis_max]
elif x_axis.axis_max is not None:
pdf = pdf[pdf[x_axis.column_name] <= x_axis.axis_max]
return pdf
def pivot_pdf(
pdf: pd.DataFrame,
x_axis: Axis,
chart_series: list[Series],
pivot_column_name: str,
) -> pd.DataFrame:
pivoted_pdf = pdf.pivot(
index=x_axis.column_name,
columns=pivot_column_name,
values=list(
set([chart_series_item.column_name for chart_series_item in chart_series])
),
)
return pivoted_pdf
def filter_for_null_data_points(
pdf: pd.DataFrame,
x_axis: Axis,
chart_series: list[Series],
x_configuration: dict,
pivoted_pdf: pd.DataFrame = None,
):
# Truncate null data points along the x-axis if this flag is specified
# Get the first data point with any non-null values for the selected columns of the dataframe
# and filter the data to start at that point onwards
possible_start_points = []
if x_axis.start_at_non_null_values:
for chart_series_item in chart_series:
# If pivoting and using all categories, then get the first point with any non-null value
# across all categories,
should_pivot = chart_series_item.pivot_column_name is not None
if should_pivot and chart_series_item.include_all_categories:
any_column_not_null = ~pivoted_pdf.isnull().all(axis=1)
first_non_null_data_point = pivoted_pdf[any_column_not_null].index.min()
# If pivoting with some categories,
# then get the first point with any non-null data for those categories specifically
elif should_pivot:
any_column_not_null = (
~pivoted_pdf[
[
(s.column_name, s.category_name)
for s in chart_series
if s.category_name is not None
]
]
.isnull()
.all(axis=1)
)
first_non_null_data_point = pivoted_pdf[any_column_not_null].index.min()
# otherwise perform filter across all selected columns on the base input dataframe
else:
any_column_not_null = (
~pdf[[s.column_name for s in chart_series]].isnull().all(axis=1)
)
first_non_null_data_point = pdf[any_column_not_null][
x_axis.column_name
].min()
possible_start_points.append(first_non_null_data_point)
# modify x-configuration to update chart bounds
# taking the lowest value of all possible_start_points
first_non_null_data_point = min(possible_start_points)
pdf = pdf[pdf[x_axis.column_name] >= first_non_null_data_point]
if pivoted_pdf is not None:
pivoted_pdf = pivoted_pdf[pivoted_pdf.index >= first_non_null_data_point]
x_data = _sorted_x_data(pdf, x_axis)
x_configuration |= {
"data_min": x_data.min(),
"series": x_data,
}
return pdf, pivoted_pdf, x_configuration
def _sorted_x_data(pdf: pd.DataFrame, x_axis: Axis) -> pd.Series:
return pdf[x_axis.column_name].sort_values(ascending=True)