Source code for visualization_toolkit.helpers.plotly.charts.core.axis

from typing import Literal, Optional
from datetime import date, datetime
from dataclasses import dataclass, field
import math
import numpy

from decimal import Decimal
from dateutil.relativedelta import relativedelta
import pandas as pd

DECIMAL_PRECISION = 6


@dataclass
class Axis:
    """
    Control axis behavior by using this function. Options to control the ticks, title, tick format, and overall range is possible.
    Define an axis for each relavant axis of the chart (ex: 1 x-axis and y-axis is 2 axis function calls).

    Axis will attempt to look as good as possible with minimum customization needed. The function will identify
    the minimum and maximum bounds of the input data to ensure all data is inside the plotted frame.

    In addition, the axis will try to generate as even as possible tick values given the dataset. If specific tick steps
    are preferred specify the ``tick_interval`` and ``axis_min``.

    :param column_name: The column of the input data of a `chart` function call to use.
    :param location: Controls whether this axis is the x, y1, or y2 axis of the chart. This is not needed to be specified, as the ``chart`` function will set this value automatically.
    :param label: Set an axis title for the chart. Default is no title is added.
    :param axis_type: The numerical type of the data on this axis. It is important to specify this value, as it will control the default axis behavior and formatting.
    :param tick_format: The numerical format to style tick values on the axis. If not specified, a default will be chosen based on the ``axis_type`` parameter.
    :param tick_interval: Tick values will be incremented by a standard value starting from the ``axis_min`` or ``axis_max``.
    :param axis_min: The lowest point of the axis will be this value. Default is this value is automatically determined by the input data on this axis.
    :param axis_max: The greatest point of the axis will be this value. Default is this value is automatically determined by the input data on this axis.
    :param number_of_ticks: The total number of ticks will be this value. Default is 6 for y-axes and 12 for the x-axis.
    :param tick_angle: Controls the angle that ticks are placed on the axies. Only applies to the x-axis.
    :param currency_symbol: The currency symbol prefixed to tick labels. Only applied if the ``axis_type=='currency'`` and no custom ``tick_format`` is used.
    :param start_at_non_null_values: Used for x-axes, but when set to ``True``, will filter data to the first available data point that has at least one non-null column based on the series plotted for a chart.
    :param extra_options: Additional options to pass to the axis figure object in Plotly.

    Examples
    ^^^^^^^^^^^^^
    .. code-block:: python
        :caption: Example of creating axes and using in a ``chart`` function. Notice each axis is assigned to x/y1/y2 position in the chart function.

        from visualization_toolkit.helpers.plotly import chart, axis, series, annotation

        fig = chart(
            df,
            x_axis=axis(
                column_name="year",
                label="Year",
            ),
            chart_series=[
                series(
                    column_name="lifeExp",
                    category_name="country",
                    location="y1",
                ),
            ],
            y1_axis=axis(
                label="Life Expectancy",
                axis_type="number",
            ),
        )

        display(fig)

    """

    column_name: str = None
    location: Literal["x", "y1", "y2"] = "y1"
    label: str = None
    axis_type: Literal["date", "category", "number", "currency", "percent"] = None
    tick_format: str = None
    tick_values: list[int | float | str] = None
    tick_labels: list[str] = None
    tick_interval: int | float | relativedelta = None
    number_of_ticks: Optional[int] = None
    tick_angle: int = -45
    axis_min: int | float | date = None
    axis_max: int | float | date = None
    currency_symbol: str = "$"
    start_at_non_null_values: bool = False
    extra_options: dict = field(default_factory=dict)

    def category_axis_configuration(self, x_data: pd.Series) -> dict:
        axis_options = dict(
            title_text=self.label,
            tickmode="array",
            tickangle=self.tick_angle if self.location == "x" else None,
            tickformat=self.resolved_tick_format,
            type=self.axis_type,
            tickvals=x_data,
        )

        match self.location:
            case "x":
                prefix = "xaxis"

            case "y1":
                prefix = "yaxis"

            case "y2":
                prefix = "yaxis2"
                axis_options |= {
                    "side": "right",
                }

        final_options = {prefix: axis_options}

        return final_options

    def date_axis_configuration(
        self, x_data: pd.Series, apply_offset: bool = False
    ) -> dict:
        tick_values = _generate_date_tick_values(
            date_series=x_data,
            axis_range=(self.axis_min, self.axis_max),
            number_of_ticks=self.number_of_ticks,
            tick_interval=self.tick_interval,
        )
        (range_min, range_max) = self.resolved_range(tick_values)

        # Apply an offset to date ranges to ensure traces are not cut off at the boundaries of the plot area
        if apply_offset:
            offset = _get_date_tick_offset(
                range_min,
                range_max,
                x_data,
            )
        else:
            offset = 0

        axis_options = dict(
            title_text=self.label,
            tickangle=self.tick_angle if self.location == "x" else None,
            tickformat=self.resolved_tick_format,
            type=self.axis_type,
            tickvals=tick_values,
            range=(
                range_min - relativedelta(days=offset),
                range_max + relativedelta(days=offset),
            ),
        )

        match self.location:
            case "x":
                prefix = "xaxis"

            case "y1":
                prefix = "yaxis"

            case "y2":
                prefix = "yaxis2"
                axis_options |= {
                    "side": "right",
                }

        final_options = {prefix: axis_options}

        return final_options

    def axis_configuration(self, data_min: int | float, data_max: int | float) -> dict:
        # When there is only one data point then we should scale axes using plotly defaults
        if data_min == data_max:
            tick_values = None
            tick_range = None
        else:
            tick_values = self.resolved_tick_values(data_min, data_max)
            tick_range = self.resolved_range(tick_values)

        axis_options = dict(
            title_text=self.label,
            tickmode="array",
            tickvals=tick_values,
            tickformat=self.resolved_tick_format,
            range=tick_range,
            tickangle=self.tick_angle if self.location == "x" else None,
        )

        match self.location:
            case "x":
                prefix = "xaxis"

            case "y1":
                prefix = "yaxis"

            case "y2":
                prefix = "yaxis2"
                axis_options |= {
                    "side": "right",
                }

        final_options = {prefix: axis_options}
        return final_options

    def resolved_tick_values(
        self,
        chart_min: int | float,
        chart_max: int | float,
    ) -> list[int | float | str]:
        match self.axis_type:
            case "percent" | "number" | "currency":
                tick_values = _generate_numerical_tick_values(
                    (chart_min, chart_max),
                    (self.axis_min, self.axis_max),
                    number_of_ticks=self.resolved_number_of_ticks,
                    tick_interval=self.tick_interval,
                    location=self.location,
                )

                return tick_values

    @property
    def resolved_tick_format(self) -> str:
        # Plotly does tick formatting based on D3 format
        # docs: https://d3js.org/d3-format
        if self.tick_format:
            return self.tick_format

        # If no custom format is provided,
        # use a standard format given the axis type
        match self.axis_type:
            case "percent":
                return f".1%"

            case "number":
                return f",.0f"

            case "currency":
                return f"{self.currency_symbol},.0f"

            case "date":
                return "%m/%d/%y"

    def resolved_range(
        self, tick_values: list[int | float | date] = None
    ) -> (int | float | date, int | float | date):
        if tick_values:
            data_min = tick_values[0]
            data_max = tick_values[-1]
        else:
            data_min = None
            data_max = None

        is_min_pandas_timestamp_type = isinstance(
            data_min, (date, pd.Timestamp)
        ) or isinstance(self.axis_min, (date, pd.Timestamp))

        is_max_pandas_timestamp_type = isinstance(
            data_max, (date, pd.Timestamp)
        ) or isinstance(self.axis_max, (date, pd.Timestamp))

        if (
            is_min_pandas_timestamp_type
            and self.axis_min is not None
            and data_min is not None
            and pd.Timestamp(self.axis_min).date() < pd.Timestamp(data_min).date()
        ):
            data_min = self.axis_min

        elif (
            not is_min_pandas_timestamp_type
            and self.axis_min is not None
            and data_min is not None
            and self.axis_min < data_min
        ):
            data_min = self.axis_min

        if (
            is_max_pandas_timestamp_type
            and self.axis_max is not None
            and data_max is not None
            and pd.Timestamp(self.axis_max).date() > pd.Timestamp(data_max).date()
        ):
            data_min = self.axis_min

        elif (
            not is_max_pandas_timestamp_type
            and self.axis_max is not None
            and data_max is not None
            and self.axis_max > data_max
        ):
            data_max = self.axis_max

        if data_min is not None and data_max is not None:
            return (data_min, data_max)

        return None

    @property
    def resolved_number_of_ticks(self) -> int:
        if self.number_of_ticks is not None:
            return self.number_of_ticks

        return 12 if self.location == "x" else 6

    def format_value(self, value: float | int) -> str:
        return (
            f"{self.currency_symbol if self.axis_type == 'currency' else ''}"
            f"{format(value, self.resolved_tick_format.replace(self.currency_symbol, ''))}"
        )


