from __future__ import (absolute_import, division, print_function,
unicode_literals)
import six
import numpy as np
from .axes_divider import make_axes_locatable, Size, locatable_axes_factory
import sys
from .mpl_axes import Axes
[docs]def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True):
"""
pad : fraction of the axes height.
"""
divider = make_axes_locatable(ax)
pad_size = Size.Fraction(pad, Size.AxesY(ax))
xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))
divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
ax_rgb = []
if axes_class is None:
try:
axes_class = locatable_axes_factory(ax._axes_class)
except AttributeError:
axes_class = locatable_axes_factory(type(ax))
for ny in [4, 2, 0]:
ax1 = axes_class(ax.get_figure(),
ax.get_position(original=True),
sharex=ax, sharey=ax)
locator = divider.new_locator(nx=2, ny=ny)
ax1.set_axes_locator(locator)
for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
t.set_visible(False)
try:
for axis in ax1.axis.values():
axis.major_ticklabels.set_visible(False)
except AttributeError:
pass
ax_rgb.append(ax1)
if add_all:
fig = ax.get_figure()
for ax1 in ax_rgb:
fig.add_axes(ax1)
return ax_rgb
[docs]def imshow_rgb(ax, r, g, b, **kwargs):
ny, nx = r.shape
R = np.zeros([ny, nx, 3], dtype="d")
R[:,:,0] = r
G = np.zeros_like(R)
G[:,:,1] = g
B = np.zeros_like(R)
B[:,:,2] = b
RGB = R + G + B
im_rgb = ax.imshow(RGB, **kwargs)
return im_rgb
[docs]class RGBAxesBase(object):
"""base class for a 4-panel imshow (RGB, R, G, B)
Layout:
+---------------+-----+
| | R |
+ +-----+
| RGB | G |
+ +-----+
| | B |
+---------------+-----+
Attributes
----------
_defaultAxesClass : matplotlib.axes.Axes
defaults to 'Axes' in RGBAxes child class.
No default in abstract base class
RGB : _defaultAxesClass
The axes object for the three-channel imshow
R : _defaultAxesClass
The axes object for the red channel imshow
G : _defaultAxesClass
The axes object for the green channel imshow
B : _defaultAxesClass
The axes object for the blue channel imshow
"""
def __init__(self, *kl, **kwargs):
"""
Parameters
----------
pad : float
fraction of the axes height to put as padding.
defaults to 0.0
add_all : bool
True: Add the {rgb, r, g, b} axes to the figure
defaults to True.
axes_class : matplotlib.axes.Axes
kl :
Unpacked into axes_class() init for RGB
kwargs :
Unpacked into axes_class() init for RGB, R, G, B axes
"""
pad = kwargs.pop("pad", 0.0)
add_all = kwargs.pop("add_all", True)
try:
axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
except AttributeError:
new_msg = ("A subclass of RGBAxesBase must have a "
"_defaultAxesClass attribute. If you are not sure which "
"axes class to use, consider using "
"mpl_toolkits.axes_grid1.mpl_axes.Axes.")
six.reraise(AttributeError, AttributeError(new_msg),
sys.exc_info()[2])
ax = axes_class(*kl, **kwargs)
divider = make_axes_locatable(ax)
pad_size = Size.Fraction(pad, Size.AxesY(ax))
xsize = Size.Fraction((1.-2.*pad)/3., Size.AxesX(ax))
ysize = Size.Fraction((1.-2.*pad)/3., Size.AxesY(ax))
divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])
ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))
ax_rgb = []
for ny in [4, 2, 0]:
ax1 = axes_class(ax.get_figure(),
ax.get_position(original=True),
sharex=ax, sharey=ax, **kwargs)
locator = divider.new_locator(nx=2, ny=ny)
ax1.set_axes_locator(locator)
ax1.axis[:].toggle(ticklabels=False)
ax_rgb.append(ax1)
self.RGB = ax
self.R, self.G, self.B = ax_rgb
if add_all:
fig = ax.get_figure()
fig.add_axes(ax)
self.add_RGB_to_figure()
self._config_axes()
def _config_axes(self, line_color='w', marker_edge_color='w'):
"""Set the line color and ticks for the axes
Parameters
----------
line_color : any matplotlib color
marker_edge_color : any matplotlib color
"""
for ax1 in [self.RGB, self.R, self.G, self.B]:
ax1.axis[:].line.set_color(line_color)
ax1.axis[:].major_ticks.set_markeredgecolor(marker_edge_color)
[docs] def imshow_rgb(self, r, g, b, **kwargs):
"""Create the four images {rgb, r, g, b}
Parameters
----------
r : array-like
The red array
g : array-like
The green array
b : array-like
The blue array
kwargs : imshow kwargs
kwargs get unpacked into the imshow calls for the four images
Returns
-------
rgb : matplotlib.image.AxesImage
r : matplotlib.image.AxesImage
g : matplotlib.image.AxesImage
b : matplotlib.image.AxesImage
"""
if not (r.shape == g.shape == b.shape):
raise ValueError('Input shapes do not match.'
'\nr.shape = {}'
'\ng.shape = {}'
'\nb.shape = {}'
.format(r.shape, g.shape, b.shape))
RGB = np.dstack([r, g, b])
R = np.zeros_like(RGB)
R[:,:,0] = r
G = np.zeros_like(RGB)
G[:,:,1] = g
B = np.zeros_like(RGB)
B[:,:,2] = b
im_rgb = self.RGB.imshow(RGB, **kwargs)
im_r = self.R.imshow(R, **kwargs)
im_g = self.G.imshow(G, **kwargs)
im_b = self.B.imshow(B, **kwargs)
return im_rgb, im_r, im_g, im_b
[docs]class RGBAxes(RGBAxesBase):
_defaultAxesClass = Axes