-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy patharrow.py
More file actions
91 lines (70 loc) · 2.81 KB
/
arrow.py
File metadata and controls
91 lines (70 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
A module with utility functions and classes to support authorizing the Arrow Flight servers.
"""
import asyncio
import functools
import logging
from typing import cast
import pyarrow.flight as fl
from pyarrow.flight import ServerCallContext
from feast.permissions.auth.auth_manager import (
get_auth_manager,
)
from feast.permissions.security_manager import get_security_manager
from feast.permissions.user import User
logger = logging.getLogger(__name__)
class AuthorizationMiddlewareFactory(fl.ServerMiddlewareFactory):
"""
A middleware factory to intercept the authorization header and propagate it to the authorization middleware.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def start_call(self, info, headers):
"""
Intercept the authorization header and propagate it to the authorization middleware.
"""
access_token = get_auth_manager().token_extractor.extract_access_token(
headers=headers
)
return AuthorizationMiddleware(access_token=access_token)
class AuthorizationMiddleware(fl.ServerMiddleware):
"""
A server middleware holding the authorization header and offering a method to extract the user credentials.
"""
def __init__(self, access_token: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.access_token = access_token
def call_completed(self, exception):
if exception:
logger.exception(
f"{AuthorizationMiddleware.__name__} encountered an exception: {exception}"
)
async def extract_user(self) -> User:
"""
Use the configured `TokenParser` to extract the user credentials.
"""
return await get_auth_manager().token_parser.user_details_from_access_token(
self.access_token
)
def inject_user_details(context: ServerCallContext):
"""
Function to use in Arrow Flight endpoints (e.g. `do_get`, `do_put` and so on) to access the token extracted from the header,
extract the user details out of it and propagate them to the current security manager, if any.
Args:
context: The endpoint context.
"""
if context.get_middleware("auth") is None:
logger.warning("No `auth` middleware.")
return
sm = get_security_manager()
if sm is not None:
auth_middleware = cast(AuthorizationMiddleware, context.get_middleware("auth"))
current_user = asyncio.run(auth_middleware.extract_user())
logger.debug(f"User extracted: {current_user}")
sm.set_current_user(current_user)
def inject_user_details_decorator(func):
@functools.wraps(func)
def wrapper(self, context, *args, **kwargs):
inject_user_details(context)
return func(self, context, *args, **kwargs)
return wrapper