Source code for matplotlib.category
"""
Module that allows plotting of string "category" data. i.e.
``plot(['d', 'f', 'a'],[1, 2, 3])`` will plot three points with x-axis
values of 'd', 'f', 'a'.
See :doc:`/gallery/lines_bars_and_markers/categorical_variables` for an
example.
The module uses Matplotlib's `matplotlib.units` mechanism to convert from
strings to integers, provides a tick locator and formatter, and the
class:`.UnitData` that creates and stores the string-to-integer mapping.
"""
from collections import OrderedDict
import itertools
import numpy as np
import matplotlib.units as units
import matplotlib.ticker as ticker
[docs]class StrCategoryConverter(units.ConversionInterface):
[docs] @staticmethod
def convert(value, unit, axis):
"""Converts strings in value to floats using
mapping information store in the unit object.
Parameters
----------
value : string or iterable
value or list of values to be converted
unit : :class:`.UnitData`
object string unit information for value
axis : :class:`~matplotlib.Axis.axis`
axis on which the converted value is plotted
Returns
-------
mapped_ value : float or ndarray[float]
.. note:: axis is not used in this function
"""
# dtype = object preserves numerical pass throughs
values = np.atleast_1d(np.array(value, dtype=object))
# pass through sequence of non binary numbers
if all((units.ConversionInterface.is_numlike(v) and
not isinstance(v, (str, bytes))) for v in values):
return np.asarray(values, dtype=float)
# force an update so it also does type checking
unit.update(values)
str2idx = np.vectorize(unit._mapping.__getitem__,
otypes=[float])
mapped_value = str2idx(values)
return mapped_value
[docs] @staticmethod
def axisinfo(unit, axis):
"""Sets the default axis ticks and labels
Parameters
----------
unit : `.UnitData`
object string unit information for value
axis : `~matplotlib.Axis.axis`
axis for which information is being set
Returns
-------
axisinfo : `~matplotlib.units.AxisInfo`
Information to support default tick labeling
.. note: axis is not used
"""
# locator and formatter take mapping dict because
# args need to be pass by reference for updates
majloc = StrCategoryLocator(unit._mapping)
majfmt = StrCategoryFormatter(unit._mapping)
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
[docs] @staticmethod
def default_units(data, axis):
"""Sets and updates the :class:`~matplotlib.Axis.axis` units.
Parameters
----------
data : string or iterable of strings
axis : :class:`~matplotlib.Axis.axis`
axis on which the data is plotted
Returns
-------
class:~.UnitData~
object storing string to integer mapping
"""
# the conversion call stack is supposed to be
# default_units->axis_info->convert
if axis.units is None:
axis.set_units(UnitData(data))
else:
axis.units.update(data)
return axis.units
[docs]class StrCategoryLocator(ticker.Locator):
"""tick at every integer mapping of the string data"""
def __init__(self, units_mapping):
"""
Parameters
-----------
units_mapping : Dict[str, int]
string:integer mapping
"""
self._units = units_mapping
def __call__(self):
return list(self._units.values())
[docs] def tick_values(self, vmin, vmax):
return self()
[docs]class UnitData(object):
def __init__(self, data=None):
"""
Create mapping between unique categorical values and integer ids.
Parameters
----------
data: iterable
sequence of string values
"""
self._mapping = OrderedDict()
self._counter = itertools.count()
if data is not None:
self.update(data)
[docs] def update(self, data):
"""Maps new values to integer identifiers.
Parameters
----------
data: iterable
sequence of string values
Raises
------
TypeError
If the value in data is not a string, unicode, bytes type
"""
data = np.atleast_1d(np.array(data, dtype=object))
for val in OrderedDict.fromkeys(data):
if not isinstance(val, (str, bytes)):
raise TypeError("{val!r} is not a string".format(val=val))
if val not in self._mapping:
self._mapping[val] = next(self._counter)
# Connects the convertor to matplotlib
units.registry[str] = StrCategoryConverter()
units.registry[np.str_] = StrCategoryConverter()
units.registry[bytes] = StrCategoryConverter()
units.registry[np.bytes_] = StrCategoryConverter()