forked from matplotlib/matplotlib
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcategory.py
More file actions
121 lines (97 loc) · 3.46 KB
/
category.py
File metadata and controls
121 lines (97 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 OA-*-za
"""
catch all for categorical functions
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import six
import numpy as np
import matplotlib.units as units
import matplotlib.ticker as ticker
# np 1.6/1.7 support
from distutils.version import LooseVersion
import collections
if LooseVersion(np.__version__) >= LooseVersion('1.8.0'):
def shim_array(data):
return np.array(data, dtype=np.unicode)
else:
def shim_array(data):
if (isinstance(data, six.string_types) or
not isinstance(data, collections.Iterable)):
data = [data]
try:
data = [str(d) for d in data]
except UnicodeEncodeError:
# this yields gibberish but unicode text doesn't
# render under numpy1.6 anyway
data = [d.encode('utf-8', 'ignore').decode('utf-8')
for d in data]
return np.array(data, dtype=np.unicode)
class StrCategoryConverter(units.ConversionInterface):
@staticmethod
def convert(value, unit, axis):
"""Uses axis.unit_data map to encode
data as floats
"""
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))
if isinstance(value, six.string_types):
return vmap[value]
vals = shim_array(value)
for lab, loc in vmap.items():
vals[vals == lab] = loc
return vals.astype('float')
@staticmethod
def axisinfo(unit, axis):
majloc = StrCategoryLocator(axis.unit_data.locs)
majfmt = StrCategoryFormatter(axis.unit_data.seq)
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
@staticmethod
def default_units(data, axis):
# the conversion call stack is:
# default_units->axis_info->convert
if axis.unit_data is None:
axis.unit_data = UnitData(data)
else:
axis.unit_data.update(data)
return None
class StrCategoryLocator(ticker.FixedLocator):
def __init__(self, locs):
self.locs = locs
self.nbins = None
class StrCategoryFormatter(ticker.FixedFormatter):
def __init__(self, seq):
self.seq = seq
self.offset_string = ''
class UnitData(object):
# debatable makes sense to special code missing values
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
def __init__(self, data):
"""Create mapping between unique categorical values
and numerical identifier
Paramters
---------
data: iterable
sequence of values
"""
self.seq, self.locs = [], []
self._set_seq_locs(data, 0)
def update(self, new_data):
# so as not to conflict with spdict
value = max(max(self.locs) + 1, 0)
self._set_seq_locs(new_data, value)
def _set_seq_locs(self, data, value):
strdata = shim_array(data)
new_s = [d for d in np.unique(strdata) if d not in self.seq]
for ns in new_s:
self.seq.append(ns)
if ns in UnitData.spdict:
self.locs.append(UnitData.spdict[ns])
else:
self.locs.append(value)
value += 1
# Connects the convertor to matplotlib
units.registry[str] = StrCategoryConverter()
units.registry[np.str_] = StrCategoryConverter()
units.registry[six.text_type] = StrCategoryConverter()
units.registry[bytes] = StrCategoryConverter()
units.registry[np.bytes_] = StrCategoryConverter()