.. _units-basic_units: units example code: basic_units.py ================================== .. plot:: /home/tcaswell/source/p/matplotlib/doc/mpl_examples/units/basic_units.py :: import math import numpy as np import matplotlib.units as units import matplotlib.ticker as ticker from matplotlib.axes import Axes from matplotlib.cbook import iterable class ProxyDelegate(object): def __init__(self, fn_name, proxy_type): self.proxy_type = proxy_type self.fn_name = fn_name def __get__(self, obj, objtype=None): return self.proxy_type(self.fn_name, obj) class TaggedValueMeta (type): def __init__(cls, name, bases, dict): for fn_name in cls._proxies.keys(): try: dummy = getattr(cls, fn_name) except AttributeError: setattr(cls, fn_name, ProxyDelegate(fn_name, cls._proxies[fn_name])) class PassThroughProxy(object): def __init__(self, fn_name, obj): self.fn_name = fn_name self.target = obj.proxy_target def __call__(self, *args): fn = getattr(self.target, self.fn_name) ret = fn(*args) return ret class ConvertArgsProxy(PassThroughProxy): def __init__(self, fn_name, obj): PassThroughProxy.__init__(self, fn_name, obj) self.unit = obj.unit def __call__(self, *args): converted_args = [] for a in args: try: converted_args.append(a.convert_to(self.unit)) except AttributeError: converted_args.append(TaggedValue(a, self.unit)) converted_args = tuple([c.get_value() for c in converted_args]) return PassThroughProxy.__call__(self, *converted_args) class ConvertReturnProxy(PassThroughProxy): def __init__(self, fn_name, obj): PassThroughProxy.__init__(self, fn_name, obj) self.unit = obj.unit def __call__(self, *args): ret = PassThroughProxy.__call__(self, *args) if (type(ret) == type(NotImplemented)): return NotImplemented return TaggedValue(ret, self.unit) class ConvertAllProxy(PassThroughProxy): def __init__(self, fn_name, obj): PassThroughProxy.__init__(self, fn_name, obj) self.unit = obj.unit def __call__(self, *args): converted_args = [] arg_units = [self.unit] for a in args: if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'): # if this arg has a unit type but no conversion ability, # this operation is prohibited return NotImplemented if hasattr(a, 'convert_to'): try: a = a.convert_to(self.unit) except: pass arg_units.append(a.get_unit()) converted_args.append(a.get_value()) else: converted_args.append(a) if hasattr(a, 'get_unit'): arg_units.append(a.get_unit()) else: arg_units.append(None) converted_args = tuple(converted_args) ret = PassThroughProxy.__call__(self, *converted_args) if (type(ret) == type(NotImplemented)): return NotImplemented ret_unit = unit_resolver(self.fn_name, arg_units) if (ret_unit == NotImplemented): return NotImplemented return TaggedValue(ret, ret_unit) class _TaggedValue(object): _proxies = {'__add__': ConvertAllProxy, '__sub__': ConvertAllProxy, '__mul__': ConvertAllProxy, '__rmul__': ConvertAllProxy, '__cmp__': ConvertAllProxy, '__lt__': ConvertAllProxy, '__gt__': ConvertAllProxy, '__len__': PassThroughProxy} def __new__(cls, value, unit): # generate a new subclass for value value_class = type(value) try: subcls = type('TaggedValue_of_%s' % (value_class.__name__), tuple([cls, value_class]), {}) if subcls not in units.registry: units.registry[subcls] = basicConverter return object.__new__(subcls) except TypeError: if cls not in units.registry: units.registry[cls] = basicConverter return object.__new__(cls) def __init__(self, value, unit): self.value = value self.unit = unit self.proxy_target = self.value def __getattribute__(self, name): if (name.startswith('__')): return object.__getattribute__(self, name) variable = object.__getattribute__(self, 'value') if (hasattr(variable, name) and name not in self.__class__.__dict__): return getattr(variable, name) return object.__getattribute__(self, name) def __array__(self, t=None, context=None): if t is not None: return np.asarray(self.value).astype(t) else: return np.asarray(self.value, 'O') def __array_wrap__(self, array, context): return TaggedValue(array, self.unit) def __repr__(self): return 'TaggedValue(' + repr(self.value) + ', ' + repr(self.unit) + ')' def __str__(self): return str(self.value) + ' in ' + str(self.unit) def __len__(self): return len(self.value) def __iter__(self): class IteratorProxy(object): def __init__(self, iter, unit): self.iter = iter self.unit = unit def __next__(self): value = next(self.iter) return TaggedValue(value, self.unit) next = __next__ # for Python 2 return IteratorProxy(iter(self.value), self.unit) def get_compressed_copy(self, mask): new_value = np.ma.masked_array(self.value, mask=mask).compressed() return TaggedValue(new_value, self.unit) def convert_to(self, unit): if (unit == self.unit or not unit): return self new_value = self.unit.convert_value_to(self.value, unit) return TaggedValue(new_value, unit) def get_value(self): return self.value def get_unit(self): return self.unit TaggedValue = TaggedValueMeta('TaggedValue', (_TaggedValue, ), {}) class BasicUnit(object): def __init__(self, name, fullname=None): self.name = name if fullname is None: fullname = name self.fullname = fullname self.conversions = dict() def __repr__(self): return 'BasicUnit(%s)' % self.name def __str__(self): return self.fullname def __call__(self, value): return TaggedValue(value, self) def __mul__(self, rhs): value = rhs unit = self if hasattr(rhs, 'get_unit'): value = rhs.get_value() unit = rhs.get_unit() unit = unit_resolver('__mul__', (self, unit)) if (unit == NotImplemented): return NotImplemented return TaggedValue(value, unit) def __rmul__(self, lhs): return self*lhs def __array_wrap__(self, array, context): return TaggedValue(array, self) def __array__(self, t=None, context=None): ret = np.array([1]) if t is not None: return ret.astype(t) else: return ret def add_conversion_factor(self, unit, factor): def convert(x): return x*factor self.conversions[unit] = convert def add_conversion_fn(self, unit, fn): self.conversions[unit] = fn def get_conversion_fn(self, unit): return self.conversions[unit] def convert_value_to(self, value, unit): conversion_fn = self.conversions[unit] ret = conversion_fn(value) return ret def get_unit(self): return self class UnitResolver(object): def addition_rule(self, units): for unit_1, unit_2 in zip(units[:-1], units[1:]): if (unit_1 != unit_2): return NotImplemented return units[0] def multiplication_rule(self, units): non_null = [u for u in units if u] if (len(non_null) > 1): return NotImplemented return non_null[0] op_dict = { '__mul__': multiplication_rule, '__rmul__': multiplication_rule, '__add__': addition_rule, '__radd__': addition_rule, '__sub__': addition_rule, '__rsub__': addition_rule} def __call__(self, operation, units): if (operation not in self.op_dict): return NotImplemented return self.op_dict[operation](self, units) unit_resolver = UnitResolver() cm = BasicUnit('cm', 'centimeters') inch = BasicUnit('inch', 'inches') inch.add_conversion_factor(cm, 2.54) cm.add_conversion_factor(inch, 1/2.54) radians = BasicUnit('rad', 'radians') degrees = BasicUnit('deg', 'degrees') radians.add_conversion_factor(degrees, 180.0/np.pi) degrees.add_conversion_factor(radians, np.pi/180.0) secs = BasicUnit('s', 'seconds') hertz = BasicUnit('Hz', 'Hertz') minutes = BasicUnit('min', 'minutes') secs.add_conversion_fn(hertz, lambda x: 1./x) secs.add_conversion_factor(minutes, 1/60.0) # radians formatting def rad_fn(x, pos=None): n = int((x / np.pi) * 2.0 + 0.25) if n == 0: return '0' elif n == 1: return r'$\pi/2$' elif n == 2: return r'$\pi$' elif n % 2 == 0: return r'$%s\pi$' % (n//2,) else: return r'$%s\pi/2$' % (n,) class BasicUnitConverter(units.ConversionInterface): @staticmethod def axisinfo(unit, axis): 'return AxisInfo instance for x and unit' if unit == radians: return units.AxisInfo( majloc=ticker.MultipleLocator(base=np.pi/2), majfmt=ticker.FuncFormatter(rad_fn), label=unit.fullname, ) elif unit == degrees: return units.AxisInfo( majloc=ticker.AutoLocator(), majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'), label=unit.fullname, ) elif unit is not None: if hasattr(unit, 'fullname'): return units.AxisInfo(label=unit.fullname) elif hasattr(unit, 'unit'): return units.AxisInfo(label=unit.unit.fullname) return None @staticmethod def convert(val, unit, axis): if units.ConversionInterface.is_numlike(val): return val if iterable(val): return [thisval.convert_to(unit).get_value() for thisval in val] else: return val.convert_to(unit).get_value() @staticmethod def default_units(x, axis): 'return the default unit for x or None' if iterable(x): for thisx in x: return thisx.unit return x.unit def cos(x): if iterable(x): return [math.cos(val.convert_to(radians).get_value()) for val in x] else: return math.cos(x.convert_to(radians).get_value()) basicConverter = BasicUnitConverter() units.registry[BasicUnit] = basicConverter units.registry[TaggedValue] = basicConverter Keywords: python, matplotlib, pylab, example, codex (see :ref:`how-to-search-examples`)