forked from ukosuagwu/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfixes.py
More file actions
225 lines (163 loc) · 7.04 KB
/
Copy pathfixes.py
File metadata and controls
225 lines (163 loc) · 7.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""Compatibility fixes for older version of python, numpy and scipy
If you add content to this file, please give the version of the package
at which the fix is no longer needed.
"""
# Authors: Emmanuelle Gouillart <emmanuelle.gouillart@normalesup.org>
# Gael Varoquaux <gael.varoquaux@normalesup.org>
# Fabian Pedregosa <fpedregosa@acm.org>
# Lars Buitinck
#
# License: BSD 3 clause
from functools import update_wrapper
from importlib import resources
import functools
import sys
import sklearn
import numpy as np
import scipy
import scipy.stats
import threadpoolctl
from .._config import config_context, get_config
from ..externals._packaging.version import parse as parse_version
np_version = parse_version(np.__version__)
sp_version = parse_version(scipy.__version__)
if sp_version >= parse_version("1.4"):
from scipy.sparse.linalg import lobpcg
else:
# Backport of lobpcg functionality from scipy 1.4.0, can be removed
# once support for sp_version < parse_version('1.4') is dropped
# mypy error: Name 'lobpcg' already defined (possibly by an import)
from ..externals._lobpcg import lobpcg # type: ignore # noqa
try:
from scipy.optimize._linesearch import line_search_wolfe2, line_search_wolfe1
except ImportError: # SciPy < 1.8
from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1 # type: ignore # noqa
def _object_dtype_isnan(X):
return X != X
class loguniform(scipy.stats.reciprocal):
"""A class supporting log-uniform random variables.
Parameters
----------
low : float
The minimum value
high : float
The maximum value
Methods
-------
rvs(self, size=None, random_state=None)
Generate log-uniform random variables
The most useful method for Scikit-learn usage is highlighted here.
For a full list, see
`scipy.stats.reciprocal
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.reciprocal.html>`_.
This list includes all functions of ``scipy.stats`` continuous
distributions such as ``pdf``.
Notes
-----
This class generates values between ``low`` and ``high`` or
low <= loguniform(low, high).rvs() <= high
The logarithmic probability density function (PDF) is uniform. When
``x`` is a uniformly distributed random variable between 0 and 1, ``10**x``
are random variables that are equally likely to be returned.
This class is an alias to ``scipy.stats.reciprocal``, which uses the
reciprocal distribution:
https://en.wikipedia.org/wiki/Reciprocal_distribution
Examples
--------
>>> from sklearn.utils.fixes import loguniform
>>> rv = loguniform(1e-3, 1e1)
>>> rvs = rv.rvs(random_state=42, size=1000)
>>> rvs.min() # doctest: +SKIP
0.0010435856341129003
>>> rvs.max() # doctest: +SKIP
9.97403052786026
"""
# TODO: remove when the minimum scipy version is >= 1.5
if sp_version >= parse_version("1.5"):
from scipy.linalg import eigh as _eigh # noqa
else:
def _eigh(*args, **kwargs):
"""Wrapper for `scipy.linalg.eigh` that handles the deprecation of `eigvals`."""
eigvals = kwargs.pop("subset_by_index", None)
return scipy.linalg.eigh(*args, eigvals=eigvals, **kwargs)
# remove when https://github.com/joblib/joblib/issues/1071 is fixed
def delayed(function):
"""Decorator used to capture the arguments of a function."""
@functools.wraps(function)
def delayed_function(*args, **kwargs):
return _FuncWrapper(function), args, kwargs
return delayed_function
class _FuncWrapper:
""" "Load the global configuration before calling the function."""
def __init__(self, function):
self.function = function
self.config = get_config()
update_wrapper(self, self.function)
def __call__(self, *args, **kwargs):
with config_context(**self.config):
return self.function(*args, **kwargs)
# Rename the `method` kwarg to `interpolation` for NumPy < 1.22, because
# `interpolation` kwarg was deprecated in favor of `method` in NumPy >= 1.22.
def _percentile(a, q, *, method="linear", **kwargs):
return np.percentile(a, q, interpolation=method, **kwargs)
if np_version < parse_version("1.22"):
percentile = _percentile
else: # >= 1.22
from numpy import percentile # type: ignore # noqa
# compatibility fix for threadpoolctl >= 3.0.0
# since version 3 it's possible to setup a global threadpool controller to avoid
# looping through all loaded shared libraries each time.
# the global controller is created during the first call to threadpoolctl.
def _get_threadpool_controller():
if not hasattr(threadpoolctl, "ThreadpoolController"):
return None
if not hasattr(sklearn, "_sklearn_threadpool_controller"):
sklearn._sklearn_threadpool_controller = threadpoolctl.ThreadpoolController()
return sklearn._sklearn_threadpool_controller
def threadpool_limits(limits=None, user_api=None):
controller = _get_threadpool_controller()
if controller is not None:
return controller.limit(limits=limits, user_api=user_api)
else:
return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api)
threadpool_limits.__doc__ = threadpoolctl.threadpool_limits.__doc__
def threadpool_info():
controller = _get_threadpool_controller()
if controller is not None:
return controller.info()
else:
return threadpoolctl.threadpool_info()
threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__
# TODO: Remove when SciPy 1.9 is the minimum supported version
def _mode(a, axis=0):
if sp_version >= parse_version("1.9.0"):
return scipy.stats.mode(a, axis=axis, keepdims=True)
return scipy.stats.mode(a, axis=axis)
###############################################################################
# Backport of Python 3.9's importlib.resources
# TODO: Remove when Python 3.9 is the minimum supported version
def _open_text(data_module, data_file_name):
if sys.version_info >= (3, 9):
return resources.files(data_module).joinpath(data_file_name).open("r")
else:
return resources.open_text(data_module, data_file_name)
def _open_binary(data_module, data_file_name):
if sys.version_info >= (3, 9):
return resources.files(data_module).joinpath(data_file_name).open("rb")
else:
return resources.open_binary(data_module, data_file_name)
def _read_text(descr_module, descr_file_name):
if sys.version_info >= (3, 9):
return resources.files(descr_module).joinpath(descr_file_name).read_text()
else:
return resources.read_text(descr_module, descr_file_name)
def _path(data_module, data_file_name):
if sys.version_info >= (3, 9):
return resources.as_file(resources.files(data_module).joinpath(data_file_name))
else:
return resources.path(data_module, data_file_name)
def _is_resource(data_module, data_file_name):
if sys.version_info >= (3, 9):
return resources.files(data_module).joinpath(data_file_name).is_file()
else:
return resources.is_resource(data_module, data_file_name)