Skip to content

Commit 663e0cd

Browse files
committed
Add monkey patching of the automl clients
1 parent b358d22 commit 663e0cd

3 files changed

Lines changed: 124 additions & 5 deletions

File tree

patches/kaggle_gcp.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def init_bigquery():
130130
return bigquery
131131

132132
# If this Kernel has bigquery integration on startup, preload the Kaggle Credentials
133-
# object for magics to work.
133+
# object for magics to work.
134134
if get_integrations().has_bigquery():
135135
from google.cloud.bigquery import magics
136136
magics.context.credentials = KaggleKernelCredentials()
@@ -139,7 +139,7 @@ def monkeypatch_bq(bq_client, *args, **kwargs):
139139
from kaggle_gcp import get_integrations, PublicBigqueryClient, KaggleKernelCredentials
140140
specified_credentials = kwargs.get('credentials')
141141
has_bigquery = get_integrations().has_bigquery()
142-
# Prioritize passed in project id, but if it is missing look for env var.
142+
# Prioritize passed in project id, but if it is missing look for env var.
143143
arg_project = kwargs.get('project')
144144
explicit_project_id = arg_project or os.environ.get(environment_vars.PROJECT)
145145
# This is a hack to get around the bug in google-cloud library.
@@ -200,9 +200,70 @@ def monkeypatch_gcs(self, *args, **kwargs):
200200
storage.Client.__init__ = monkeypatch_gcs
201201
return storage
202202

203+
def init_automl():
204+
is_user_secrets_token_set = "KAGGLE_USER_SECRETS_TOKEN" in os.environ
205+
from google.cloud import automl_v1beta1 as automl
206+
if not is_user_secrets_token_set:
207+
return automl
208+
209+
from kaggle_gcp import get_integrations
210+
if not get_integrations().has_automl():
211+
return automl
212+
213+
from kaggle_secrets import GcpTarget
214+
from kaggle_gcp import KaggleKernelCredentials
215+
kaggle_kernel_credentials = KaggleKernelCredentials(target=GcpTarget.AUTOML)
216+
217+
# The AutoML client library exposes 4 different client classes (AutoMlClient,
218+
# TablesClient, PredictionServiceClient and GcsClient), so patch each of them.
219+
# The same KaggleKernelCredentials are passed to all of them.
220+
221+
automl_client_init = automl.AutoMlClient.__init__
222+
def monkeypatch_automl(self, *args, **kwargs):
223+
specified_credentials = kwargs.get('credentials')
224+
if specified_credentials is None:
225+
Log.info("No credentials specified, using KaggleKernelCredentials.")
226+
kwargs['credentials'] = kaggle_kernel_credentials
227+
# Note: This is only here so that unit tests can check whether
228+
# credentials were set properly.
229+
self._kaggle_credentials = kwargs['credentials']
230+
return automl_client_init(self, *args, **kwargs)
231+
232+
if (not has_been_monkeypatched(automl.AutoMlClient.__init__)):
233+
automl.AutoMlClient.__init__ = monkeypatch_automl
234+
235+
236+
automl_tablesclient_init = automl.TablesClient.__init__
237+
def monkeypatch_tablesclient(self, *args, **kwargs):
238+
specified_credentials = kwargs.get('credentials')
239+
if specified_credentials is None:
240+
Log.info("No credentials specified, using KaggleKernelCredentials.")
241+
kwargs['credentials'] = kaggle_kernel_credentials
242+
self._kaggle_credentials = kwargs['credentials']
243+
return automl_tablesclient_init(self, *args, **kwargs)
244+
245+
if (not has_been_monkeypatched(automl.TablesClient.__init__)):
246+
automl.TablesClient.__init__ = monkeypatch_tablesclient
247+
248+
249+
automl_predictionclient_init = automl.PredictionServiceClient.__init__
250+
def monkeypatch_predictionclient(self, *args, **kwargs):
251+
specified_credentials = kwargs.get('credentials')
252+
if specified_credentials is None:
253+
Log.info("No credentials specified, using KaggleKernelCredentials.")
254+
kwargs['credentials'] = kaggle_kernel_credentials
255+
self._kaggle_credentials = kwargs['credentials']
256+
return automl_predictionclient_init(self, *args, **kwargs)
257+
258+
if (not has_been_monkeypatched(automl.PredictionServiceClient.__init__)):
259+
automl.PredictionServiceClient.__init__ = monkeypatch_predictionclient
260+
261+
return automl
262+
203263
def init():
204264
init_bigquery()
205265
init_gcs()
266+
init_automl()
206267

