forked from srlearn/srlearn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.py
More file actions
147 lines (122 loc) · 4.74 KB
/
base.py
File metadata and controls
147 lines (122 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright © 2017, 2018, 2019 Alexander L. Hayes
"""
Base class for Boosted Relational Models
"""
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.utils.validation import check_is_fitted
import subprocess
from .background import Background
from .system_manager import FileSystem
from ._meta import DEBUG
class BaseBoostedRelationalModel(BaseEstimator, ClassifierMixin):
"""Base class for deriving boosted relational models
This class extends :class:`sklearn.base.BaseEstimator` and
:class:`sklearn.base.ClassifierMixin` while providing several utilities
for instantiating a model and performing learning/inference with the
BoostSRL jar files.
.. note:: This is not a complete treatment of *how to derive estimators*.
Contributions would be appreciated.
Examples
--------
The actual :class:`srlearn.rdn.BoostedRDN` is derived from this class, so this
example is similar to the implementation (but the actual implementation
passes model parameters instead of leaving them with the defaults).
This example derives a new class ``BoostedRDN``, which inherits the default
values of the superclass while also setting a 'special_parameter' which
may be unique to this model.
All that remains is to implement the specific cases of ``fit()``,
``predict()``, and ``predict_proba()``.
>>> from srlearn.base import BaseBoostedRelationalModel
>>> class BoostedRDN(BaseBoostedRelationalModel):
... def __init__(self, special_parameter=5):
... super().__init__(self)
... self.special_parameter = special_parameter
...
>>> dn = BoostedRDN(special_parameter=8)
>>> print(dn)
BoostedRDN(special_parameter=8)
>>> print(dn.n_estimators)
10
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
background=None,
target="None",
n_estimators=10,
node_size=2,
max_tree_depth=3,
):
"""Initialize a BaseEstimator"""
self.background = background
self.target = target
self.n_estimators = n_estimators
self.node_size = node_size
self.max_tree_depth = max_tree_depth
self.debug = DEBUG
def _check_params(self):
"""Check validity of parameters. Raise ValueError if errors are detected.
If all parameters are valid, instantiate ``self.file_system`` by
instantiating it with a :class:`srlearn.system_manager.FileSystem`
"""
if self.target == "None":
raise ValueError("target must be set, cannot be {0}".format(self.target))
if not isinstance(self.target, str):
raise ValueError(
"target must be a string, cannot be {0}".format(self.target)
)
if self.background is None:
raise ValueError(
"background must be set, cannot be {0}".format(self.background)
)
if not isinstance(self.background, Background):
raise ValueError(
"background should be a srlearn.Background object, cannot be {0}".format(
self.background
)
)
if not isinstance(self.n_estimators, int) or isinstance(
self.n_estimators, bool
):
raise ValueError(
"n_estimators must be an integer, cannot be {0}".format(
self.n_estimators
)
)
if self.n_estimators <= 0:
raise ValueError(
"n_estimators must be greater than 0, cannot be {0}".format(
self.n_estimators
)
)
# If all params are valid, allocate a FileSystem:
self.file_system = FileSystem()
def _check_initialized(self):
"""Check for the estimator(s), raise an error if not found."""
check_is_fitted(self, "estimators_")
@staticmethod
def _call_shell_command(shell_command):
"""Start a new process to execute a shell command.
This is intended for use in calling jar files. It opens a new process and
waits for it to return 0.
Parameters
----------
shell_command : str
A string representing a shell command.
Returns
-------
None
"""
_pid = subprocess.Popen(shell_command, shell=True)
_status = _pid.wait()
if _status != 0:
raise RuntimeError(
"Error when running shell command: {0}".format(shell_command)
)
def fit(self, database):
raise NotImplementedError
def predict(self, database):
raise NotImplementedError
def predict_proba(self, database):
raise NotImplementedError