Skip to content

Commit 646cbce

Browse files
committed
[PECO-1411] Support OAuth InHouse on GCP
Signed-off-by: Jacky Hu <jacky.hu@databricks.com>
1 parent 3f6834c commit 646cbce

File tree

3 files changed

+28
-20
lines changed

3 files changed

+28
-20
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ class AuthType(Enum):
2020

2121
class ClientContext:
2222
def __init__(
23-
self,
24-
hostname: str,
25-
username: str = None,
26-
password: str = None,
27-
access_token: str = None,
28-
auth_type: str = None,
29-
oauth_scopes: List[str] = None,
30-
oauth_client_id: str = None,
31-
oauth_redirect_port_range: List[int] = None,
32-
use_cert_as_auth: str = None,
33-
tls_client_cert_file: str = None,
34-
oauth_persistence=None,
35-
credentials_provider=None,
23+
self,
24+
hostname: str,
25+
username: str = None,
26+
password: str = None,
27+
access_token: str = None,
28+
auth_type: str = None,
29+
oauth_scopes: List[str] = None,
30+
oauth_client_id: str = None,
31+
oauth_redirect_port_range: List[int] = None,
32+
use_cert_as_auth: str = None,
33+
tls_client_cert_file: str = None,
34+
oauth_persistence=None,
35+
credentials_provider=None,
3636
):
3737
self.hostname = hostname
3838
self.username = username
@@ -88,9 +88,10 @@ def normalize_host_name(hostname: str):
8888

8989

9090
def get_client_id_and_redirect_port(hostname: str):
91+
cloud_type = infer_cloud_from_host(hostname)
9192
return (
9293
(PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE)
93-
if infer_cloud_from_host(hostname) == CloudType.AWS
94+
if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP
9495
else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE)
9596
)
9697

src/databricks/sql/auth/endpoint.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class OAuthScope:
2121
class CloudType(Enum):
2222
AWS = "aws"
2323
AZURE = "azure"
24+
GCP = "gcp"
2425

2526

2627
DATABRICKS_AWS_DOMAINS = [
@@ -34,6 +35,9 @@ class CloudType(Enum):
3435
".databricks.azure.cn",
3536
".databricks.azure.us",
3637
]
38+
DATABRICKS_GCP_DOMAINS = [
39+
".gcp.databricks.com"
40+
]
3741

3842

3943
# Infer cloud type from Databricks SQL instance hostname
@@ -45,6 +49,8 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
4549
return CloudType.AZURE
4650
elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)):
4751
return CloudType.AWS
52+
elif any(e for e in DATABRICKS_GCP_DOMAINS if host.endswith(e)):
53+
return CloudType.GCP
4854
else:
4955
return None
5056

@@ -94,7 +100,7 @@ def get_openid_config_url(self, hostname: str):
94100
return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"
95101

96102

97-
class AwsOAuthEndpointCollection(OAuthEndpointCollection):
103+
class InHouseOAuthEndpointCollection(OAuthEndpointCollection):
98104
def get_scopes_mapping(self, scopes: List[str]) -> List[str]:
99105
# No scope mapping in AWS
100106
return scopes.copy()
@@ -109,8 +115,8 @@ def get_openid_config_url(self, hostname: str):
109115

110116

111117
def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]:
112-
if cloud == CloudType.AWS:
113-
return AwsOAuthEndpointCollection()
118+
if cloud == CloudType.AWS or cloud == CloudType.GCP:
119+
return InHouseOAuthEndpointCollection()
114120
elif cloud == CloudType.AZURE:
115121
return AzureOAuthEndpointCollection()
116122
else:

tests/unit/test_auth.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
88
from databricks.sql.auth.oauth import OAuthManager
99
from databricks.sql.auth.authenticators import DatabricksOAuthProvider
10-
from databricks.sql.auth.endpoint import CloudType, AwsOAuthEndpointCollection, AzureOAuthEndpointCollection
10+
from databricks.sql.auth.endpoint import CloudType, InHouseOAuthEndpointCollection, AzureOAuthEndpointCollection
1111
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory
1212
from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache
1313

@@ -55,9 +55,10 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh):
5555
mock_get_tokens.return_value = (access_token, refresh_token)
5656
mock_check_and_refresh.return_value = (access_token, refresh_token, False)
5757

58-
params = [(CloudType.AWS, "foo.cloud.databricks.com", AwsOAuthEndpointCollection, "offline_access sql"),
58+
params = [(CloudType.AWS, "foo.cloud.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql"),
5959
(CloudType.AZURE, "foo.1.azuredatabricks.net", AzureOAuthEndpointCollection,
60-
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access")]
60+
f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access"),
61+
(CloudType.GCP, "foo.gcp.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql")]
6162

6263
for cloud_type, host, expected_endpoint_type, expected_scopes in params:
6364
with self.subTest(cloud_type.value):

0 commit comments

Comments
 (0)