"""Interactive figures in the Jupyter notebook"""
import sys
import types
from warnings import warn
# In the case of a pyodide context (JupyterLite)
# we mock Tornado, as it cannot be imported and would
# lead to errors. Of course, any code using tornado cannot
# work with this workaround, we prefer a half-working
# solution than a non-working solution here.
try:
import micropip # noqa
sys.modules['tornado'] = types.ModuleType('tornadofake')
except ModuleNotFoundError:
pass
import io
import json
from base64 import b64encode
try:
from collections.abc import Iterable
except ImportError:
# Python 2.7
from collections import Iterable
import matplotlib
import numpy as np
from IPython import get_ipython
from IPython import version_info as ipython_version_info
from IPython.display import HTML, display
from ipython_genutils.py3compat import string_types
from ipywidgets import DOMWidget, widget_serialization
from matplotlib import is_interactive, rcParams
from matplotlib._pylab_helpers import Gcf
from matplotlib.backend_bases import NavigationToolbar2, _Backend, cursors
from matplotlib.backends.backend_webagg_core import (
FigureCanvasWebAggCore,
FigureManagerWebAgg,
NavigationToolbar2WebAgg,
TimerTornado,
)
from PIL import Image
from traitlets import (
Any,
Bool,
CaselessStrEnum,
CInt,
Enum,
Float,
Instance,
List,
Tuple,
Unicode,
default,
observe,
)
from ._version import js_semver
cursors_str = {
cursors.HAND: 'pointer',
cursors.POINTER: 'default',
cursors.SELECT_REGION: 'crosshair',
cursors.MOVE: 'move',
cursors.WAIT: 'wait',
}
def connection_info():
"""
Return a string showing the figure and connection status for
the backend. This is intended as a diagnostic tool, and not for general
use.
"""
result = []
for manager in Gcf.get_all_fig_managers():
fig = manager.canvas.figure
result.append(
'{} - {}'.format(
(fig.get_label() or f"Figure {manager.num}"),
manager.web_sockets,
)
)
if not is_interactive():
result.append(f'Figures pending show: {len(Gcf._activeQue)}')
return '\n'.join(result)
class Toolbar(DOMWidget, NavigationToolbar2WebAgg):
_model_module = Unicode('jupyter-matplotlib').tag(sync=True)
_model_module_version = Unicode(js_semver).tag(sync=True)
_model_name = Unicode('ToolbarModel').tag(sync=True)
_view_module = Unicode('jupyter-matplotlib').tag(sync=True)
_view_module_version = Unicode(js_semver).tag(sync=True)
_view_name = Unicode('ToolbarView').tag(sync=True)
toolitems = List().tag(sync=True)
button_style = CaselessStrEnum(
values=['primary', 'success', 'info', 'warning', 'danger', ''],
default_value='',
help="""Use a predefined styling for the button.""",
).tag(sync=True)
#######
# Those traits are deprecated
orientation = Enum(['horizontal', 'vertical'], default_value='vertical').tag(
sync=True
)
collapsed = Bool(True).tag(sync=True)
#######
_current_action = Enum(values=['pan', 'zoom', ''], default_value='').tag(sync=True)
def __init__(self, canvas, *args, **kwargs):
DOMWidget.__init__(self, *args, **kwargs)
NavigationToolbar2WebAgg.__init__(self, canvas, *args, **kwargs)
self.on_msg(self.canvas._handle_message)
def export(self):
buf = io.BytesIO()
self.canvas.figure.savefig(buf, format='png', dpi='figure')
# Figure width in pixels
pwidth = self.canvas.figure.get_figwidth() * self.canvas.figure.get_dpi()
# Scale size to match widget on HiDPI monitors.
if hasattr(self.canvas, 'device_pixel_ratio'): # Matplotlib 3.5+
width = pwidth / self.canvas.device_pixel_ratio
else:
width = pwidth / self.canvas._dpi_ratio
data = ""
data = data.format(b64encode(buf.getvalue()).decode('utf-8'), width)
display(HTML(data))
@default('toolitems')
def _default_toolitems(self):
icons = {
'home': 'home',
'back': 'arrow-left',
'forward': 'arrow-right',
'zoom_to_rect': 'square-o',
'move': 'arrows',
'download': 'floppy-o',
'export': 'file-picture-o',
}
download_item = ('Download', 'Download plot', 'download', 'save_figure')
toolitems = NavigationToolbar2.toolitems + (download_item,)
return [
(text, tooltip, icons[icon_name], method_name)
for text, tooltip, icon_name, method_name in toolitems
if icon_name in icons
]
def __getattr__(self, name):
if name in ['orientation', 'collapsed']:
warn(
"The Toolbar properties 'orientation' and 'collapsed' are deprecated."
"Accessing them will raise an error in a coming ipympl release",
DeprecationWarning,
)
return super().__getattr__(name)
@observe('orientation', 'collapsed')
def _on_orientation_collapsed_changed(self, change):
warn(
"The Toolbar properties 'orientation' and 'collapsed' are deprecated.",
DeprecationWarning,
)
class Canvas(DOMWidget, FigureCanvasWebAggCore):
_model_module = Unicode('jupyter-matplotlib').tag(sync=True)
_model_module_version = Unicode(js_semver).tag(sync=True)
_model_name = Unicode('MPLCanvasModel').tag(sync=True)
_view_module = Unicode('jupyter-matplotlib').tag(sync=True)
_view_module_version = Unicode(js_semver).tag(sync=True)
_view_name = Unicode('MPLCanvasView').tag(sync=True)
toolbar = Instance(Toolbar, allow_none=True).tag(sync=True, **widget_serialization)
toolbar_visible = Enum(
['visible', 'hidden', 'fade-in-fade-out', True, False],
default_value='fade-in-fade-out',
).tag(sync=True)
toolbar_position = Enum(
['top', 'bottom', 'left', 'right'], default_value='left'
).tag(sync=True)
header_visible = Bool(True).tag(sync=True)
footer_visible = Bool(True).tag(sync=True)
resizable = Bool(True).tag(sync=True)
capture_scroll = Bool(False).tag(sync=True)
pan_zoom_throttle = Float(33).tag(sync=True)
# This is a very special widget trait:
# We set "sync=True" because we want ipywidgets to consider this
# as part of the widget state, but we overwrite send_state so that
# it never sync the value with the front-end, the front-end keeps its
# own value.
# This will still be used by ipywidgets in the case of embedding.
_data_url = Any(None).tag(sync=True)
_size = Tuple([0, 0]).tag(sync=True)
_figure_label = Unicode('Figure').tag(sync=True)
_message = Unicode().tag(sync=True)
_cursor = Unicode('pointer').tag(sync=True)
_image_mode = Unicode('full').tag(sync=True)
_rubberband_x = CInt(0).tag(sync=True)
_rubberband_y = CInt(0).tag(sync=True)
_rubberband_width = CInt(0).tag(sync=True)
_rubberband_height = CInt(0).tag(sync=True)
_closed = Bool(True)
# Must declare the superclass private members.
_png_is_old = Bool()
_force_full = Bool()
_current_image_mode = Unicode()
# Static as it should be the same for all canvases
current_dpi_ratio = 1.0
def __init__(self, figure, *args, **kwargs):
DOMWidget.__init__(self, *args, **kwargs)
FigureCanvasWebAggCore.__init__(self, figure, *args, **kwargs)
self.on_msg(self._handle_message)
# This will stay True for cases where there is no
# front-end (e.g. nbconvert --execute)
self.syncing_data_url = True
# Overwrite ipywidgets's send_state so we don't sync the data_url
def send_state(self, key=None):
if key is None:
keys = self.keys
elif isinstance(key, string_types):
keys = [key]
elif isinstance(key, Iterable):
keys = key
if not self.syncing_data_url:
keys = [k for k in keys if k != '_data_url']
DOMWidget.send_state(self, key=keys)
def _handle_message(self, object, content, buffers):
# Every content has a "type".
if content['type'] == 'closing':
self._closed = True
elif content['type'] == 'initialized':
# We stop syncing data url, the front-end is there and
# ready to receive diffs
self.syncing_data_url = False
_, _, w, h = self.figure.bbox.bounds
self.manager.resize(w, h)
elif content['type'] == 'set_dpi_ratio':
Canvas.current_dpi_ratio = content['dpi_ratio']
self.manager.handle_json(content)
else:
self.manager.handle_json(content)
def send_json(self, content):
# Change in the widget state?
if content['type'] == 'cursor':
cursor = content['cursor']
self._cursor = cursors_str[cursor] if cursor in cursors_str else cursor
elif content['type'] == 'message':
self._message = content['message']
elif content['type'] == 'figure_label':
self._figure_label = content['label']
elif content['type'] == 'resize':
self._size = content['size']
# Send resize message anyway:
# We absolutely need this instead of a `_size` trait change listening
# on the front-end, otherwise ipywidgets might squash multiple changes
# and the resizing protocol is not respected anymore
self.send({'data': json.dumps(content)})
elif content['type'] == 'image_mode':
self._image_mode = content['mode']
else:
# Default: send the message to the front-end
self.send({'data': json.dumps(content)})
def send_binary(self, data):
# TODO we should maybe rework the FigureCanvasWebAggCore implementation
# so that it has a "refresh" method that we can overwrite
# Update _data_url
if self.syncing_data_url:
data = self._last_buff.view(dtype=np.uint8).reshape(
(*self._last_buff.shape, 4)
)
with io.BytesIO() as png:
Image.fromarray(data).save(png, format="png")
self._data_url = b64encode(png.getvalue()).decode('utf-8')
# Actually send the data
self.send({'data': '{"type": "binary"}'}, buffers=[data])
def new_timer(self, *args, **kwargs):
return TimerTornado(*args, **kwargs)
def _repr_mimebundle_(self, **kwargs):
# now happens before the actual display call.
if hasattr(self, '_handle_displayed'):
self._handle_displayed(**kwargs)
plaintext = repr(self)
if len(plaintext) > 110:
plaintext = plaintext[:110] + '…'
buf = io.BytesIO()
self.figure.savefig(buf, format='png', dpi='figure')
base64_image = b64encode(buf.getvalue()).decode('utf-8')
self._data_url = f'data:image/png;base64,{base64_image}'
# Figure size in pixels
pwidth = self.figure.get_figwidth() * self.figure.get_dpi()
pheight = self.figure.get_figheight() * self.figure.get_dpi()
# Scale size to match widget on HiDPI monitors.
if hasattr(self, 'device_pixel_ratio'): # Matplotlib 3.5+
width = pwidth / self.device_pixel_ratio
height = pheight / self.device_pixel_ratio
else:
width = pwidth / self._dpi_ratio
height = pheight / self._dpi_ratio
html = """