from typing import Optional
import string
import jmespath
from plotly import graph_objects as go
import pandas as pd
import numpy as np
import base64
from visualization_toolkit.helpers.plotly.theme import YD_CLASSIC_THEME
def decode_binary_data(bdata_str, dtype):
"""Decode base64 binary data back to numpy array"""
# Decode base64 string to bytes
binary_data = base64.b64decode(bdata_str)
# Convert bytes to numpy array with specified dtype
return np.frombuffer(binary_data, dtype=dtype)
def generate_pdf_from_figure(fig: go.Figure) -> pd.DataFrame:
"""
Iterate over plotly figure and construct a pandas dataframe given the x-axis and traces within the figure.
Each trace will be included as a different column in the dataframe while sharing the x-axis.
:param fig:
:return:
Examples
^^^^^^^^^^^^^
.. code-block:: python
:caption: Generate the pandas dataframe from a chart function.
from visualization_toolkit.helpers.plotly import generate_pdf_from_figure, chart
# Assume a chart is created with series
fig = chart(...)
# Return the chart data as a pdf to use elsewhere
pdf = generate_pdf_from_figure(fig)
display(pdf)
"""
# First set up the x-axis
# then iterate over each trace of the figure and generate a pandas dataframe for the trace
# then finally union the dataframes and pivot the data so that each trace is a different column
pdfs = []
x_column_name = jmespath.search(
"layout.xaxis.title.text || layout.xaxis.type || `x`", fig
)
x_column_name = string.capwords(
str(x_column_name).replace("_", " ").replace("-", " ")
)
for trace in fig["data"]:
# Store the trace name as an additional static column, this will be used to pivot the data
if "bdata" in trace["y"]:
deserialized_y = decode_binary_data(trace["y"]["bdata"], dtype=np.float64)
y_values = deserialized_y
else:
y_values = trace["y"]
pdf = pd.DataFrame(
{
x_column_name: trace["x"],
"y": y_values,
"trace_name": trace.get("name")
if isinstance(trace, dict)
else getattr(trace, "name"),
}
)
pdfs.append(pdf)
combined_pdf = pd.concat(pdfs, ignore_index=True)
pivoted_df = combined_pdf.pivot(
index=x_column_name, columns="trace_name", values="y"
).reset_index()
return pivoted_df
[docs]
def disable_hover_labels(fig: go.Figure):
"""
Disable hover labels for a plotly figure. This works by removing or hiding hover labels, spikelines, and hovertext.
The hoverData parameter will still trigger dash callbacks.
.. warning::
When disabling hover labels, you will also need to set the following css attributes on a parent dash element:
.. code-block:: css
& .hovertext {
fill-opacity: 0;
stroke-opacity: 0;
}
:param fig: Plotly figure to disable hover labels for
:return:
Examples
^^^^^^^^^^^^^
.. code-block:: python
:caption: Disable plotly hover labels for a figure
from visualization_toolkit.helpers.plotly import generate_pdf_from_figure, chart
# Assume a chart is created with series
fig = chart(...)
# Disable hover labels for the figure
disable_hover_labels(fig)
display(fig)
"""
fig.update_xaxes(showspikes=False)
fig.update_yaxes(showspikes=False)
fig.update_traces(
hoverinfo="none",
hoverlabel=dict(
namelength=0,
bgcolor="rgba(255,255,255,0)",
bordercolor="rgba(255,255,255,0)",
font_color="rgba(255,255,255,0)",
),
hovertemplate=" ",
)
[docs]
def highlight_trace(
fig: go.Figure,
trace_index: int,
selected_opacity: float = 1,
unselected_opacity: float = 0.4,
):
"""
Highlight a trace in a plotly figure. This works by setting the opacity of the trace to 1 and the other traces to 0.4.
:param fig: Plotly figure to highlight a trace in
:param trace_index: Index of the trace to highlight
:param selected_opacity: Opacity of the selected trace
:param unselected_opacity: Opacity of the unselected traces
:return:
Examples
^^^^^^^^^^^^^
"""
for idx, trace in enumerate(fig["data"]):
if idx == trace_index:
trace["opacity"] = selected_opacity
else:
trace["opacity"] = unselected_opacity