from typing import Literal, Optional
from datetime import datetime, date
from dataclasses import dataclass, field
from copy import deepcopy
import pandas as pd
from dateutil import parser
from plotly import graph_objects as go
from visualization_toolkit.exceptions import InvalidInputException
from visualization_toolkit.helpers.plotly.charts.core.axis import Axis
from visualization_toolkit.helpers.plotly.colors import resolve_color, hex_to_rgba
from visualization_toolkit.helpers.plotly.theme import (
DEFAULT_COLORS,
SERIES_LINE_WIDTH,
TRANSPARENT,
SHADE_OPACITY,
)
@dataclass
class Series:
"""
A series is used to define a line, bar, area plot on a graph. Each series represents one
column of the dataset. If multiple pivoted "series" exist on a column, then the `category_name` argument can be used
to generate or specify each pivoted series.
Each series will be colored based on the company colors if not otherwise specified. Series
are by default line plots but can be customized via the `mode` attribute.
Series will always be on the x-axis and one of the y-axes, specified by `location`.
:param column_name: The column name for the series y-values based on the input data for the ``chart`` function.
:param category_name: The column name for a categorical value within pivoted chart data to be used for this specific series. By default, this value is ``None``, and should only be used if pivoting is enabled via ``pivot_column_name`` and ``include_all_categories=False``.
:param color: The color of the plot for this series. Default is that the color is automatically determined based on company colors.
:param location: The Y-axis location of the series. Default is the Y1 axis.
:param label: The series legend label. Default is the column_name for the series. Legends are only displayed when multiple series exit on a chart.
:param mode: The type of plot for plotly (ex: lines, bar, area, lines+markers, markers) used for the series. Default is "lines" for a line chart.
:param shape: The shape (ex: dashed, dotted, striped fill) of the plot which is behaves differently based on the ``mode``. Default is a solid series plot.
:param hover_format: Control the data label format when hovering over this series. Default is a standard format based on the corresponding axis' ``axis_type``.
:param is_stacked: Flag to control if pivoted chart data shoud be used to generate multiple series dynamically in the ``chart`` function. (Default is False and should not be used unless pivoting is used in the ``chart`` function)
:param connect_gaps: If `True`, then the series will be plotted for null or missing data. Default is `False`, i.e. series is not plotted for null values.
:param shade_series: If specified, a shaded region will be applied to the series
:param pivot_column_name: Column name of the input ``df`` to pivot, where each category in the pivot column can then be used as a seperate ``series`` of the chart.
:param include_all_categories: If ``True``, all categories will be expanded as separate ``series`` automatically instead of needing to specify ``series`` individually. This is meant to be a convenience flag, but will limit the format customization options available. The default is ``False``.
:param category_sort_column_name: If specified, the column name will be used to aggregate the unique categorie and sort them in the legend. To specify a descending order, prefix with a ``-``, ex: ``category_sort_column_name=-yd_income_bucket``. Default behavior is alphabetical sorting based on category name. This field is only used when ``include_all_categories=True``. When False, the series are ordered by how their position in the chart_series list passed to the ``chart`` function.
:param color_mapping: Optional color mapping of key names to color labels, hexcodes, or rgba values. The keys should be either the ``column_name`` or ``category_name`` of the series.
:param extra_options: Additional options to be passed to the Plotly figure object to format this series.
Examples
^^^^^^^^^^^^^
.. code-block:: python
:caption: Common example of creating multiple series on a chart. Notice that each series corresponds to a different column on the dataset.
from visualization_toolkit.helpers.plotly import chart, axis, series, annotation
fig = chart(
df,
x_axis=axis(
column_name="year",
label="Year",
),
chart_series=[
series(
column_name="australia",
label="Australia",
),
series(
column_name="new_zealand",
label="New Zealand",
),
],
y1_axis=axis(
label="Life Expectancy",
axis_type="number",
),
)
display(fig)
.. code-block:: python
:caption: Alternative example of creating series dynamically using the `category_column_name`. In this case each unique country of the input data will be plotted on a parallel series.
from visualization_toolkit.helpers.plotly import chart, axis, series, annotation
fig = chart(
df,
x_axis=axis(
column_name="year",
label="Year",
),
chart_series=[
series(
column_name="lifeExp",
category_name="New Zealand",
),
series(
column_name="lifeExp",
category_name="Australia",
),
],
y1_axis=axis(
label="Life Expectancy",
axis_type="number",
),
pivot_column_name="country",
)
display(fig)
"""
column_name: str = None
category_name: str = None
color: str = None
color_scale: Literal[
"default", 50, 100, 200, 300, 400, 500, 600, 700, 800, 900
] = "default"
location: Literal["y1", "y2"] = "y1"
label: str = None
mode: Literal[
"lines",
"bar",
"area",
"lines+markers",
"markers",
"clustered_bar",
"line",
"line+marker",
"marker",
"scatter",
] = "lines"
shape: Optional[Literal["dash", "spline", "dot", "stripe"]] = None
hover_format: str = None
y_data: pd.Series = None
is_stacked: bool = False
connect_gaps: bool = False
show_in_legend: bool = True
shade_series: Optional["ShadeSeries"] = None
pivot_column_name: str = None
include_all_categories: bool = False
category_sort_column_name: Optional[str] = None
color_mapping: Optional[dict[str, str]] = None
extra_options: dict = field(default_factory=dict)
def __post_init__(self):
self._validate_shape()
def _validate_shape(self):
if self.shape is None:
return
match self.shape:
case "dash" | "spline" | "dot":
if self.resolved_mode not in ("lines", "lines+markers"):
raise ValueError(
"Line series-specific shapes can only be used with mode=lines"
)
case "stripe":
if self.resolved_mode not in ("bar", "clustered_bar"):
raise ValueError(
"Bar series-specific shapes can only be used with mode=bar"
)
@property
def is_bar_type(self) -> bool:
return self.resolved_mode in ("bar", "clustered_bar")
@property
def resolved_mode(
self,
) -> Literal["lines", "bar", "area", "lines+markers", "markers", "clustered_bar"]:
# Normalize mode input in case of common typos
match self.mode:
case "lines" | "line":
return "lines"
case "lines+markers" | "line+marker" | "lines+marker" | "line+markers":
return "lines+markers"
case "markers" | "marker" | "scatter":
return "markers"
case "bar":
return "bar"
case "area":
return "area"
case "clustered_bar":
return "clustered_bar"
case _:
raise InvalidInputException(f"Unexpected mode provided: {self.mode}")
def resolved_color(self, series_idx: int) -> str:
if self.color is not None:
return resolve_color(self.color)
if self.color_mapping is not None:
color = self.color_mapping.get(
self.category_name
) or self.color_mapping.get(self.column_name)
if color is not None:
return resolve_color(color)
# Loop through default colors based on series index
return DEFAULT_COLORS[series_idx % len(DEFAULT_COLORS)]
@property
def resolved_label(self) -> str:
return self.label or self.column_name
def resolved_hover_format(self, y_axis: Axis) -> str:
# Plotly does tick formatting based on D3 format
# docs: https://d3js.org/d3-format
if self.hover_format:
return self.hover_format
match y_axis.axis_type:
case "percent":
return f".1%"
case "number":
return f",.0f"
case "currency":
return f"{y_axis.currency_symbol},.0f"
case "date":
return "%m/%d/%y"
def resolved_x_data(self, pdf: pd.DataFrame, x_column_name: str) -> pd.Series:
if self.y_data is not None:
return pd.Series(self.y_data.index)
return pdf[x_column_name]
def resolved_y_data(self, pdf: pd.DataFrame, column_name: str = None) -> pd.Series:
if self.y_data is not None:
return self.y_data
return pdf[column_name or self.column_name]
def resolved_trace(
self,
chart_data: pd.DataFrame,
x_axis: Axis,
y_axis: Axis,
series_position: int = 0,
column_name: str = None,
extra_options: dict = None,
) -> go.Scatter:
extra_options = deepcopy(self.extra_options | (extra_options or {}))
y_data = self.resolved_y_data(chart_data, column_name=column_name)
template = f"%{{y:{self.resolved_hover_format(y_axis)}}}"
extra_options |= {
"hovertemplate": extra_options.get("hovertemplate", template),
"showlegend": extra_options.get("showlegend", self.show_in_legend),
}
if self.connect_gaps:
extra_options["connectgaps"] = True
# Add chart-type-specific styling
# ex: https://plotly.com/python/line-charts/#style-line-plots
match self.shape:
case "dash" | "dot":
extra_options["line"] = {"dash": self.shape}
case "spline":
extra_options["line_shape"] = "spline"
match self.resolved_mode:
case "bar" | "clustered_bar":
x_data = self.resolved_x_data(chart_data, x_axis.column_name)
graph_class = go.Bar
graph_options = {
"x": x_data,
"y": y_data,
"yaxis": y_axis.location,
"marker_color": (
# For clustered_bar charts default to
# having each x-value ("cluster") use a different color
[self.resolved_color(idx) for idx in range(len(x_data))]
if self.resolved_mode == "clustered_bar"
# otherwise assign a color to each bar series in the chart
else self.resolved_color(series_position)
),
"name": self.resolved_label,
} | extra_options
case "area":
graph_class = go.Scatter
graph_options = {
"x": self.resolved_x_data(chart_data, x_axis.column_name),
"y": y_data,
"yaxis": y_axis.location,
"line_color": self.resolved_color(series_position),
"fillcolor": self.resolved_color(series_position),
"stackgroup": y_axis.location,
"name": self.resolved_label,
"line": {"width": SERIES_LINE_WIDTH},
} | extra_options
case _:
graph_class = go.Scatter
graph_options = {
"x": self.resolved_x_data(chart_data, x_axis.column_name),
"y": y_data,
"yaxis": y_axis.location,
"mode": self.resolved_mode,
"line_color": self.resolved_color(series_position),
"name": self.resolved_label,
"line": {"width": SERIES_LINE_WIDTH},
} | extra_options
trace = graph_class(**graph_options)
# Clustered bars should also store actual x-values in customdata
# Create a list of the same parsed date value for each x point
if self.resolved_mode == "clustered_bar":
try:
parsed_value = parser.parse(trace.name)
customdata = pd.DataFrame([parsed_value] * len(trace.x))
except TypeError:
# For non-date types, use the string representation
customdata = pd.DataFrame([str(trace.name)] * len(trace.x))
trace.update(customdata=customdata)
return trace
def get_categories(self, pdf: pd.DataFrame) -> list[str]:
sort_column_name = self.category_sort_column_name or self.pivot_column_name
ascending = True
if sort_column_name[0] == "-":
ascending = False
sort_column_name = sort_column_name[1:]
return (
pdf.groupby(self.pivot_column_name)[sort_column_name]
.min()
.sort_values(ascending=ascending)
.index.tolist()
)
@dataclass
class ShadeSeries:
"""
Adds shaded area for an existing ``series`` instance for a ``chart``.
The shaded area is based on two columns that need to be present in the input dataframe
and are specified via the ``boundary_column_names`` argument as a tuple.
:param boundary_column_names: Tuple of two columns that indicate the shaded boundary for the series. Should be in the format of ``(lower column, upper column)``.
:param label: Label for the shaded range that should be used in the legend and while hovering. When hovering, the label will be suffied with ``Lower Bound`` and ``Upper Bound`` to indicate the range of the shaded area.
:param color: Optional fill color of the shaded region. Defaults to parent series color.
:param opacity: Optional opacity of the shaded region fill color. Defaults to standard theme opacity.
Examples
^^^^^^^^^^^^^
.. code-block:: python
:caption: Example of adding a shaded area around a line chart. The columns ``yy``, ``lower_bound`` and ``upper_bound`` must all exist on the input dataframe.
from visualization_toolkit.helpers.plotly import chart, axis, series, shade_series
fig = chart(
df,
x_axis=axis(column_name="fiscal_qy", label="Fiscal Quarter"),
y1_axis=axis(label="Downloads Growth Rate", axis_type="percent"),
chart_series=[
series(
column_name="yy",
label="Y/Y Growth",
color="dark-blue",
shade_series=shade_series(
boundary_column_names=("lower_bound", "upper_bound"),
label="Margin of Error",
color="light-blue",
opacity=.3,
),
),
],
)
display(fig)
"""
boundary_column_names: (str, str)
label: str = None
color: str = None
opacity: float = SHADE_OPACITY
def add_to_figure(
self,
fig: go.Figure,
pdf: pd.DataFrame,
x_axis: "Axis",
y_axis: "Axis",
base_series: "Series",
base_series_idx: int = 0,
) -> None:
# Base color off of base series if not specified
if self.color is not None:
fill_color = resolve_color(self.color)
else:
fill_color = base_series.resolved_color(base_series_idx)
# If only a hex color is used, apply opacity to create a shaded area effect
if fill_color.startswith("#"):
fill_color = hex_to_rgba(fill_color, self.opacity)
# Use 2 line traces to generate the shaded effect in Plotly
# use the base series to generate the trace to inherit most formatting settings
# 1 trace is for the lower boundary and 1 trace is for the upper boundary
# make line color transparent to remove borders of the shaded area
# The lower bound entry will be shown in the hoverinfo but not the legend
lower_bound_trace = base_series.resolved_trace(
pdf,
x_axis,
y_axis,
base_series_idx,
column_name=self.boundary_column_names[0],
extra_options={
"name": f"{self.resolved_label(base_series)} Lower Bound",
"line_color": TRANSPARENT,
"showlegend": False,
},
)
fig.add_trace(lower_bound_trace)
# Add an upper bound series to create the shaded effect from the lower bound
# to the upper bound. This will not be shown on the legend but indicated in the hoverinfo
upper_bound_trace = base_series.resolved_trace(
pdf,
x_axis,
y_axis,
base_series_idx,
column_name=self.boundary_column_names[1],
extra_options={
"name": f"{self.resolved_label(base_series)} Upper Bound",
"showlegend": False,
"line_color": TRANSPARENT,
"fill": "tonexty",
"fillcolor": fill_color,
},
)
fig.add_trace(upper_bound_trace)
# Add a custom plot with no data as a placeholder
# So there is only 1 legend entry for the shaded bounds using the provided label
legend_placeholder_trace = base_series.resolved_trace(
pdf,
x_axis,
y_axis,
base_series_idx,
column_name=self.boundary_column_names[1],
extra_options={
"name": f"{self.resolved_label(base_series)}",
"line_color": TRANSPARENT,
"fill": "tonexty",
"fillcolor": fill_color,
"hoverinfo": "skip",
},
)
legend_placeholder_trace.x = [None]
legend_placeholder_trace.y = [None]
fig.add_trace(legend_placeholder_trace)
def resolved_label(self, base_series: "Series") -> str:
if self.label is not None:
return self.label
return base_series.resolved_label
[docs]
shade_series = ShadeSeries