Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ test-go: compile-protos-go compile-protos-python install-feast-ci-locally
test-go-integration: compile-protos-go compile-protos-python install-feast-ci-locally
docker compose -f go/integration_tests/valkey/docker-compose.yaml up -d
docker compose -f go/integration_tests/scylladb/docker-compose.yaml up -d
go test -tags=integration ./go/internal/...
go test -p 1 -tags=integration ./go/internal/...
docker compose -f go/integration_tests/valkey/docker-compose.yaml down
docker compose -f go/integration_tests/scylladb/docker-compose.yaml down

Expand Down
8 changes: 4 additions & 4 deletions go/internal/feast/onlineserving/serving.go
Original file line number Diff line number Diff line change
Expand Up @@ -1086,17 +1086,17 @@ func GroupSortedFeatureRefs(
sortOrder = &flipped // non-nil only when sort key order is reversed
}

var filterModel *model.SortKeyFilter
if filter, ok := sortKeyFilterMap[sortKey.FieldName]; ok {
filterModel = model.NewSortKeyFilterFromProto(filter, sortOrder)
filterModel := model.NewSortKeyFilterFromProto(filter, sortOrder)
sortKeyFilterModels = append(sortKeyFilterModels, filterModel)
} else if reverseSortOrder {
filterModel = &model.SortKeyFilter{
filterModel := &model.SortKeyFilter{
SortKeyName: sortKey.FieldName,
Order: model.NewSortOrderFromProto(*sortOrder),
}
sortKeyFilterModels = append(sortKeyFilterModels, filterModel)
}

sortKeyFilterModels = append(sortKeyFilterModels, filterModel)
}

if _, ok := groups[groupKey]; !ok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
"time"
)

var onlineStore *CassandraOnlineStore
var ctx = context.Background()

