Source code for visualization_toolkit.helpers.plotly.charts.core.chart

import os
import string
from typing import Literal, Optional, Any
from pathlib import Path
import base64
from datetime import date, datetime

import pandas as pd
from plotly import graph_objects as go
from pyspark.sql import DataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame

from visualization_toolkit.exceptions import InvalidInputException
from visualization_toolkit.helpers.plotly.charts.core.axis import Axis
from visualization_toolkit.helpers.plotly.charts.core.series import Series, series
from visualization_toolkit.helpers.plotly.charts.core.annotation import (
    Annotation,
)
from visualization_toolkit.helpers.plotly.charts.core.shading import (
    ShadeX,
    ShadeY,
)
from visualization_toolkit.helpers.plotly.colors import resolve_color
from visualization_toolkit.helpers.plotly.theme import (
    HOVER_LABEL_BACKGROUND_COLOR,
    TRANSPARENT,
    LEGEND_Y_POSITION,
    CHART_LOGO_X_POSITION,
    CHART_LOGO_MEDIUM_X_POSITION,
    CHART_LOGO_X_SIZE,
    CHART_LOGO_Y_SIZE,
    WHITE,
    BACKGROUND_SECONDARY,
    COLORS,
)
from visualization_toolkit.helpers.plotly.charts.core.figure import generate_base_figure


BASE_DIR = Path(__file__).resolve().parent.parent.parent.parent.parent


