from typing import Literal, Optional
from datetime import date, datetime
from dataclasses import dataclass, field
import math
import numpy
from decimal import Decimal
from dateutil.relativedelta import relativedelta
import pandas as pd
DECIMAL_PRECISION = 6
@dataclass
class Axis:
"""
Control axis behavior by using this function. Options to control the ticks, title, tick format, and overall range is possible.
Define an axis for each relavant axis of the chart (ex: 1 x-axis and y-axis is 2 axis function calls).
Axis will attempt to look as good as possible with minimum customization needed. The function will identify
the minimum and maximum bounds of the input data to ensure all data is inside the plotted frame.
In addition, the axis will try to generate as even as possible tick values given the dataset. If specific tick steps
are preferred specify the ``tick_interval`` and ``axis_min``.
:param column_name: The column of the input data of a `chart` function call to use.
:param location: Controls whether this axis is the x, y1, or y2 axis of the chart. This is not needed to be specified, as the ``chart`` function will set this value automatically.
:param label: Set an axis title for the chart. Default is no title is added.
:param axis_type: The numerical type of the data on this axis. It is important to specify this value, as it will control the default axis behavior and formatting.
:param tick_format: The numerical format to style tick values on the axis. If not specified, a default will be chosen based on the ``axis_type`` parameter.
:param tick_interval: Tick values will be incremented by a standard value starting from the ``axis_min`` or ``axis_max``.
:param axis_min: The lowest point of the axis will be this value. Default is this value is automatically determined by the input data on this axis.
:param axis_max: The greatest point of the axis will be this value. Default is this value is automatically determined by the input data on this axis.
:param number_of_ticks: The total number of ticks will be this value. Default is 6 for y-axes and 12 for the x-axis.
:param tick_angle: Controls the angle that ticks are placed on the axies. Only applies to the x-axis.
:param currency_symbol: The currency symbol prefixed to tick labels. Only applied if the ``axis_type=='currency'`` and no custom ``tick_format`` is used.
:param start_at_non_null_values: Used for x-axes, but when set to ``True``, will filter data to the first available data point that has at least one non-null column based on the series plotted for a chart.
:param extra_options: Additional options to pass to the axis figure object in Plotly.
Examples
^^^^^^^^^^^^^
.. code-block:: python
:caption: Example of creating axes and using in a ``chart`` function. Notice each axis is assigned to x/y1/y2 position in the chart function.
from visualization_toolkit.helpers.plotly import chart, axis, series, annotation
fig = chart(
df,
x_axis=axis(
column_name="year",
label="Year",
),
chart_series=[
series(
column_name="lifeExp",
category_name="country",
location="y1",
),
],
y1_axis=axis(
label="Life Expectancy",
axis_type="number",
),
)
display(fig)
"""
column_name: str = None
location: Literal["x", "y1", "y2"] = "y1"
label: str = None
axis_type: Literal["date", "category", "number", "currency", "percent"] = None
tick_format: str = None
tick_values: list[int | float | str] = None
tick_labels: list[str] = None
tick_interval: int | float | relativedelta = None
number_of_ticks: Optional[int] = None
tick_angle: int = -45
axis_min: int | float | date = None
axis_max: int | float | date = None
currency_symbol: str = "$"
start_at_non_null_values: bool = False
extra_options: dict = field(default_factory=dict)
def category_axis_configuration(self, x_data: pd.Series) -> dict:
axis_options = dict(
title_text=self.label,
tickmode="array",
tickangle=self.tick_angle if self.location == "x" else None,
tickformat=self.resolved_tick_format,
type=self.axis_type,
tickvals=x_data,
)
match self.location:
case "x":
prefix = "xaxis"
case "y1":
prefix = "yaxis"
case "y2":
prefix = "yaxis2"
axis_options |= {
"side": "right",
}
final_options = {prefix: axis_options}
return final_options
def date_axis_configuration(
self, x_data: pd.Series, apply_offset: bool = False
) -> dict:
tick_values = _generate_date_tick_values(
date_series=x_data,
axis_range=(self.axis_min, self.axis_max),
number_of_ticks=self.number_of_ticks,
tick_interval=self.tick_interval,
)
(range_min, range_max) = self.resolved_range(tick_values)
# Apply an offset to date ranges to ensure traces are not cut off at the boundaries of the plot area
if apply_offset:
offset = _get_date_tick_offset(
range_min,
range_max,
x_data,
)
else:
offset = 0
axis_options = dict(
title_text=self.label,
tickangle=self.tick_angle if self.location == "x" else None,
tickformat=self.resolved_tick_format,
type=self.axis_type,
tickvals=tick_values,
range=(
range_min - relativedelta(days=offset),
range_max + relativedelta(days=offset),
),
)
match self.location:
case "x":
prefix = "xaxis"
case "y1":
prefix = "yaxis"
case "y2":
prefix = "yaxis2"
axis_options |= {
"side": "right",
}
final_options = {prefix: axis_options}
return final_options
def axis_configuration(self, data_min: int | float, data_max: int | float) -> dict:
# When there is only one data point then we should scale axes using plotly defaults
if data_min == data_max:
tick_values = None
tick_range = None
else:
tick_values = self.resolved_tick_values(data_min, data_max)
tick_range = self.resolved_range(tick_values)
axis_options = dict(
title_text=self.label,
tickmode="array",
tickvals=tick_values,
tickformat=self.resolved_tick_format,
range=tick_range,
tickangle=self.tick_angle if self.location == "x" else None,
)
match self.location:
case "x":
prefix = "xaxis"
case "y1":
prefix = "yaxis"
case "y2":
prefix = "yaxis2"
axis_options |= {
"side": "right",
}
final_options = {prefix: axis_options}
return final_options
def resolved_tick_values(
self,
chart_min: int | float,
chart_max: int | float,
) -> list[int | float | str]:
match self.axis_type:
case "percent" | "number" | "currency":
tick_values = _generate_numerical_tick_values(
(chart_min, chart_max),
(self.axis_min, self.axis_max),
number_of_ticks=self.resolved_number_of_ticks,
tick_interval=self.tick_interval,
location=self.location,
)
return tick_values
@property
def resolved_tick_format(self) -> str:
# Plotly does tick formatting based on D3 format
# docs: https://d3js.org/d3-format
if self.tick_format:
return self.tick_format
# If no custom format is provided,
# use a standard format given the axis type
match self.axis_type:
case "percent":
return f".1%"
case "number":
return f",.0f"
case "currency":
return f"{self.currency_symbol},.0f"
case "date":
return "%m/%d/%y"
def resolved_range(
self, tick_values: list[int | float | date] = None
) -> (int | float | date, int | float | date):
if tick_values:
data_min = tick_values[0]
data_max = tick_values[-1]
else:
data_min = None
data_max = None
is_min_pandas_timestamp_type = isinstance(
data_min, (date, pd.Timestamp)
) or isinstance(self.axis_min, (date, pd.Timestamp))
is_max_pandas_timestamp_type = isinstance(
data_max, (date, pd.Timestamp)
) or isinstance(self.axis_max, (date, pd.Timestamp))
if (
is_min_pandas_timestamp_type
and self.axis_min is not None
and data_min is not None
and pd.Timestamp(self.axis_min).date() < pd.Timestamp(data_min).date()
):
data_min = self.axis_min
elif (
not is_min_pandas_timestamp_type
and self.axis_min is not None
and data_min is not None
and self.axis_min < data_min
):
data_min = self.axis_min
if (
is_max_pandas_timestamp_type
and self.axis_max is not None
and data_max is not None
and pd.Timestamp(self.axis_max).date() > pd.Timestamp(data_max).date()
):
data_min = self.axis_min
elif (
not is_max_pandas_timestamp_type
and self.axis_max is not None
and data_max is not None
and self.axis_max > data_max
):
data_max = self.axis_max
if data_min is not None and data_max is not None:
return (data_min, data_max)
return None
@property
def resolved_number_of_ticks(self) -> int:
if self.number_of_ticks is not None:
return self.number_of_ticks
return 12 if self.location == "x" else 6
def format_value(self, value: float | int) -> str:
return (
f"{self.currency_symbol if self.axis_type == 'currency' else ''}"
f"{format(value, self.resolved_tick_format.replace(self.currency_symbol, ''))}"
)
def _generate_numerical_tick_values(
series_range: (int | float, int | float),
axis_range: (Optional[int | float], Optional[int | float]) = (None, None),
number_of_ticks: int = None,
tick_interval: Optional[int | float] = None,
location: str = "y1",
) -> list[int | float]:
(min_tick, max_tick) = series_range
(min_axis, max_axis) = axis_range
# If null values are exclusively on the axis, return None for tick values
min_tick_is_nan = min_tick is None or (
min_tick is not None and math.isnan(min_tick)
)
max_tick_is_nan = max_tick is None or (
max_tick is not None and math.isnan(max_tick)
)
if min_tick_is_nan and max_tick_is_nan:
return None
# Prefer specified min-axis when it is lower than the smallest data value observed
if min_axis is not None and min_axis <= min_tick:
min_tick = min_axis
# Prefer specified max-axis when it is greater than the largest data value ovesrved
if max_axis is not None and max_axis >= max_tick:
max_tick = max_axis
if min_tick == max_tick:
raise ValueError("Axis min and max values must be different.")
# If the axis only plots series with null values return None for tick values
if min_tick is None or max_tick is None:
return None
# To avoid floating number rounding issues when axis sizing, truncate to 6 decimals
min_tick = round(float(min_tick), DECIMAL_PRECISION)
max_tick = round(float(max_tick), DECIMAL_PRECISION)
# Step 1: Calculate the range
data_range = max_tick - min_tick
# Step 2: Calculate a rough tick interval
if number_of_ticks is None:
match location:
case "x":
number_of_ticks = 12
case "y1" | "y2":
number_of_ticks = 6
rough_interval = float(data_range / number_of_ticks)
# Step 3: Round to a "nice" number (1, 2, 5, or powers of 10)
exponent = math.floor(math.log10(rough_interval))
base = 10**exponent
precision = abs(exponent)
if tick_interval is not None:
nice_interval = tick_interval
else:
# Snap to 0.5, 1, 2, 2.5, or 5
if rough_interval / base <= 0.5:
nice_interval = base / 2
elif rough_interval / base <= 1:
nice_interval = base
elif rough_interval / base <= 2:
nice_interval = 2 * base
elif rough_interval / base <= 2.5:
nice_interval = 2.5 * base
precision = precision + 1
else:
nice_interval = 5 * base
# Step 4: Calculate the "nice" axis limits
nice_min = math.floor(min_tick / nice_interval) * nice_interval
nice_max = math.ceil(max_tick / nice_interval) * nice_interval
# Step 5: Generate ticks
ticks = []
current_tick = nice_min
i = 0
while current_tick < nice_max:
current_tick = nice_min + (i * nice_interval)
# When intervals are decimals, round to the appropriate precision
# to avoid python floating numbers
if exponent <= 0:
current_tick = round(current_tick, abs(precision))
ticks.append(current_tick)
i += 1
return ticks
def _generate_date_tick_values(
date_series: pd.Series,
axis_range: (Optional[date | datetime], Optional[date | datetime]) = None,
number_of_ticks: int = None,
tick_interval: Optional[relativedelta] = None,
) -> list[date]:
final_dates = []
sorted_dates = date_series.sort_values(ascending=True)
min_date = sorted_dates.min()
max_date = sorted_dates.max()
if axis_range is None:
axis_min, axis_max = None, None
else:
axis_min, axis_max = axis_range
# Handle pandas type errors comparing pd.Timestamp with python date values
try:
if axis_min is not None and axis_min > min_date:
min_date = axis_min
except TypeError as e:
if axis_min is not None and pd.Timestamp(axis_min) > pd.Timestamp(min_date):
min_date = axis_min
try:
if axis_max is not None and axis_max < max_date:
max_date = axis_max
except TypeError as e:
if axis_max is not None and pd.Timestamp(axis_max) < pd.Timestamp(max_date):
max_date = axis_max
# If no tick interval or number of ticks specified
# use a default interval based on the amount of time elapsed in the series data
if tick_interval is None and number_of_ticks is None:
days_elapsed = (max_date - min_date).days
if days_elapsed <= 90:
tick_interval = relativedelta(days=7)
elif days_elapsed <= 180:
tick_interval = relativedelta(months=1)
elif days_elapsed <= 360:
tick_interval = relativedelta(months=3)
elif days_elapsed <= 720:
tick_interval = relativedelta(months=4)
else:
tick_interval = relativedelta(years=1)
# If number of ticks is specified, generate a relative delta
# based on the approximate days that matches
elif tick_interval is None:
days_elapsed = (max_date - min_date).days
tick_interval = relativedelta(days=days_elapsed // (number_of_ticks - 1))
next_expected_date = None
last_date = sorted_dates.iloc[-1]
tick_cutoff = last_date - relativedelta(
days=(last_date - (last_date - tick_interval)).total_seconds() / 86400 / 2
)
i = 0
for _, val in sorted_dates.items():
if val < min_date or val > max_date:
continue
# Always include the first date
if i == 0:
next_expected_date = val + tick_interval
final_dates.append(val)
# Include a date if it is at least 1 tick_interval after the prior selected date
# and if it is at least 1 tick_interval before the last date
elif val >= next_expected_date and val <= tick_cutoff:
final_dates.append(val)
next_expected_date = val + tick_interval
i += 1
# Always include the last date
final_dates.append(max_date)
return final_dates
def _nearest_interval_floor(
number: int | float | Decimal, interval: int | float | Decimal
) -> int | float:
return (float(number) // float(interval)) * interval
def _nearest_interval_ceiling(
number: int | float | Decimal, interval: int | float | Decimal
) -> int | float:
if number % interval == 0:
return (float(number) // float(interval)) * interval
return ((float(number) // float(interval)) * interval) + interval
def _round_up(number: int | float | Decimal, decimals=0) -> int | float:
factor = 10**decimals
return math.ceil(number * factor) / factor
def _normalize_date(value: datetime | date | numpy.datetime64 | pd.Timestamp) -> date:
if isinstance(value, pd.Timestamp):
parsed = value.date()
elif isinstance(value, numpy.datetime64):
parsed = datetime.fromtimestamp(value.item() / (10**9)).date()
elif isinstance(value, datetime):
parsed = value.date()
else:
parsed = value
return parsed
def _get_date_tick_offset(
range_min: datetime | date | numpy.datetime64 | pd.Timestamp,
range_max: datetime | date | numpy.datetime64 | pd.Timestamp,
x_data: pd.Series,
) -> int:
# Determine offset by taking the first two tick values in plot range
# and getting the delta between them. Otherwise, calculate by using min/max and number of ticks
# the offset is equal to half the tick value
tick_values_in_range = []
normalized_min = _normalize_date(range_min)
normalized_max = _normalize_date(range_max)
for value in sorted(x_data.unique()):
normalized_value = _normalize_date(value)
if normalized_value >= normalized_min and normalized_value <= normalized_max:
tick_values_in_range.append(normalized_value)
if len(tick_values_in_range) >= 2:
offset = (tick_values_in_range[1] - tick_values_in_range[0]).days // 2
else:
offset = (
(_normalize_date(range_max) - _normalize_date(range_min)).days
/ len(tick_values_in_range)
// 2
)
return offset