func TestMain(m *testing.M) {
// Initialize the test environment
dir := "../../../integration_tests/scylladb/"
Expand All @@ -27,6 +30,12 @@ func TestMain(m *testing.M) {
os.Exit(1)
}

onlineStore, err = getCassandraOnlineStore()
if err != nil {
fmt.Printf("Failed to create CassandraOnlineStore: %v\n", err)
os.Exit(1)
}

// Run the tests
exitCode := m.Run()

Expand All @@ -40,16 +49,19 @@ func TestMain(m *testing.M) {
os.Exit(exitCode)
}

func getCassandraOnlineStore(t *testing.T) (*CassandraOnlineStore, context.Context) {
ctx := context.Background()
func getCassandraOnlineStore() (*CassandraOnlineStore, error) {
dir := "../../../integration_tests/scylladb/"
config, err := loadRepoConfig(dir)
require.NoError(t, err)
assert.Equal(t, "scylladb", config.OnlineStore["type"])
if err != nil {
fmt.Printf("Failed to load repo config: %v\n", err)
return nil, err
}

onlineStore, err := NewCassandraOnlineStore("feature_integration_repo", config, config.OnlineStore)
require.NoError(t, err)
return onlineStore, ctx
store, err := NewCassandraOnlineStore("feature_integration_repo", config, config.OnlineStore)
if err != nil {
return nil, err
}
return store, nil
}

func loadRepoConfig(basePath string) (*registry.RepoConfig, error) {
Expand All @@ -62,8 +74,6 @@ func loadRepoConfig(basePath string) (*registry.RepoConfig, error) {
}

func TestCassandraOnlineStore_OnlineReadRange_withSingleEntityKey(t *testing.T) {
onlineStore, ctx := getCassandraOnlineStore(t)

entityKeys := []*types.EntityKey{{
JoinKeys: []string{"index_id"},
EntityValues: []*types.Value{
Expand All @@ -79,6 +89,7 @@ func TestCassandraOnlineStore_OnlineReadRange_withSingleEntityKey(t *testing.T)
sortKeyFilters := []*model.SortKeyFilter{{
SortKeyName: "event_timestamp",
RangeStart: int64(1744769099919),
RangeEnd: int64(1744779099919),
}}

groupedRefs := &model.GroupedRangeFeatureRefs{
Expand All @@ -93,12 +104,10 @@ func TestCassandraOnlineStore_OnlineReadRange_withSingleEntityKey(t *testing.T)

data, err := onlineStore.OnlineReadRange(ctx, groupedRefs)
require.NoError(t, err)
verifyResponseData(t, data, 1)
verifyResponseData(t, data, 1, int64(1744769099919), int64(1744779099919))
}

func TestCassandraOnlineStore_OnlineReadRange_withMultipleEntityKeys(t *testing.T) {
onlineStore, ctx := getCassandraOnlineStore(t)

entityKeys := []*types.EntityKey{
{
JoinKeys: []string{"index_id"},
Expand Down Expand Up @@ -142,12 +151,10 @@ func TestCassandraOnlineStore_OnlineReadRange_withMultipleEntityKeys(t *testing.

data, err := onlineStore.OnlineReadRange(ctx, groupedRefs)
require.NoError(t, err)
verifyResponseData(t, data, 3)
verifyResponseData(t, data, 3, int64(1744769099919), int64(1744769099919*10))
}

func TestCassandraOnlineStore_OnlineReadRange_withReverseSortOrder(t *testing.T) {
onlineStore, ctx := getCassandraOnlineStore(t)

entityKeys := []*types.EntityKey{
{
JoinKeys: []string{"index_id"},
Expand Down Expand Up @@ -193,15 +200,59 @@ func TestCassandraOnlineStore_OnlineReadRange_withReverseSortOrder(t *testing.T)

data, err := onlineStore.OnlineReadRange(ctx, groupedRefs)
require.NoError(t, err)
verifyResponseData(t, data, 3)
verifyResponseData(t, data, 3, int64(1744769099919), int64(1744769099919*10))
}

func TestCassandraOnlineStore_OnlineReadRange_withNoSortKeyFilters(t *testing.T) {
entityKeys := []*types.EntityKey{
{
JoinKeys: []string{"index_id"},
EntityValues: []*types.Value{
{Val: &types.Value_Int64Val{Int64Val: 1}},
},
},
{
JoinKeys: []string{"index_id"},
EntityValues: []*types.Value{
{Val: &types.Value_Int64Val{Int64Val: 2}},
},
},
{
JoinKeys: []string{"index_id"},
EntityValues: []*types.Value{
{Val: &types.Value_Int64Val{Int64Val: 3}},
},
},
}
featureViewNames := []string{"all_dtypes_sorted"}
featureNames := []string{"int_val", "long_val", "float_val", "double_val", "byte_val", "string_val", "timestamp_val", "boolean_val",
"null_int_val", "null_long_val", "null_float_val", "null_double_val", "null_byte_val", "null_string_val", "null_timestamp_val", "null_boolean_val",
"null_array_int_val", "null_array_long_val", "null_array_float_val", "null_array_double_val", "null_array_byte_val", "null_array_string_val",
"null_array_boolean_val", "array_int_val", "array_long_val", "array_float_val", "array_double_val", "array_string_val", "array_boolean_val",
"array_byte_val", "array_timestamp_val", "null_array_timestamp_val", "event_timestamp"}
sortKeyFilters := []*model.SortKeyFilter{}

groupedRefs := &model.GroupedRangeFeatureRefs{
EntityKeys: entityKeys,
FeatureViewNames: featureViewNames,
FeatureNames: featureNames,
SortKeyFilters: sortKeyFilters,
Limit: 10,
IsReverseSortOrder: true,
SortKeyNames: map[string]bool{"event_timestamp": true},
}

data, err := onlineStore.OnlineReadRange(ctx, groupedRefs)
require.NoError(t, err)
verifyResponseData(t, data, 3, int64(0), int64(1744769099919*10))
}

func assertValueType(t *testing.T, actualValue interface{}, expectedType string) {
require.IsType(t, &types.Value{}, actualValue, "Expected value to be of type *types.Value")
assert.Equal(t, expectedType, fmt.Sprintf("%T", actualValue.(*types.Value).GetVal()), expectedType)
}

func verifyResponseData(t *testing.T, data [][]RangeFeatureData, numEntityKeys int) {
func verifyResponseData(t *testing.T, data [][]RangeFeatureData, numEntityKeys int, start int64, end int64) {
assert.Equal(t, numEntityKeys, len(data))

for i := 0; i < numEntityKeys; i++ {
Expand Down Expand Up @@ -356,7 +407,8 @@ func verifyResponseData(t *testing.T, data [][]RangeFeatureData, numEntityKeys i
assert.NotNil(t, data[i][32].Values[0])
assert.IsType(t, time.Time{}, data[i][32].Values[0])
for _, timestamp := range data[i][32].Values {
assert.GreaterOrEqual(t, timestamp.(time.Time).UnixMilli(), int64(1744769099919), "Timestamp should be greater than or equal to 1744769099919")
assert.GreaterOrEqual(t, timestamp.(time.Time).UnixMilli(), start, "Timestamp should be greater than or equal to %d", start)
assert.LessOrEqual(t, timestamp.(time.Time).UnixMilli(), end, "Timestamp should be less than or equal to %d", end)
}
}
}
79 changes: 0 additions & 79 deletions go/internal/feast/server/grpc_server_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,85 +107,6 @@ func TestGetOnlineFeaturesValkey(t *testing.T) {
assert.Equal(t, len(response.Results), len(featureNames)+1)
}

func TestGetOnlineFeaturesRange(t *testing.T) {
ctx := context.Background()
dir := "../../../integration_tests/scylladb/"
err := test.SetupInitializedRepo(dir)
defer test.CleanUpInitializedRepo(dir)
require.NoError(t, err, "Failed to setup initialized repo with err: %v", err)

client, closer := getClient(ctx, "", dir, "")
defer closer()

entities := make(map[string]*types.RepeatedValue)

entities["index_id"] = &types.RepeatedValue{
Val: []*types.Value{
{Val: &types.Value_Int64Val{Int64Val: 1}},
{Val: &types.Value_Int64Val{Int64Val: 2}},
{Val: &types.Value_Int64Val{Int64Val: 3}},
},
}

featureNames := []string{"int_val", "long_val", "float_val", "double_val", "byte_val", "string_val", "timestamp_val", "boolean_val",
"null_int_val", "null_long_val", "null_float_val", "null_double_val", "null_byte_val", "null_string_val", "null_timestamp_val", "null_boolean_val",
"null_array_int_val", "null_array_long_val", "null_array_float_val", "null_array_double_val", "null_array_byte_val", "null_array_string_val",
"null_array_boolean_val", "array_int_val", "array_long_val", "array_float_val", "array_double_val", "array_string_val", "array_boolean_val",
"array_byte_val", "array_timestamp_val", "null_array_timestamp_val"}

var featureNamesWithFeatureView []string

for _, featureName := range featureNames {
featureNamesWithFeatureView = append(featureNamesWithFeatureView, "all_dtypes_sorted:"+featureName)
}

request := &serving.GetOnlineFeaturesRangeRequest{
Kind: &serving.GetOnlineFeaturesRangeRequest_Features{
Features: &serving.FeatureList{
Val: featureNamesWithFeatureView,
},
},
Entities: entities,
SortKeyFilters: []*serving.SortKeyFilter{
{
SortKeyName: "event_timestamp",
Query: &serving.SortKeyFilter_Range{
Range: &serving.SortKeyFilter_RangeQuery{
RangeStart: &types.Value{Val: &types.Value_UnixTimestampVal{UnixTimestampVal: 0}},
},
},
},
},
Limit: 10,
}
response, err := client.GetOnlineFeaturesRange(ctx, request)
assert.NoError(t, err)
assert.NotNil(t, response)
assert.Equal(t, 33, len(response.Results))

for i, featureResult := range response.Results {
assert.Equal(t, 3, len(featureResult.Values))
for _, value := range featureResult.Values {
if i == 0 {
// The first result is the entity key which should only have 1 entry
assert.NotNil(t, value)
assert.Equal(t, 1, len(value.Val), "Entity Key should have 1 value, got %d", len(value.Val))
} else {
featureName := featureNames[i-1] // The first entry is the entity key
if strings.Contains(featureName, "null") {
// For null features, we expect the value to contain 1 entry with a nil value
assert.NotNil(t, value)
assert.Equal(t, 1, len(value.Val), "Feature %s should have one values, got %d", featureName, len(value.Val))
assert.Nil(t, value.Val[0].Val, "Feature %s should have a nil value", featureName)
} else {
assert.NotNil(t, value)
assert.Equal(t, 10, len(value.Val), "Feature %s should have 10 values, got %d", featureName, len(value.Val))
}
}
}
}
}

func getValueType(value interface{}, featureName string) *types.Value {
if value == nil {
return &types.Value{}
Expand Down
Loading
Loading