Source code for visualization_toolkit.helpers.plotly.charts.waterfall

from typing import Any, Optional

from plotly import graph_objs as go
from pyspark.sql import DataFrame
import pandas as pd

from visualization_toolkit.exceptions import InvalidInputException
from visualization_toolkit.helpers.plotly.charts.core.figure import generate_base_figure
from visualization_toolkit.helpers.plotly.charts.core.chart import (
    _normalize_input_data_to_pandas,
)
from visualization_toolkit.helpers.plotly.charts.core.series import Series
from visualization_toolkit.helpers.plotly.charts.core.axis import Axis
from visualization_toolkit.helpers.plotly.colors import resolve_color
from visualization_toolkit.helpers.plotly.theme import (
    YD_CLASSIC_THEME,
    ATLAS_TICK_COLOR,
)


[docs] def waterfall_chart( df: DataFrame | pd.DataFrame | list[dict[str, Any]], y1_axis: Axis, totals: list[Series] = None, additions: list[Series] = None, subtractions: list[Series] = None, theme: dict = YD_CLASSIC_THEME, ) -> go.Figure: """ Create a Waterfall chart that can summarize the effects of multiple positive or negative adjustments to a starting and ending balance. The chart will automatically style ``additions`` and ``subtractions`` series so that the starting value from ``totals`` equals the ending value of ``totals``. Data must be provided as a single row where the columns indicate the starting and ending totals along with any adjustments that are to be plotted. :param df: Input data to visualize. The columns from this data will be used for the ``totals``, ``additions``, and ``subtractions`` series and the ``y1_axis``. :param y1_axis: The ``axis`` that the balances are plotted on :param totals: Two ``series`` containing the starting and ending balance for the data. The order matters so pass in the totals in the correct order. :param additions: Any number of ``series`` indicating positive adjustments to the starting balance. :param subtractions: Any number of ``series`` indicating negative adjustments to the starting balance. :param theme: Optionally changes the theme used to format the chart. Defaults to the ``YD_CLASSIC_THEME`` if not set. :return: Examples ^^^^^^^^^^^^^ .. code-block:: python :caption: Example of creating a waterfall chart. Notice the shape of data that is used. It is acceptable to pass data as a dataframe, pandas dataframe, or a single-item list of dictionaries. from visualization_toolkit.helpers.plotly import waterfall_chart, axis, series data = [ { "prior_period": 186994, "expansion": 36691, "contraction": -36489, "churn": -40530, "ending_period": 146656, } ] fig = waterfall_chart( data, y1_axis=axis(label="Spend", axis_type="currency"), totals=[ series( column_name="prior_period", label="Prior Period", ), series( column_name="ending_period", label="Ending Period", ), ], additions=[ series( column_name="expansion", label="Expansion", ), ], subtractions=[ series( column_name="contraction", label="Contraction", ), series( column_name="churn", label="Churn", ), ], ) display(fig) """ totals = totals or [] additions = additions or [] subtractions = subtractions or [] if len(totals) != 2: raise InvalidInputException("2 total series must be provided") pdf = _normalize_input_data_to_pandas(df) data = pdf.to_dict("records")[0] x = [] measure = [] y = [] x.append(totals[0].resolved_label) y.append(data[totals[0].column_name]) measure.append("absolute") for series_item in additions: x.append(series_item.resolved_label) y.append(data[series_item.column_name]) measure.append("relative") for series_item in subtractions: x.append(series_item.resolved_label) y.append(data[series_item.column_name]) measure.append("relative") x.append(totals[1].resolved_label) y.append(data[totals[1].column_name]) measure.append("absolute") axis_min, axis_max = _get_axis_range( data, totals=totals, additions=additions, subtractions=subtractions, ) fig = generate_base_figure(theme=theme) y1_axis.axis_min = axis_min y1_axis.axis_max = axis_max tick_values = y1_axis.resolved_tick_values(axis_min, axis_max) tick_range = y1_axis.resolved_range(tick_values) fig.add_trace( go.Waterfall( orientation="v", measure=measure, x=x, text=[y1_axis.format_value(val) for val in y], textposition="auto", y=y, connector={"line": {"color": ATLAS_TICK_COLOR, "width": 1}}, increasing=dict(marker=dict(color=resolve_color("light-blue"))), decreasing=dict(marker=dict(color=resolve_color("orange"))), totals=dict(marker=dict(color=resolve_color("dark-blue"))), ) ) fig.update_layout( yaxis={ "tickvals": tick_values, "range": tick_range, "tickformat": y1_axis.resolved_tick_format, } ) return fig
def _get_axis_range( data: dict, totals: list[Series] = None, additions: list[Series] = None, subtractions: list[Series] = None, ) -> (int | float, int | float): axis_min = 0 for series_item in totals: if data[series_item.column_name] < axis_min: axis_min = data[series_item.column_name] for series_item in subtractions: if data[series_item.column_name] < 0: axis_min -= data[series_item.column_name] axis_max = 0 for series_item in totals: if data[series_item.column_name] > axis_max: axis_max = data[series_item.column_name] for series_item in additions: if data[series_item.column_name] > 0: axis_max += data[series_item.column_name] return axis_min, axis_max