[docs] def chart( df: DataFrame | pd.DataFrame | list[dict[str, Any]], chart_series: list[Series], x_axis: Axis, y1_axis: Optional[Axis] = None, y2_axis: Optional[Axis] = None, annotations: Optional[list[Annotation]] = None, shaded_regions: Optional[list[ShadeX | ShadeY]] = None, include_logo: bool = False, theme: Optional[dict] = None, custom_options: Optional[dict] = None, ) -> go.Figure: """ Standard function generate a plotly-based chard with standard company styling. The ``chart`` function works closely with the other toolkit building blocks for charts: ``axis``, ``series``, and ``annotations``. The input data to the chart can be passed in as a spark dataframe, pandas dataframe, or a list of dictionaries. The input data is normalized to handle the remaining operations to be supplied to plotly to render the chart. :param df: Input data to visualize. The columns from this data will be used for the ``chart_series`` and various axes. :param chart_series: A list of ``series`` instances that will be plotted on this chart figure. :param x_axis: The axis configuration for the x-axis of this chart figure. :param y1_axis: The axis configuration for the y1-axis of this chart figure. :param y2_axis: The axis configuration for the y2-axis of this chart figure. By default, no y2 axis is included. :param annotations: A list of ``annotations`` to include on the chart. :param shaded_regions: A list of ``shade_x`` and/or ``shade_y`` to include on the chart to shade regions. :param include_logo: Optionally include the YipitData logo at the bottom right of the chart. Default is False (no logo added). :param theme: Optionally changes the theme used to format the chart. Defaults to the ``YD_CLASSIC_THEME`` if not set. :param custom_options: Optionally include any options to pass into figure.update_layout. This is an escape hatch for final adjustments to a chart. :return: Plotly ``figure`` object that can be displayed in databricks or supplied to a dash app as a property of the ``dcc.Graph`` component. Examples ^^^^^^^^^^^^^ .. code-block:: python :caption: Common example of creating multiple series on a chart. Notice that each series corresponds to a different column on the dataset. from visualization_toolkit.helpers.plotly import chart, axis, series, annotation fig = chart( df, x_axis=axis( column_name="year", label="Year", ), chart_series=[ series( column_name="australia", label="Australia", ), series( column_name="new_zealand", label="New Zealand", ), ], y1_axis=axis( label="Life Expectancy", axis_type="number", ), ) display(fig) """ # validate inputs _validate_axes( chart_series=chart_series, x_axis=x_axis, y1_axis=y1_axis, y2_axis=y2_axis, ) # initialize blank figure object if theme is not None: fig = generate_base_figure(theme=theme) else: fig = generate_base_figure() annotations = annotations or [] shaded_regions = shaded_regions or [] # standardize data in pandas # truncate data on x-axis if given pdf = _normalize_input_data_to_pandas(df) pdf = filter_pdf_for_x_axis_bounds(pdf, x_axis) # organize series for axes x_axis.location = "x" x_series_data = _sorted_x_data(pdf, x_axis) x_configuration = { "data_min": x_series_data.min(), "data_max": x_series_data.max(), "series": x_series_data, } y1_axis.location = "y1" y1_configuration = { "data_min": None, "data_max": None, "series": [], "axis": y1_axis, } if y2_axis: y2_axis.location = "y2" y2_configuration = { "data_min": None, "data_max": None, "series": [], "axis": y2_axis, } # pivot data if needed for chart if any((s.pivot_column_name is not None for s in chart_series)): pivoted_series = [s for s in chart_series if s.pivot_column_name is not None] if len([s for s in pivoted_series if s.include_all_categories]) > 1: raise InvalidInputException( "Only one series can be pivoted per chart. " "Either pivot the input data ahead of time or adjust the series passed to the chart function." ) pivoted_pdf = pivot_pdf( pdf, x_axis, pivoted_series, pivoted_series[0].pivot_column_name, ) else: pivoted_pdf = None pdf, pivoted_pdf, x_configuration = filter_for_null_data_points( pdf, x_axis, chart_series, x_configuration=x_configuration, pivoted_pdf=pivoted_pdf, ) y1_min, y1_max, y2_min, y2_max = generate_y_axes_bounds( pdf, chart_series, x_axis, pivoted_pdf=pivoted_pdf, ) y1_configuration["data_min"] = y1_min y1_configuration["data_max"] = y1_max y2_configuration["data_min"] = y2_min y2_configuration["data_max"] = y2_max for chart_series_item in chart_series: series_to_add = [chart_series_item] # If pivoting the data and including all categories # generate a chart_series for each category dynamically using the pivoted_pdf pivot_column_name = chart_series_item.pivot_column_name include_all_categories = chart_series_item.include_all_categories if pivot_column_name is not None and include_all_categories: categories = chart_series_item.get_categories(pdf) series_to_add = [ series( column_name=category, location=chart_series_item.location, label=_resolve_label( None, category, chart_series_item.column_name, ), mode=chart_series_item.mode, y_data=pivoted_pdf[chart_series_item.column_name][category], shape=chart_series_item.shape, hover_format=chart_series_item.hover_format, color=chart_series_item.color, extra_options=chart_series_item.extra_options, ) for category in categories ] # Otherwise if pivoted for a subset of categories, # modify specified series using the pivoted pdf to plot the correct data elif pivot_column_name: series_to_add = [ series( column_name=chart_series_item.category_name, location=chart_series_item.location, label=_resolve_label( chart_series_item.label, chart_series_item.category_name, chart_series_item.column_name, ), mode=chart_series_item.mode, y_data=pivoted_pdf[chart_series_item.column_name][ chart_series_item.category_name ], shape=chart_series_item.shape, hover_format=chart_series_item.hover_format, color=chart_series_item.color, extra_options=chart_series_item.extra_options, ) ] if chart_series_item.location == "y1": for _chart_series_item in series_to_add: y1_configuration["series"].append(_chart_series_item) elif chart_series_item.location == "y2": for _chart_series_item in series_to_add: y2_configuration["series"].append(_chart_series_item) # Add each y1 series to the figure object for idx, chart_series_item in enumerate(y1_configuration["series"]): if chart_series_item.shade_series is not None: chart_series_item.shade_series.add_to_figure( fig=fig, pdf=pdf, x_axis=x_axis, y_axis=y1_axis, base_series=chart_series_item, base_series_idx=idx, ) trace = chart_series_item.resolved_trace( pdf, x_axis, y1_axis, idx, ) fig.add_trace(trace) # Add each y2 series to the figure object for idx, chart_series_item in enumerate(y2_configuration["series"]): if chart_series_item.shade_series is not None: chart_series_item.shade_series.add_to_figure( fig=fig, pdf=pdf, x_axis=x_axis, y_axis=y2_axis, base_series=chart_series_item, base_series_idx=idx, ) trace = chart_series_item.resolved_trace( pdf, x_axis, y2_axis, idx, ) fig.add_trace(trace) # Configure x-axis general_options = {} apply_offset = True for s in chart_series: if s.is_bar_type and s.is_stacked: # Use stacked bar charts and ensure axis is scaled correctly general_options = {"barmode": "relative"} elif s.mode == "area": apply_offset = False if x_axis.axis_type == "date": x_axis_options = x_axis.date_axis_configuration( x_configuration["series"], apply_offset=apply_offset, ) elif any([s.is_bar_type for s in chart_series]): x_axis_options = x_axis.category_axis_configuration(x_configuration["series"]) else: x_axis_options = x_axis.axis_configuration( x_configuration["data_min"], x_configuration["data_max"] ) # configure y1-axis y1_axis_options = y1_axis.axis_configuration( y1_configuration["data_min"], y1_configuration["data_max"] ) # configure y2-axis if len(y2_configuration["series"]): y2_axis_options = y2_axis.axis_configuration( y2_configuration["data_min"], y2_configuration["data_max"] ) else: y2_axis_options = {} # Add annotations for annotation in annotations: y_options = y1_axis_options["yaxis"] if annotation.axis_location == "x" and "range" in y_options: y_range = y_options["range"] elif annotation.axis_location == "x": y_range = y_options["tickvals"][0], y_options["tickvals"][-1] else: y_range = None annotation.register_in_figure( fig, # pass in the y-range bounds if the annotation is on the x-axis, # this is to avoid a plotly bug with add_vlines # https://github.com/plotly/plotly.py/issues/3065 y_range=y_range, ) # Add shaded regions for region in shaded_regions: region.add_to_figure(fig) # finalize figure layout fig.update_layout( legend_title_text="", legend_y=LEGEND_Y_POSITION, # Controlling x-axis **x_axis_options, **y1_axis_options, **y2_axis_options, **general_options, # To customize hover label format hovermode="x unified", hoverlabel_bgcolor=WHITE, hoverlabel_bordercolor=BACKGROUND_SECONDARY, xaxis_spikecolor=COLORS["dark-grey"][700], xaxis_spikethickness=1, ) # Clustered bar charts need additional customization if any([s.resolved_mode == "clustered_bar" for s in chart_series]): fig.update_layout( hovermode="closest", xaxis_tickangle=0, showlegend=False, ) fig.update_xaxes( showspikes=False, tickfont={"color": resolve_color("dark-grey")} ) fig.update_yaxes(showspikes=False) if include_logo: _add_yipitdata_logo_to_chart( fig, y1_axis=y1_axis, y2_axis=y2_axis, ) if custom_options is not None: fig.update_layout(**custom_options) return fig
def _normalize_input_data_to_pandas( df: DataFrame | pd.DataFrame | list[dict[str, Any]] ) -> pd.DataFrame: if isinstance(df, (DataFrame, ConnectDataFrame)): return df.toPandas() elif isinstance(df, pd.DataFrame): return df elif isinstance(df, list): return pd.DataFrame(df) raise ValueError("Invalid type for dataframe") def _resolve_label( chart_series_label: str | None, category_name: str | None, column_name: str ) -> str: # If a specific title is passed in, use that if chart_series_label: return chart_series_label elif isinstance(category_name, (date, datetime)): return category_name.isoformat() # For acronyms, return as-is, ex: UCAN elif category_name.upper() == category_name: return category_name # Otherwise titleize the category or column name elif category_name: return string.capwords(category_name.replace("_", " ")) return string.capwords(column_name.replace("_", " ")) def _add_yipitdata_logo_to_chart( fig: go.Figure, y1_axis: Axis = None, y2_axis: Axis = None, ): # Determine x positioning based on a % of plot area # simple factor in determining plot area size is the number of y-axis titles # more titles, smaller plot area, requires higher % for x-position so logo is placed on edge of plot title_count = 0 if y1_axis is not None and y1_axis.label is not None: title_count += 1 if y2_axis is not None and y2_axis.label is not None: title_count += 1 match title_count: case 0 | 1: x = CHART_LOGO_X_POSITION case 2: x = CHART_LOGO_MEDIUM_X_POSITION # Load the logo image from the library and attach to the figure as base64 data with open( os.path.join(BASE_DIR, "static/logos/yipitdata-logo-dark.png"), "rb" ) as f: _input = f.read() binary_data = base64.b64encode(_input) encoded = binary_data.decode("utf-8") # Place logo on the right edge of x-axis and aligned with legend on the bottom fig.update_layout( images=[ dict( source=f"data:image/png;base64,{encoded}", xref="paper", yref="paper", x=x, y=fig._layout["legend"]["y"], sizex=CHART_LOGO_X_SIZE, sizey=CHART_LOGO_Y_SIZE, layer="below", xanchor="right", yanchor="bottom", ) ], ) def _validate_axes( chart_series: list[Series], x_axis: Axis = None, y1_axis: Axis = None, y2_axis: Axis = None, ): if x_axis is None: raise InvalidInputException("x_axis must be supplied") if not isinstance(x_axis, Axis): raise InvalidInputException( f"Invalid type, x_axis must be an ``axis`` instance, received: {type(x_axis)}" ) for _series in chart_series: match _series.location: case "y1": if y1_axis is None: raise InvalidInputException("y1_axis must be supplied") if not isinstance(y1_axis, Axis): raise InvalidInputException( "Invalid type, y1_axis must be an ``axis`` instance" ) case "y2": if y2_axis is None: raise InvalidInputException("y2_axis must be supplied") if not isinstance(y2_axis, Axis): raise InvalidInputException( "Invalid type, y2_axis must be an ``axis`` instance" ) case _: raise InvalidInputException( "Invalid `location` for series, must be either 'y1' or 'y2'" ) def generate_y_axes_bounds( pdf: pd.DataFrame, chart_series: list[Series], x_axis: Axis, pivoted_pdf: pd.DataFrame = None, ) -> (Optional[float], Optional[float], Optional[float], Optional[float]): """ Returns the upper and lower bounds for y1 and y2 axes of the given series data. If an axis is not used by any series, the bounds will be None. These bounds are used to position the axes min and max with more human-readable values so that all plotted series are within the graph's frame. :param pdf: :param chart_series: :param x_axis: :param include_all_categories: :param pivoted_pdf: :param pivot_column_name: :return: """ y1_min, y1_max = None, None y2_min, y2_max = None, None for chart_series_item in chart_series: # Account for shaded region boundary columns if part of the series pivot_column_name = chart_series_item.pivot_column_name include_all_categories = chart_series_item.include_all_categories if chart_series_item.shade_series is not None: columns_with_boundaries = [ chart_series_item.column_name, *chart_series_item.shade_series.boundary_column_names, ] else: columns_with_boundaries = [chart_series_item.column_name] if chart_series_item.is_stacked: # If this is a stacked series that needs to be pivoted, take the sum for each x-axis period across all series # and use the min / max of those period totals to generate bounds for the series on the y-axis if include_all_categories and pivot_column_name is not None: grouped_series = pdf.groupby(x_axis.column_name)[ columns_with_boundaries ].sum() elif pivot_column_name is not None: # If a stacked series that needs to be pivoted with some categories included, filter for the specified categories, # and then take the sum for each x-axis period to generate bounds for the series on y-axis grouped_series = ( pdf[ pdf[pivot_column_name].isin( [ s.category_name for s in chart_series if s.location == chart_series_item.location ] ) ] .groupby(x_axis.column_name)[columns_with_boundaries] .sum() ) else: # Otherwise for a stacked set of series, calculate the totals along the x-axis # for the series columns specified and use those totals to set the bound pdf["__stacked_total__"] = sum( [ pdf[s.column_name] for s in chart_series if s.location == chart_series_item.location ] ) grouped_series = pdf.groupby(x_axis.column_name)[ "__stacked_total__" ].sum() # Newer versions of Pandas do not support converting a single-item series to float # so need to normalize to a numerical value by indexing into the series if possible # Older versions return a numpy.int64 value so that can be casted directly agg_min = grouped_series.min() if hasattr(agg_min, "iloc"): series_min = float(agg_min.iloc[0]) else: series_min = float(agg_min) agg_max = grouped_series.max() if hasattr(agg_max, "iloc"): series_max = float(agg_max.iloc[0]) else: series_max = float(agg_max) # If working with a pivoted dataset and not stacked # use the category min and max value to generate bounds for the series on the y-axis elif pivoted_pdf is not None and chart_series_item.category_name is not None: grouped_series = pivoted_pdf[chart_series_item.column_name][ chart_series_item.category_name ] series_min = float(grouped_series.min()) series_max = float(grouped_series.max()) elif len(columns_with_boundaries) > 1: series_min = float(pdf[columns_with_boundaries].min().min()) series_max = float(pdf[columns_with_boundaries].max().max()) else: series_min = float(pdf[chart_series_item.column_name].min()) series_max = float(pdf[chart_series_item.column_name].max()) # Update global min/max bounds if it exceeds existing bounds from prior series if chart_series_item.location == "y1": if y1_min is None or series_min < y1_min: y1_min = series_min if y1_max is None or series_max > y1_max: y1_max = series_max elif chart_series_item.location == "y2": if y2_min is None or series_min < y2_min: y2_min = series_min if y2_max is None or series_max > y2_max: y2_max = series_max return y1_min, y1_max, y2_min, y2_max def filter_pdf_for_x_axis_bounds( pdf: pd.DataFrame, x_axis: Axis, ) -> pd.DataFrame: # If axis is a date type, handle datetime casting before filtering # Only apply filters on axis min and max if specified if x_axis.axis_min is not None and x_axis.axis_type == "date": pdf[x_axis.column_name] = pd.to_datetime(pdf[x_axis.column_name]) pdf = pdf[pdf[x_axis.column_name].dt.date >= x_axis.axis_min] elif x_axis.axis_min is not None: pdf = pdf[pdf[x_axis.column_name] >= x_axis.axis_min] if x_axis.axis_max is not None and x_axis.axis_type == "date": pdf[x_axis.column_name] = pd.to_datetime(pdf[x_axis.column_name]) pdf = pdf[pdf[x_axis.column_name].dt.date <= x_axis.axis_max] elif x_axis.axis_max is not None: pdf = pdf[pdf[x_axis.column_name] <= x_axis.axis_max] return pdf def pivot_pdf( pdf: pd.DataFrame, x_axis: Axis, chart_series: list[Series], pivot_column_name: str, ) -> pd.DataFrame: pivoted_pdf = pdf.pivot( index=x_axis.column_name, columns=pivot_column_name, values=list( set([chart_series_item.column_name for chart_series_item in chart_series]) ), ) return pivoted_pdf def filter_for_null_data_points( pdf: pd.DataFrame, x_axis: Axis, chart_series: list[Series], x_configuration: dict, pivoted_pdf: pd.DataFrame = None, ): # Truncate null data points along the x-axis if this flag is specified # Get the first data point with any non-null values for the selected columns of the dataframe # and filter the data to start at that point onwards possible_start_points = [] if x_axis.start_at_non_null_values: for chart_series_item in chart_series: # If pivoting and using all categories, then get the first point with any non-null value # across all categories, should_pivot = chart_series_item.pivot_column_name is not None if should_pivot and chart_series_item.include_all_categories: any_column_not_null = ~pivoted_pdf.isnull().all(axis=1) first_non_null_data_point = pivoted_pdf[any_column_not_null].index.min() # If pivoting with some categories, # then get the first point with any non-null data for those categories specifically elif should_pivot: any_column_not_null = ( ~pivoted_pdf[ [ (s.column_name, s.category_name) for s in chart_series if s.category_name is not None ] ] .isnull() .all(axis=1) ) first_non_null_data_point = pivoted_pdf[any_column_not_null].index.min() # otherwise perform filter across all selected columns on the base input dataframe else: any_column_not_null = ( ~pdf[[s.column_name for s in chart_series]].isnull().all(axis=1) ) first_non_null_data_point = pdf[any_column_not_null][ x_axis.column_name ].min() possible_start_points.append(first_non_null_data_point) # modify x-configuration to update chart bounds # taking the lowest value of all possible_start_points first_non_null_data_point = min(possible_start_points) pdf = pdf[pdf[x_axis.column_name] >= first_non_null_data_point] if pivoted_pdf is not None: pivoted_pdf = pivoted_pdf[pivoted_pdf.index >= first_non_null_data_point] x_data = _sorted_x_data(pdf, x_axis) x_configuration |= { "data_min": x_data.min(), "series": x_data, } return pdf, pivoted_pdf, x_configuration def _sorted_x_data(pdf: pd.DataFrame, x_axis: Axis) -> pd.Series: return pdf[x_axis.column_name].sort_values(ascending=True)