@@ -33,6 +33,7 @@ def __init__(
3333 tls_client_cert_file : str = None ,
3434 oauth_persistence = None ,
3535 credentials_provider = None ,
36+ custom_oauth_manager = None ,
3637 ):
3738 self .hostname = hostname
3839 self .username = username
@@ -46,11 +47,13 @@ def __init__(
4647 self .tls_client_cert_file = tls_client_cert_file
4748 self .oauth_persistence = oauth_persistence
4849 self .credentials_provider = credentials_provider
50+ self .custom_oauth_manager = custom_oauth_manager
4951
5052
5153def get_auth_provider (cfg : ClientContext ):
5254 if cfg .credentials_provider :
5355 return ExternalAuthProvider (cfg .credentials_provider )
56+
5457 if cfg .auth_type == AuthType .DATABRICKS_OAUTH .value :
5558 assert cfg .oauth_redirect_port_range is not None
5659 assert cfg .oauth_client_id is not None
@@ -62,6 +65,7 @@ def get_auth_provider(cfg: ClientContext):
6265 cfg .oauth_redirect_port_range ,
6366 cfg .oauth_client_id ,
6467 cfg .oauth_scopes ,
68+ cfg .custom_oauth_manager ,
6569 )
6670 elif cfg .access_token is not None :
6771 return AccessTokenAuthProvider (cfg .access_token )
@@ -112,5 +116,8 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
112116 else redirect_port_range ,
113117 oauth_persistence = kwargs .get ("experimental_oauth_persistence" ),
114118 credentials_provider = kwargs .get ("credentials_provider" ),
119+ # customization start
120+ custom_oauth_manager = kwargs .get ("custom_oauth_manager" ),
121+ # customization end
115122 )
116123 return get_auth_provider (cfg )
0 commit comments