Source code for visualization_toolkit.helpers.plotly.charts.heatmap
from typing import Any
import pandas as pd
from pyspark.sql import DataFrame
from plotly import graph_objects as go
from visualization_toolkit.helpers.plotly.charts.core.chart import (
    _normalize_input_data_to_pandas,
)
from visualization_toolkit.helpers.plotly.charts.core.figure import generate_base_figure
from visualization_toolkit.helpers.plotly.theme import (
    YD_CLASSIC_THEME,
    POSITIVE_COLORSCALE,
)
from visualization_toolkit.helpers.plotly.charts.core.axis import Axis
[docs]
def heatmap_chart(
    df: DataFrame | pd.DataFrame | list[dict[str, Any]],
    category_column_name: str,
    z_axis: Axis,
    theme: dict = YD_CLASSIC_THEME,
    colorscale: list[list[float, str]] = POSITIVE_COLORSCALE,
):
    """
    Create a heatmap visualization to compare two sets of categories in a table-like view.
    Cells in the heatmap are colored based on their value with respect to the ``colorscale`` provided.
    Cell values are formatted based on the ``z_axis`` provided. A ``category_column_name`` indicates which column contains one set of categories to compare. It is assumed all other columns represent the opposite set of categories to be compared.
    The data must have one categorical column with all other columns each representing a category to be compared. Values within each category pair will appear in the heatmap as a cell.
    :param df: Input data to visualize. Only one column type can include categorical values while all other columns must be numerical in nature.
    :param category_column_name: Name of the categorical column in the ``df`` to plot on one axis of the heatmap.
    :param z_axis: An ``axis`` object that defines how values within the heatmap should be formatted.
    :param theme: Optionally changes the theme used to format the chart. Defaults to the ``YD_CLASSIC_THEME`` if not set.
    :param colorscale: Optionally changes the colorscale used to format heatmap cells. Defaults to ``POSITIVE_COLORSCALE`` of various shades of blue.
    :return:
    Examples
    ^^^^^^^^^^^^^
    .. code-block:: python
        :caption: Example of creating a heatmap chart. Notice the ``z_axis`` is used to format cells in the heatmap with the ``axis_type`` argument.
        from visualization_toolkit.helpers.plotly import heatmap_chart, axis, series
        fig = heatmap_chart(
            pdf,
            category_column_name="fiscal_qy",
            z_axis=axis(label="Downloads", axis_type="number"),
        )
        display(fig)
    """
    fig = generate_base_figure(theme=theme)
    pdf = _normalize_input_data_to_pandas(df)
    pdf.set_index(category_column_name, inplace=True)
    # Generate text values for each x,y coordinate in the heatmap
    labels = [[z_axis.format_value(val) for val in row] for row in pdf.values.tolist()]
    fig.add_trace(
        go.Heatmap(
            x=pdf.columns.tolist(),
            y=pdf.index.tolist(),
            z=pdf.values.tolist(),
            text=labels,
            texttemplate="%{text}",
            colorbar_tickformat=z_axis.resolved_tick_format,
            colorscale=colorscale,
        )
    )
    return fig