[docs] axis = Axis
def _generate_numerical_tick_values( series_range: (int | float, int | float), axis_range: (Optional[int | float], Optional[int | float]) = (None, None), number_of_ticks: int = None, tick_interval: Optional[int | float] = None, location: str = "y1", ) -> list[int | float]: (min_tick, max_tick) = series_range (min_axis, max_axis) = axis_range # If null values are exclusively on the axis, return None for tick values min_tick_is_nan = min_tick is None or ( min_tick is not None and math.isnan(min_tick) ) max_tick_is_nan = max_tick is None or ( max_tick is not None and math.isnan(max_tick) ) if min_tick_is_nan and max_tick_is_nan: return None # Prefer specified min-axis when it is lower than the smallest data value observed if min_axis is not None and min_axis <= min_tick: min_tick = min_axis # Prefer specified max-axis when it is greater than the largest data value ovesrved if max_axis is not None and max_axis >= max_tick: max_tick = max_axis if min_tick == max_tick: raise ValueError("Axis min and max values must be different.") # If the axis only plots series with null values return None for tick values if min_tick is None or max_tick is None: return None # To avoid floating number rounding issues when axis sizing, truncate to 6 decimals min_tick = round(float(min_tick), DECIMAL_PRECISION) max_tick = round(float(max_tick), DECIMAL_PRECISION) # Step 1: Calculate the range data_range = max_tick - min_tick # Step 2: Calculate a rough tick interval if number_of_ticks is None: match location: case "x": number_of_ticks = 12 case "y1" | "y2": number_of_ticks = 6 rough_interval = float(data_range / number_of_ticks) # Step 3: Round to a "nice" number (1, 2, 5, or powers of 10) exponent = math.floor(math.log10(rough_interval)) base = 10**exponent precision = abs(exponent) if tick_interval is not None: nice_interval = tick_interval else: # Snap to 0.5, 1, 2, 2.5, or 5 if rough_interval / base <= 0.5: nice_interval = base / 2 elif rough_interval / base <= 1: nice_interval = base elif rough_interval / base <= 2: nice_interval = 2 * base elif rough_interval / base <= 2.5: nice_interval = 2.5 * base precision = precision + 1 else: nice_interval = 5 * base # Step 4: Calculate the "nice" axis limits nice_min = math.floor(min_tick / nice_interval) * nice_interval nice_max = math.ceil(max_tick / nice_interval) * nice_interval # Step 5: Generate ticks ticks = [] current_tick = nice_min i = 0 while current_tick < nice_max: current_tick = nice_min + (i * nice_interval) # When intervals are decimals, round to the appropriate precision # to avoid python floating numbers if exponent <= 0: current_tick = round(current_tick, abs(precision)) ticks.append(current_tick) i += 1 return ticks def _generate_date_tick_values( date_series: pd.Series, axis_range: (Optional[date | datetime], Optional[date | datetime]) = None, number_of_ticks: int = None, tick_interval: Optional[relativedelta] = None, ) -> list[date]: final_dates = [] sorted_dates = date_series.sort_values(ascending=True) min_date = sorted_dates.min() max_date = sorted_dates.max() if axis_range is None: axis_min, axis_max = None, None else: axis_min, axis_max = axis_range # Handle pandas type errors comparing pd.Timestamp with python date values try: if axis_min is not None and axis_min > min_date: min_date = axis_min except TypeError as e: if axis_min is not None and pd.Timestamp(axis_min) > pd.Timestamp(min_date): min_date = axis_min try: if axis_max is not None and axis_max < max_date: max_date = axis_max except TypeError as e: if axis_max is not None and pd.Timestamp(axis_max) < pd.Timestamp(max_date): max_date = axis_max # If no tick interval or number of ticks specified # use a default interval based on the amount of time elapsed in the series data if tick_interval is None and number_of_ticks is None: days_elapsed = (max_date - min_date).days if days_elapsed <= 90: tick_interval = relativedelta(days=7) elif days_elapsed <= 180: tick_interval = relativedelta(months=1) elif days_elapsed <= 360: tick_interval = relativedelta(months=3) elif days_elapsed <= 720: tick_interval = relativedelta(months=4) else: tick_interval = relativedelta(years=1) # If number of ticks is specified, generate a relative delta # based on the approximate days that matches elif tick_interval is None: days_elapsed = (max_date - min_date).days tick_interval = relativedelta(days=days_elapsed // (number_of_ticks - 1)) next_expected_date = None last_date = sorted_dates.iloc[-1] tick_cutoff = last_date - relativedelta( days=(last_date - (last_date - tick_interval)).total_seconds() / 86400 / 2 ) i = 0 for _, val in sorted_dates.items(): if val < min_date or val > max_date: continue # Always include the first date if i == 0: next_expected_date = val + tick_interval final_dates.append(val) # Include a date if it is at least 1 tick_interval after the prior selected date # and if it is at least 1 tick_interval before the last date elif val >= next_expected_date and val <= tick_cutoff: final_dates.append(val) next_expected_date = val + tick_interval i += 1 # Always include the last date final_dates.append(max_date) return final_dates def _nearest_interval_floor( number: int | float | Decimal, interval: int | float | Decimal ) -> int | float: return (float(number) // float(interval)) * interval def _nearest_interval_ceiling( number: int | float | Decimal, interval: int | float | Decimal ) -> int | float: if number % interval == 0: return (float(number) // float(interval)) * interval return ((float(number) // float(interval)) * interval) + interval def _round_up(number: int | float | Decimal, decimals=0) -> int | float: factor = 10**decimals return math.ceil(number * factor) / factor def _normalize_date(value: datetime | date | numpy.datetime64 | pd.Timestamp) -> date: if isinstance(value, pd.Timestamp): parsed = value.date() elif isinstance(value, numpy.datetime64): parsed = datetime.fromtimestamp(value.item() / (10**9)).date() elif isinstance(value, datetime): parsed = value.date() else: parsed = value return parsed def _get_date_tick_offset( range_min: datetime | date | numpy.datetime64 | pd.Timestamp, range_max: datetime | date | numpy.datetime64 | pd.Timestamp, x_data: pd.Series, ) -> int: # Determine offset by taking the first two tick values in plot range # and getting the delta between them. Otherwise, calculate by using min/max and number of ticks # the offset is equal to half the tick value tick_values_in_range = [] normalized_min = _normalize_date(range_min) normalized_max = _normalize_date(range_max) for value in sorted(x_data.unique()): normalized_value = _normalize_date(value) if normalized_value >= normalized_min and normalized_value <= normalized_max: tick_values_in_range.append(normalized_value) if len(tick_values_in_range) >= 2: offset = (tick_values_in_range[1] - tick_values_in_range[0]).days // 2 else: offset = ( (_normalize_date(range_max) - _normalize_date(range_min)).days / len(tick_values_in_range) // 2 ) return offset