207268
# We need to initialize the monkeypatching of the client libraries
208269
# here since there is a circular dependency between our import hook version

patches/sitecustomize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import importlib.machinery
88

99
class GcpModuleFinder(importlib.abc.MetaPathFinder):
10-
_MODULES = ['google.cloud.bigquery', 'google.cloud.storage']
10+
_MODULES = ['google.cloud.bigquery', 'google.cloud.storage', 'google.cloud.automl_v1beta1']
1111
_KAGGLE_GCP_PATH = 'kaggle_gcp.py'
1212
def __init__(self):
1313
pass
@@ -39,7 +39,8 @@ def create_module(self, spec):
3939
import kaggle_gcp
4040
_LOADERS = {
4141
'google.cloud.bigquery': kaggle_gcp.init_bigquery,
42-
'google.cloud.storage': kaggle_gcp.init_gcs
42+
'google.cloud.storage': kaggle_gcp.init_gcs,
43+
'google.cloud.automl_v1beta1': kaggle_gcp.init_automl,
4344
}
4445
monkeypatch_gcp_module = _LOADERS[spec.name]()
4546
return monkeypatch_gcp_module

tests/test_automl.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,66 @@
11
import unittest
22

3-
from google.cloud import automl_v1beta1 as automl
3+
from unittest.mock import Mock
4+
5+
from kaggle_gcp import KaggleKernelCredentials, init_automl
6+
from test.support import EnvironmentVarGuard
7+
from google.cloud import storage, automl_v1beta1 as automl
8+
9+
def _make_credentials():
10+
import google.auth.credentials
11+
return Mock(spec=google.auth.credentials.Credentials)
412

513
class TestAutoMl(unittest.TestCase):
614

715
def test_version(self):
816
self.assertIsNotNone(automl.auto_ml_client._GAPIC_LIBRARY_VERSION)
17+
version_parts = automl.auto_ml_client._GAPIC_LIBRARY_VERSION.split('.')
18+
version = float('.'.join(version_parts[0:2]));
19+
self.assertGreaterEqual(version, 0.5);
20+
21+
def test_user_provided_credentials(self):
22+
credentials = _make_credentials()
23+
env = EnvironmentVarGuard()
24+
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
25+
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')
26+
with env:
27+
init_automl()
28+
client = automl.AutoMlClient(credentials=credentials)
29+
self.assertNotIsInstance(client._kaggle_credentials, KaggleKernelCredentials)
30+
self.assertIsNotNone(client._kaggle_credentials)
31+
32+
33+
def test_tables_gcs_client(self):
34+
# The GcsClient can't currently be monkeypatched for default
35+
# credentials because it requires a project which can't be set.
36+
# Verify that creating an automl.GcsClient given an actual
37+
# storage.Client sets the client properly.
38+
gcs_client = storage.Client(project="xyz", credentials=_make_credentials())
39+
tables_gcs_client = automl.GcsClient(client=gcs_client)
40+
self.assertIs(tables_gcs_client.client, gcs_client)
41+
42+
def test_default_credentials_automl_enabled(self):
43+
env = EnvironmentVarGuard()
44+
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
45+
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'AUTOML')
46+
with env:
47+
init_automl()
48+
automl_client = automl.AutoMlClient()
49+
self.assertIsNotNone(automl_client._kaggle_credentials)
50+
self.assertIsInstance(automl_client._kaggle_credentials, KaggleKernelCredentials)
51+
tables_client = automl.TablesClient()
52+
self.assertIsNotNone(automl_client._kaggle_credentials)
53+
self.assertIsInstance(automl_client._kaggle_credentials, KaggleKernelCredentials)
54+
prediction_client = automl.PredictionServiceClient()
55+
self.assertIsNotNone(automl_client._kaggle_credentials)
56+
self.assertIsInstance(automl_client._kaggle_credentials, KaggleKernelCredentials)
957

58+
def test_monkeypatching_idempotent(self):
59+
env = EnvironmentVarGuard()
60+
env.set('KAGGLE_USER_SECRETS_TOKEN', 'foobar')
61+
env.set('KAGGLE_KERNEL_INTEGRATIONS', 'GCS')
62+
with env:
63+
client1 = automl.AutoMlClient.__init__
64+
init_automl()
65+
client2 = automl.AutoMlClient.__init__
66+
self.assertEqual(client1, client2)

0 commit comments

Comments
 (0)