Skip to content

Commit 5b8f720

Browse files
committed
feat (oauth): add custom oauth manager
1 parent 3f6834c commit 5b8f720

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
tls_client_cert_file: str = None,
3434
oauth_persistence=None,
3535
credentials_provider=None,
36+
custom_oauth_manager=None,
3637
):
3738
self.hostname = hostname
3839
self.username = username
@@ -46,11 +47,13 @@ def __init__(
4647
self.tls_client_cert_file = tls_client_cert_file
4748
self.oauth_persistence = oauth_persistence
4849
self.credentials_provider = credentials_provider
50+
self.custom_oauth_manager=custom_oauth_manager
4951

5052

5153
def get_auth_provider(cfg: ClientContext):
5254
if cfg.credentials_provider:
5355
return ExternalAuthProvider(cfg.credentials_provider)
56+
5457
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
5558
assert cfg.oauth_redirect_port_range is not None
5659
assert cfg.oauth_client_id is not None
@@ -62,6 +65,7 @@ def get_auth_provider(cfg: ClientContext):
6265
cfg.oauth_redirect_port_range,
6366
cfg.oauth_client_id,
6467
cfg.oauth_scopes,
68+
cfg.custom_oauth_manager,
6569
)
6670
elif cfg.access_token is not None:
6771
return AccessTokenAuthProvider(cfg.access_token)
@@ -112,5 +116,8 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
112116
else redirect_port_range,
113117
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
114118
credentials_provider=kwargs.get("credentials_provider"),
119+
# customization start
120+
custom_oauth_manager=kwargs.get("custom_oauth_manager"),
121+
# customization end
115122
)
116123
return get_auth_provider(cfg)

src/databricks/sql/auth/authenticators.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
redirect_port_range: List[int],
7070
client_id: str,
7171
scopes: List[str],
72+
custom_oauth_manager,
7273
):
7374
try:
7475
cloud_type = infer_cloud_from_host(hostname)
@@ -84,7 +85,10 @@ def __init__(
8485
# Convert to the corresponding scopes in the corresponding IdP
8586
cloud_scopes = idp_endpoint.get_scopes_mapping(scopes)
8687

87-
self.oauth_manager = OAuthManager(
88+
if not custom_oauth_manager:
89+
custom_oauth_manager = OAuthManager
90+
91+
self.oauth_manager = custom_oauth_manager(
8892
port_range=redirect_port_range,
8993
client_id=client_id,
9094
idp_endpoint=idp_endpoint,

0 commit comments

Comments
 (0)