from typing import Any, Optional
from plotly import graph_objs as go
from pyspark.sql import DataFrame
import pandas as pd
from visualization_toolkit.exceptions import InvalidInputException
from visualization_toolkit.helpers.plotly.charts.core.figure import generate_base_figure
from visualization_toolkit.helpers.plotly.charts.core.chart import (
_normalize_input_data_to_pandas,
)
from visualization_toolkit.helpers.plotly.charts.core.series import Series
from visualization_toolkit.helpers.plotly.charts.core.axis import Axis
from visualization_toolkit.helpers.plotly.colors import resolve_color
from visualization_toolkit.helpers.plotly.theme import (
YD_CLASSIC_THEME,
ATLAS_TICK_COLOR,
)
[docs]
def waterfall_chart(
df: DataFrame | pd.DataFrame | list[dict[str, Any]],
y1_axis: Axis,
totals: list[Series] = None,
additions: list[Series] = None,
subtractions: list[Series] = None,
theme: dict = YD_CLASSIC_THEME,
) -> go.Figure:
"""
Create a Waterfall chart that can summarize the effects of multiple positive or negative adjustments to a starting and ending balance.
The chart will automatically style ``additions`` and ``subtractions`` series so that the starting value from ``totals`` equals the ending value of ``totals``.
Data must be provided as a single row where the columns indicate the starting and ending totals along with any adjustments that are to be plotted.
:param df: Input data to visualize. The columns from this data will be used for the ``totals``, ``additions``, and ``subtractions`` series and the ``y1_axis``.
:param y1_axis: The ``axis`` that the balances are plotted on
:param totals: Two ``series`` containing the starting and ending balance for the data. The order matters so pass in the totals in the correct order.
:param additions: Any number of ``series`` indicating positive adjustments to the starting balance.
:param subtractions: Any number of ``series`` indicating negative adjustments to the starting balance.
:param theme: Optionally changes the theme used to format the chart. Defaults to the ``YD_CLASSIC_THEME`` if not set.
:return:
Examples
^^^^^^^^^^^^^
.. code-block:: python
:caption: Example of creating a waterfall chart. Notice the shape of data that is used. It is acceptable to pass data as a dataframe, pandas dataframe, or a single-item list of dictionaries.
from visualization_toolkit.helpers.plotly import waterfall_chart, axis, series
data = [
{
"prior_period": 186994,
"expansion": 36691,
"contraction": -36489,
"churn": -40530,
"ending_period": 146656,
}
]
fig = waterfall_chart(
data,
y1_axis=axis(label="Spend", axis_type="currency"),
totals=[
series(
column_name="prior_period",
label="Prior Period",
),
series(
column_name="ending_period",
label="Ending Period",
),
],
additions=[
series(
column_name="expansion",
label="Expansion",
),
],
subtractions=[
series(
column_name="contraction",
label="Contraction",
),
series(
column_name="churn",
label="Churn",
),
],
)
display(fig)
"""
totals = totals or []
additions = additions or []
subtractions = subtractions or []
if len(totals) != 2:
raise InvalidInputException("2 total series must be provided")
pdf = _normalize_input_data_to_pandas(df)
data = pdf.to_dict("records")[0]
x = []
measure = []
y = []
x.append(totals[0].resolved_label)
y.append(data[totals[0].column_name])
measure.append("absolute")
for series_item in additions:
x.append(series_item.resolved_label)
y.append(data[series_item.column_name])
measure.append("relative")
for series_item in subtractions:
x.append(series_item.resolved_label)
y.append(data[series_item.column_name])
measure.append("relative")
x.append(totals[1].resolved_label)
y.append(data[totals[1].column_name])
measure.append("absolute")
axis_min, axis_max = _get_axis_range(
data,
totals=totals,
additions=additions,
subtractions=subtractions,
)
fig = generate_base_figure(theme=theme)
y1_axis.axis_min = axis_min
y1_axis.axis_max = axis_max
tick_values = y1_axis.resolved_tick_values(axis_min, axis_max)
tick_range = y1_axis.resolved_range(tick_values)
fig.add_trace(
go.Waterfall(
orientation="v",
measure=measure,
x=x,
text=[y1_axis.format_value(val) for val in y],
textposition="auto",
y=y,
connector={"line": {"color": ATLAS_TICK_COLOR, "width": 1}},
increasing=dict(marker=dict(color=resolve_color("light-blue"))),
decreasing=dict(marker=dict(color=resolve_color("orange"))),
totals=dict(marker=dict(color=resolve_color("dark-blue"))),
)
)
fig.update_layout(
yaxis={
"tickvals": tick_values,
"range": tick_range,
"tickformat": y1_axis.resolved_tick_format,
}
)
return fig
def _get_axis_range(
data: dict,
totals: list[Series] = None,
additions: list[Series] = None,
subtractions: list[Series] = None,
) -> (int | float, int | float):
axis_min = 0
for series_item in totals:
if data[series_item.column_name] < axis_min:
axis_min = data[series_item.column_name]
for series_item in subtractions:
if data[series_item.column_name] < 0:
axis_min -= data[series_item.column_name]
axis_max = 0
for series_item in totals:
if data[series_item.column_name] > axis_max:
axis_max = data[series_item.column_name]
for series_item in additions:
if data[series_item.column_name] > 0:
axis_max += data[series_item.column_name]
return axis_min, axis_max