You are reading an old version of the documentation (v1.4.3). For the latest version see https://matplotlib.org/stable/

We're updating the default styles for Matplotlib 2.0

Learn what to expect in the new updates

matplotlib

Previous topic

api example code: sankey_demo_links.py

Next topic

api example code: sankey_demo_rankine.py

This Page

api example code: sankey_demo_old.pyΒΆ

(Source code, png, hires.png, pdf)

../../_images/sankey_demo_old.png
#!/usr/bin/env python

from __future__ import print_function

__author__ = "Yannick Copin <[email protected]>"
__version__ = "Time-stamp: <10/02/2010 16:49 [email protected]>"

import numpy as np


def sankey(ax,
           outputs=[100.], outlabels=None,
           inputs=[100.], inlabels='',
           dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
    """Draw a Sankey diagram.

    outputs: array of outputs, should sum up to 100%
    outlabels: output labels (same length as outputs),
    or None (use default labels) or '' (no labels)
    inputs and inlabels: similar for inputs
    dx: horizontal elongation
    dy: vertical elongation
    outangle: output arrow angle [deg]
    w: output arrow shoulder
    inangle: input dip angle
    offset: text offset
    **kwargs: propagated to Patch (e.g., fill=False)

    Return (patch,[intexts,outtexts]).
    """
    import matplotlib.patches as mpatches
    from matplotlib.path import Path

    outs = np.absolute(outputs)
    outsigns = np.sign(outputs)
    outsigns[-1] = 0  # Last output

    ins = np.absolute(inputs)
    insigns = np.sign(inputs)
    insigns[0] = 0  # First input

    assert sum(outs) == 100, "Outputs don't sum up to 100%"
    assert sum(ins) == 100, "Inputs don't sum up to 100%"

    def add_output(path, loss, sign=1):
        h = (loss/2 + w)*np.tan(outangle/180. * np.pi)  # Arrow tip height
        move, (x, y) = path[-1]  # Use last point as reference
        if sign == 0:  # Final loss (horizontal)
            path.extend([(Path.LINETO, [x+dx, y]),
                         (Path.LINETO, [x+dx, y+w]),
                         (Path.LINETO, [x+dx+h, y-loss/2]),  # Tip
                         (Path.LINETO, [x+dx, y-loss-w]),
                         (Path.LINETO, [x+dx, y-loss])])
            outtips.append((sign, path[-3][1]))
        else:  # Intermediate loss (vertical)
            path.extend([(Path.CURVE4, [x+dx/2, y]),
                         (Path.CURVE4, [x+dx, y]),
                         (Path.CURVE4, [x+dx, y+sign*dy]),
                         (Path.LINETO, [x+dx-w, y+sign*dy]),
                         (Path.LINETO, [x+dx+loss/2, y+sign*(dy+h)]),  # Tip
                         (Path.LINETO, [x+dx+loss+w, y+sign*dy]),
                         (Path.LINETO, [x+dx+loss, y+sign*dy]),
                         (Path.CURVE3, [x+dx+loss, y-sign*loss]),
                         (Path.CURVE3, [x+dx/2+loss, y-sign*loss])])
            outtips.append((sign, path[-5][1]))

    def add_input(path, gain, sign=1):
        h = (gain/2)*np.tan(inangle/180. * np.pi)  # Dip depth
        move, (x, y) = path[-1]  # Use last point as reference
        if sign == 0:  # First gain (horizontal)
            path.extend([(Path.LINETO, [x-dx, y]),
                         (Path.LINETO, [x-dx+h, y+gain/2]),  # Dip
                         (Path.LINETO, [x-dx, y+gain])])
            xd, yd = path[-2][1]  # Dip position
            indips.append((sign, [xd-h, yd]))
        else:  # Intermediate gain (vertical)
            path.extend([(Path.CURVE4, [x-dx/2, y]),
                         (Path.CURVE4, [x-dx, y]),
                         (Path.CURVE4, [x-dx, y+sign*dy]),
                         (Path.LINETO, [x-dx-gain/2, y+sign*(dy-h)]),  # Dip
                         (Path.LINETO, [x-dx-gain, y+sign*dy]),
                         (Path.CURVE3, [x-dx-gain, y-sign*gain]),
                         (Path.CURVE3, [x-dx/2-gain, y-sign*gain])])
            xd, yd = path[-4][1]  # Dip position
            indips.append((sign, [xd, yd+sign*h]))

    outtips = []  # Output arrow tip dir. and positions
    urpath = [(Path.MOVETO, [0, 100])]  # 1st point of upper right path
    lrpath = [(Path.LINETO, [0, 0])]  # 1st point of lower right path
    for loss, sign in zip(outs, outsigns):
        add_output(sign>=0 and urpath or lrpath, loss, sign=sign)

    indips = []  # Input arrow tip dir. and positions
    llpath = [(Path.LINETO, [0, 0])]  # 1st point of lower left path
    ulpath = [(Path.MOVETO, [0, 100])]  # 1st point of upper left path
    for gain, sign in reversed(list(zip(ins, insigns))):
        add_input(sign<=0 and llpath or ulpath, gain, sign=sign)

    def revert(path):
        """A path is not just revertable by path[::-1] because of Bezier
        curves."""
        rpath = []
        nextmove = Path.LINETO
        for move, pos in path[::-1]:
            rpath.append((nextmove, pos))
            nextmove = move
        return rpath

    # Concatenate subpathes in correct order
    path = urpath + revert(lrpath) + llpath + revert(ulpath)

    codes, verts = zip(*path)
    verts = np.array(verts)

    # Path patch
    path = Path(verts, codes)
    patch = mpatches.PathPatch(path, **kwargs)
    ax.add_patch(patch)

    if False:  # DEBUG
        print("urpath", urpath)
        print("lrpath", revert(lrpath))
        print("llpath", llpath)
        print("ulpath", revert(ulpath))
        xs, ys = zip(*verts)
        ax.plot(xs, ys, 'go-')

    # Labels

    def set_labels(labels, values):
        """Set or check labels according to values."""
        if labels == '':  # No labels
            return labels
        elif labels is None:  # Default labels
            return ['%2d%%' % val for val in values]
        else:
            assert len(labels) == len(values)
            return labels

    def put_labels(labels, positions, output=True):
        """Put labels to positions."""
        texts = []
        lbls = output and labels or labels[::-1]
        for i, label in enumerate(lbls):
            s, (x, y) = positions[i]  # Label direction and position
            if s == 0:
                t = ax.text(x+offset, y, label,
                            ha=output and 'left' or 'right', va='center')
            elif s > 0:
                t = ax.text(x, y+offset, label, ha='center', va='bottom')
            else:
                t = ax.text(x, y-offset, label, ha='center', va='top')
            texts.append(t)
        return texts

    outlabels = set_labels(outlabels, outs)
    outtexts = put_labels(outlabels, outtips, output=True)

    inlabels = set_labels(inlabels, ins)
    intexts = put_labels(inlabels, indips, output=False)

    # Axes management
    ax.set_xlim(verts[:, 0].min()-dx, verts[:, 0].max()+dx)
    ax.set_ylim(verts[:, 1].min()-dy, verts[:, 1].max()+dy)
    ax.set_aspect('equal', adjustable='datalim')

    return patch, [intexts, outtexts]


if __name__=='__main__':

    import matplotlib.pyplot as plt

    outputs = [10., -20., 5., 15., -10., 40.]
    outlabels = ['First', 'Second', 'Third', 'Fourth', 'Fifth', 'Hurray!']
    outlabels = [s+'\n%d%%' % abs(l) for l, s in zip(outputs, outlabels)]

    inputs = [60., -25., 15.]

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[], title="Sankey diagram")

    patch, (intexts, outtexts) = sankey(ax, outputs=outputs,
                                        outlabels=outlabels, inputs=inputs,
                                        inlabels=None, fc='g', alpha=0.2)
    outtexts[1].set_color('r')
    outtexts[-1].set_fontweight('bold')

    plt.show()

Keywords: python, matplotlib, pylab, example, codex (see Search examples)