Skip to content

Commit 6bbfede

Browse files
author
zhilingc
committed
Use count() instead of returning all rows
1 parent e27e203 commit 6bbfede

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

tests/e2e/bq-batch-retrieval.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
11
import math
2+
import os
23
import random
34
import time
5+
import uuid
46
from datetime import datetime
57
from datetime import timedelta
68
from urllib.parse import urlparse
79

8-
import os
9-
import uuid
1010
import numpy as np
1111
import pandas as pd
1212
import pytest
1313
import pytz
14+
from feast.client import Client
1415
from feast.core.CoreService_pb2 import ListStoresRequest
1516
from feast.core.IngestionJob_pb2 import IngestionJobStatus
16-
from feast.client import Client
17-
from feast.core.FeatureSet_pb2 import FeatureSetStatus
1817
from feast.entity import Entity
1918
from feast.feature import Feature
2019
from feast.feature_set import FeatureSet
2120
from feast.type_map import ValueType
2221
from google.cloud import storage, bigquery
23-
from google.cloud.bigquery import TableReference
2422
from google.protobuf.duration_pb2 import Duration
2523
from pandavro import to_avro
2624

@@ -66,6 +64,7 @@ def client(core_url, serving_url, allow_dirty):
6664

6765
return client
6866

67+
6968
@pytest.mark.first
7069
@pytest.mark.direct_runner
7170
@pytest.mark.dataflow_runner
@@ -425,21 +424,25 @@ def test_batch_no_max_age(client):
425424

426425
@pytest.fixture(scope="module", autouse=True)
427426
def infra_teardown(pytestconfig, core_url, serving_url):
428-
client = Client(core_url=core_url, serving_url=serving_url)
429-
client.set_project(PROJECT_NAME)
427+
client = Client(core_url=core_url, serving_url=serving_url)
428+
client.set_project(PROJECT_NAME)
430429

431-
marker = pytestconfig.getoption("-m")
432-
yield marker
433-
if marker == 'dataflow_runner':
434-
ingest_jobs = client.list_ingest_jobs()
435-
ingest_jobs = [client.list_ingest_jobs(job.id)[0].external_id for job in ingest_jobs if job.status == IngestionJobStatus.RUNNING]
430+
marker = pytestconfig.getoption("-m")
431+
yield marker
432+
if marker == "dataflow_runner":
433+
ingest_jobs = client.list_ingest_jobs()
434+
ingest_jobs = [
435+
client.list_ingest_jobs(job.id)[0].external_id
436+
for job in ingest_jobs
437+
if job.status == IngestionJobStatus.RUNNING
438+
]
436439

437-
cwd = os.getcwd()
438-
with open(f"{cwd}/ingesting_jobs.txt", "w+") as output:
439-
for job in ingest_jobs:
440-
output.write('%s\n' % job)
441-
else:
442-
print('Cleaning up not required')
440+
cwd = os.getcwd()
441+
with open(f"{cwd}/ingesting_jobs.txt", "w+") as output:
442+
for job in ingest_jobs:
443+
output.write("%s\n" % job)
444+
else:
445+
print("Cleaning up not required")
443446

444447

445448
@pytest.fixture(scope="module")
@@ -521,8 +524,11 @@ def test_update_featureset_update_featureset_and_ingest_second_subset(
521524
time.sleep(15) # wait for rows to get written to bq
522525
rows_ingested = get_rows_ingested(client, update_fs, ingestion_id)
523526
if rows_ingested == len(subset_df):
527+
print(f"Number of rows successfully ingested: {rows_ingested}. Continuing.")
524528
break
525-
print(f"Number of rows successfully ingested: {rows_ingested}. Retrying ingestion.")
529+
print(
530+
f"Number of rows successfully ingested: {rows_ingested}. Retrying ingestion."
531+
)
526532
time.sleep(30)
527533

528534
feature_retrieval_job = client.get_batch_features(
@@ -576,7 +582,9 @@ def test_update_featureset_retrieve_valid_fields(client, update_featureset_dataf
576582
== update_featureset_dataframe["update_feature1"].to_list()
577583
)
578584
# we have to convert to float because the column contains np.NaN
579-
assert [math.isnan(i) for i in output["update_feature3"].to_list()[:5]] == [True] * 5
585+
assert [math.isnan(i) for i in output["update_feature3"].to_list()[:5]] == [
586+
True
587+
] * 5
580588
assert output["update_feature3"].to_list()[5:] == [
581589
float(i) for i in update_featureset_dataframe["update_feature3"].to_list()[5:]
582590
]
@@ -586,17 +594,21 @@ def test_update_featureset_retrieve_valid_fields(client, update_featureset_dataf
586594
)
587595

588596

589-
def get_rows_ingested(client: Client, feature_set: FeatureSet, ingestion_id: str) -> int:
590-
response = client._core_service_stub.ListStores(ListStoresRequest(filter=ListStoresRequest.Filter(name="historical")))
597+
def get_rows_ingested(
598+
client: Client, feature_set: FeatureSet, ingestion_id: str
599+
) -> int:
600+
response = client._core_service_stub.ListStores(
601+
ListStoresRequest(filter=ListStoresRequest.Filter(name="historical"))
602+
)
591603
bq_config = response.store[0].bigquery_config
592604
project = bq_config.project_id
593605
dataset = bq_config.dataset_id
594606
table = f"{PROJECT_NAME}_{feature_set.name}"
595607

596608
bq_client = bigquery.Client(project=project)
597-
rows = bq_client.query(f'SELECT * FROM `{project}.{dataset}.{table}` WHERE ingestion_id = "{ingestion_id}"').result()
598-
row_count = 0
599-
for row in rows:
600-
row_count += 1
601-
return row_count
609+
rows = bq_client.query(
610+
f'SELECT COUNT(*) as count FROM `{project}.{dataset}.{table}` WHERE ingestion_id = "{ingestion_id}"'
611+
).result()
602612

613+
for row in rows:
614+
return row["count"]

0 commit comments

Comments
 (0)