Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class SqlRegistryConfig(RegistryConfig):
""" str: Path to metadata store.
If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """

read_path: Optional[StrictStr] = None
""" str: Read Path to metadata store if different from path.
If registry_type is 'sql', then this is a Read Endpoint for database URL. If not set, path will be used for read and write. """

sqlalchemy_config_kwargs: Dict[str, Any] = {"echo": False}
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """

Expand All @@ -223,13 +227,20 @@ def __init__(
registry_config, SqlRegistryConfig
), "SqlRegistry needs a valid registry_config"

self.engine: Engine = create_engine(
self.write_engine: Engine = create_engine(
registry_config.path, **registry_config.sqlalchemy_config_kwargs
)
if registry_config.read_path:
self.read_engine: Engine = create_engine(
registry_config.read_path,
**registry_config.sqlalchemy_config_kwargs,
)
else:
self.read_engine = self.write_engine
metadata.create_all(self.write_engine)
self.thread_pool_executor_worker_count = (
registry_config.thread_pool_executor_worker_count
)
metadata.create_all(self.engine)
self.purge_feast_metadata = registry_config.purge_feast_metadata
# Sync feast_metadata to projects table
# when purge_feast_metadata is set to True, Delete data from
Expand All @@ -246,7 +257,7 @@ def __init__(
def _sync_feast_metadata_to_projects_table(self):
feast_metadata_projects: set = []
projects_set: set = []
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value
)
Expand All @@ -255,7 +266,7 @@ def _sync_feast_metadata_to_projects_table(self):
feast_metadata_projects.append(row._mapping["project_id"])

if len(feast_metadata_projects) > 0:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(projects)
rows = conn.execute(stmt).all()
for row in rows:
Expand All @@ -267,7 +278,7 @@ def _sync_feast_metadata_to_projects_table(self):
self.apply_project(Project(name=project_name), commit=True)

if self.purge_feast_metadata:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
for project_name in feast_metadata_projects:
stmt = delete(feast_metadata).where(
feast_metadata.c.project_id == project_name
Expand All @@ -285,7 +296,7 @@ def teardown(self):
validation_references,
permissions,
}:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(t)
conn.execute(stmt)

Expand Down Expand Up @@ -549,7 +560,7 @@ def apply_feature_service(
)

def delete_data_source(self, name: str, project: str, commit: bool = True):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(data_sources).where(
data_sources.c.data_source_name == name,
data_sources.c.project_id == project,
Expand Down Expand Up @@ -607,7 +618,7 @@ def _list_on_demand_feature_views(
)

def _list_project_metadata(self, project: str) -> List[ProjectMetadata]:
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.project_id == project,
)
Expand Down Expand Up @@ -726,7 +737,7 @@ def apply_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, "feature_view_name") == name,
table.c.project_id == project,
Expand Down Expand Up @@ -781,7 +792,7 @@ def get_user_metadata(
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(table).where(getattr(table.c, "feature_view_name") == name)
row = conn.execute(stmt).first()
if row:
Expand Down Expand Up @@ -885,7 +896,7 @@ def _apply_object(
name = name or (obj.name if hasattr(obj, "name") else None)
assert name, f"name needs to be provided for {obj}"

with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
update_datetime = _utc_now()
update_time = int(update_datetime.timestamp())
stmt = select(table).where(
Expand Down Expand Up @@ -961,7 +972,7 @@ def _apply_object(

def _maybe_init_project_metadata(self, project):
# Initialize project metadata if needed
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
update_datetime = _utc_now()
update_time = int(update_datetime.timestamp())
stmt = select(feast_metadata).where(
Expand All @@ -988,7 +999,7 @@ def _delete_object(
id_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
Expand All @@ -1014,7 +1025,7 @@ def _get_object(
proto_field_name: str,
not_found_exception: Optional[Callable],
):
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
Expand All @@ -1036,7 +1047,7 @@ def _list_objects(
proto_field_name: str,
tags: Optional[dict[str, str]] = None,
):
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(table).where(table.c.project_id == project)
rows = conn.execute(stmt).all()
if rows:
Expand All @@ -1051,7 +1062,7 @@ def _list_objects(
return []

def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand Down Expand Up @@ -1085,7 +1096,7 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str):
conn.execute(insert_stmt)

def _get_last_updated_metadata(self, project: str):
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
Expand Down Expand Up @@ -1130,7 +1141,7 @@ def apply_permission(
)

def delete_permission(self, name: str, project: str, commit: bool = True):
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
stmt = delete(permissions).where(
permissions.c.permission_name == name,
permissions.c.project_id == project,
Expand All @@ -1143,7 +1154,7 @@ def _list_projects(
self,
tags: Optional[dict[str, str]],
) -> List[Project]:
with self.engine.begin() as conn:
with self.read_engine.begin() as conn:
stmt = select(projects)
rows = conn.execute(stmt).all()
if rows:
Expand Down Expand Up @@ -1188,7 +1199,7 @@ def delete_project(
):
project = self.get_project(name, allow_cache=False)
if project:
with self.engine.begin() as conn:
with self.write_engine.begin() as conn:
for t in {
managed_infra,
saved_datasets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,43 @@ def minio_registry(minio_server):
yield Registry("project", registry_config, None)


POSTGRES_READONLY_USER = "read_only_user"
POSTGRES_READONLY_PASSWORD = "readonly_password"

logger = logging.getLogger(__name__)


def add_pg_read_only_user(
container_host, container_port, db_name, postgres_user, postgres_password
):
# Connect to PostgreSQL as an admin
import psycopg

conn_string = f"dbname={db_name} user={postgres_user} password={postgres_password} host={container_host} port={container_port}"

with psycopg.connect(conn_string) as conn:
user_exists = conn.execute(
f"SELECT 1 FROM pg_catalog.pg_user WHERE usename = '{POSTGRES_READONLY_USER}'"
).fetchone()
if not user_exists:
conn.execute(
f"CREATE USER {POSTGRES_READONLY_USER} WITH PASSWORD '{POSTGRES_READONLY_PASSWORD}';"
)

conn.execute(
f"REVOKE ALL PRIVILEGES ON DATABASE {db_name} FROM {POSTGRES_READONLY_USER};"
)
conn.execute(
f"GRANT CONNECT ON DATABASE {db_name} TO {POSTGRES_READONLY_USER};"
)
conn.execute(
f"GRANT SELECT ON ALL TABLES IN SCHEMA public TO {POSTGRES_READONLY_USER};"
)
conn.execute(
f"ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO {POSTGRES_READONLY_USER};"
)


@pytest.fixture(scope="function")
def pg_registry(postgres_server):
db_name = "".join(random.choices(string.ascii_lowercase, k=10))
Expand All @@ -130,13 +164,22 @@ def pg_registry(postgres_server):
container_port = postgres_server.get_exposed_port(5432)
container_host = postgres_server.get_container_host_ip()

add_pg_read_only_user(
container_host,
container_port,
db_name,
postgres_server.username,
postgres_server.password,
)

registry_config = SqlRegistryConfig(
registry_type="sql",
cache_ttl_seconds=2,
cache_mode="sync",
# The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()`
# to understand that we are using psycopg3.
path=f"postgresql+psycopg://{postgres_server.username}:{postgres_server.password}@{container_host}:{container_port}/{db_name}",
read_path=f"postgresql+psycopg://{POSTGRES_READONLY_USER}:{POSTGRES_READONLY_PASSWORD}@{container_host}:{container_port}/{db_name}",
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
thread_pool_executor_worker_count=0,
purge_feast_metadata=False,
Expand Down