Skip to content

Add ClickHouse query for bandit feedback time series#3979

Merged
virajmehta merged 145 commits intomainfrom
alan/bandits-viz-clickhouse-query
Oct 20, 2025
Merged

Add ClickHouse query for bandit feedback time series#3979
virajmehta merged 145 commits intomainfrom
alan/bandits-viz-clickhouse-query

Conversation

@amishler
Copy link
Member

@amishler amishler commented Oct 16, 2025

Important

Add ClickHouse query for cumulative feedback time series with support for various time windows and extensive testing.

  • Behavior:
    • Adds get_cumulative_feedback_timeseries method to DatabaseClient in database.rs and index.ts to retrieve cumulative feedback statistics.
    • Supports time windows: minute, hour, day, week, month; excludes 'cumulative'.
    • Filters by function, metric, and optional variant names.
  • Bindings:
    • Adds CumulativeFeedbackTimeSeriesPoint and GetCumulativeFeedbackTimeseriesParams types in bindings.
    • Updates TimeWindow type to include 'minute'.
  • Queries:
    • Implements get_cumulative_feedback_timeseries in select_queries.rs with ClickHouse SQL logic.
  • Tests:
    • Adds tests in bandit_queries.rs for various time windows and scenarios, including error handling for 'cumulative'.

This description was created by Ellipsis for f2568ee. You can customize this summary. It will automatically update as commits are pushed.

virajmehta and others added 30 commits September 26, 2025 11:23
Two constraint bugs:
- For the simplex constraint, had 1s everywhere instead of just at indices corresponding to the weights
- Signs were backwards in the SOCP constraints
amishler and others added 2 commits October 16, 2025 18:03
* fixed performance issue in time series sql query

* fixed formatting issue

* fixed issue with sorting groupArray
@github-actions
Copy link
Contributor

TensorZero CI Bot Automated Comment

ClickHouse E2E tests failed due to the new feedback time series query returning too few rows per period/variant. The tests expect cumulative statistics for all variants at each period once that variant has any historical data (i.e., “last observation carried forward”). The current implementation only emits rows for periods in which a variant has new data, so later periods are missing variants that previously had data.

Symptoms from logs:

  • Minute-level test expected 10 points but got 6.
  • Hourly test expected 8 points but got 5.
  • Daily test expected 6 points but got 4.

Root cause:

  • The query builds cumulative stats per-variant using arrays, but the final result only includes periods where that variant had data in that time bucket.
  • It does not expand to all periods × variants and carry forward the last cumulative value for periods without new data.

Fix:

  • After computing cumulative stats per variant (AllCumulativeStats), cross join distinct periods and distinct variants, LEFT JOIN cumulative stats up to each period, and use argMax by period_end to “carry forward” the last cumulative values. Filter out combinations where the variant hasn’t had any data yet (last_period IS NOT NULL). Then apply the time window filter.

This change keeps performance reasonable (still doing the heavy aggregation once) and matches test expectations. No changes to GitHub Actions are required.

Warning

I encountered an error while trying to create a follow-up PR: Failed to create follow-up PR using remote https://x-access-token:***@github.com/tensorzero/tensorzero.git: git apply --whitespace=nowarn /tmp/tensorzero-pr-hXmqHQ/repo/tensorzero.patch failed: error: corrupt patch at line 233
.

The patch I tried to generate is as follows:

diff --git a/tensorzero-core/src/db/clickhouse/select_queries.rs b/tensorzero-core/src/db/clickhouse/select_queries.rs
index 35f0afb9..ad63aef5 100644
--- a/tensorzero-core/src/db/clickhouse/select_queries.rs
+++ b/tensorzero-core/src/db/clickhouse/select_queries.rs
@@ -337,6 +337,7 @@ impl SelectQueries for ClickHouseConnectionInfo {
                 })
             })
     }
+
     async fn get_feedback_timeseries(
         &self,
         function_name: String,
@@ -346,6 +347,7 @@ impl SelectQueries for ClickHouseConnectionInfo {
         interval_minutes: u32,
         max_periods: u32,
     ) -> Result<Vec<FeedbackTimeSeriesPoint>, Error> {
+        // Build an optional variant filter. An explicit empty variant list returns empty results.
         // If variants are passed, build variant filter.
         // If None we don't filter at all;
         // If empty, we'll return an empty vector for consistency
@@ -363,83 +365,110 @@ impl SelectQueries for ClickHouseConnectionInfo {
             }
         };
 
-        // Old implementation using time_buckets CTE and self-join
-        // Clickhouse has no `toEndOfInterval` function, so we use toStartOfInterval + interval to get
-        // the end of each period, since we're computing cumulative stats up to each time point
-        // let time_grouping = format!("toStartOfInterval(minute, INTERVAL {interval_minutes} MINUTE) + INTERVAL {interval_minutes} MINUTE");
-        // let time_filter = format!(
-        //     "minute >= (SELECT max(toStartOfInterval(minute, INTERVAL {interval_minutes} MINUTE)) FROM FeedbackByVariantStatistics) - INTERVAL {max_periods} * {interval_minutes} MINUTE"
-        // );
-        //
-        // // Query to compute cumulative statistics: for each time bucket, aggregate all data from start up to that bucket
-        // let query = format!(
-        //     r"
-        //     WITH time_buckets AS (
-        //         SELECT DISTINCT {time_grouping} as period
-        //         FROM FeedbackByVariantStatistics
-        //         WHERE function_name = '{escaped_function_name}'
-        //             AND metric_name = '{escaped_metric_name}'
-        //             AND {time_filter}
-        //     )
-        //     SELECT
-        //         formatDateTime(tb.period, '%Y-%m-%dT%H:%i:%SZ') as period_end,
-        //         f.variant_name,
-        //         avgMerge(f.feedback_mean) as mean,
-        //         varSampStableMerge(f.feedback_variance) as variance,
-        //         sum(f.count) as count
-        //     FROM time_buckets tb
-        //     INNER JOIN FeedbackByVariantStatistics f
-        //         ON f.function_name = '{escaped_function_name}'
-        //         AND f.metric_name = '{escaped_metric_name}'
-        //         AND f.minute <= tb.period
-        //         {variant_filter}
-        //     GROUP BY tb.period, f.variant_name
-        //     ORDER BY period_end ASC, f.variant_name ASC
-        //     FORMAT JSONEachRow
-        //     ",
-        // );
-
-        // New implementation using window functions for better performance
-        // This computes true cumulative statistics from all historical data
+        // Implementation that:
+        // 1) Aggregates minute-level statistics into requested intervals and merges states per period/variant.
+        // 2) Builds cumulative stats per variant over time (from the beginning up to each period).
+        // 3) Expands to all (period, variant) combinations and carries forward the last observation
+        //    so that each period includes all variants that have any historical data up to that period.
+        // 4) Filters to the most recent max_periods intervals.
         let query = format!(
             r"
             WITH
                 -- CTE 1: Aggregate ALL historical data into time periods (no time filter)
                 AggregatedFilteredFeedbackByVariantStatistics AS (
                     SELECT
-                        toStartOfInterval(minute, INTERVAL {{interval_minutes:UInt32}} MINUTE) + INTERVAL {{interval_minutes:UInt32}} MINUTE AS period_end,
+                        toStartOfInterval(minute, INTERVAL {{interval_minutes:UInt32}} MINUTE)
+                            + INTERVAL {{interval_minutes:UInt32}} MINUTE AS period_end,
                         variant_name,
 
                         -- Apply -MergeState combinator to merge and keep as state for later merging
                         avgMergeState(feedback_mean) AS merged_mean_state,
                         varSampStableMergeState(feedback_variance) AS merged_var_state,
                         sum(count) AS period_count
-
                     FROM FeedbackByVariantStatistics
 
                     WHERE
                         function_name = {{function_name:String}}
                         AND metric_name = {{metric_name:String}}
                         {variant_filter}
 
                     GROUP BY
                         period_end,
                         variant_name
                 ),
 
-                -- CTE 2: For each variant, create sorted arrays of the periodic data.
+                -- CTE 2: For each variant, create sorted arrays of the periodic data
                 ArraysByVariant AS (
                     SELECT
                         variant_name,
-                        -- 3. Unzip the sorted tuples back into individual arrays
-                        arrayMap(x -> x.1, sorted_zipped_arrays) AS periods,
-                        arrayMap(x -> x.2, sorted_zipped_arrays) AS mean_states,
-                        arrayMap(x -> x.3, sorted_zipped_arrays) AS var_states,
-                        arrayMap(x -> x.4, sorted_zipped_arrays) AS counts
+                        -- Unzip the sorted tuples back into individual arrays
+                        arrayMap(x -> x.1, sorted_zipped_arrays) AS periods,
+                        arrayMap(x -> x.2, sorted_zipped_arrays) AS mean_states,
+                        arrayMap(x -> x.3, sorted_zipped_arrays) AS var_states,
+                        arrayMap(x -> x.4, sorted_zipped_arrays) AS counts
                     FROM (
                         SELECT
                             variant_name,
                             (
-                                -- 2. Sort the array of tuples based on the first element (the period_end)
+                                -- Sort the array of tuples based on the first element (the period_end)
                                 arraySort(x -> x.1,
                                     -- 1. Zip the unsorted arrays together into an array of tuples
                                     arrayZip(
                                         groupArray(period_end),
                                         groupArray(merged_mean_state),
                                         groupArray(merged_var_state),
                                         groupArray(period_count)
                                     )
                                 )
                             ) AS sorted_zipped_arrays
                         FROM AggregatedFilteredFeedbackByVariantStatistics
                         GROUP BY variant_name
                     )
                 ),
 
-                -- CTE 3: Compute cumulative stats for all periods
+                -- CTE 3: Compute cumulative stats for all periods (per variant)
                 AllCumulativeStats AS (
                     SELECT
                         periods[i] AS period_end,
                         variant_name,
 
                         arrayReduce('avgMerge', arraySlice(mean_states, 1, i)) AS mean,
                         arrayReduce('varSampStableMerge', arraySlice(var_states, 1, i)) AS variance,
                         arraySum(arraySlice(counts, 1, i)) AS count
 
                     FROM ArraysByVariant
                     ARRAY JOIN arrayEnumerate(periods) AS i
                 ),
 
-                -- CTE 4: Filter to only the most recent max_periods (DateTime arithmetic on DateTime types)
-                FilteredCumulativeStats AS (
+                -- CTE 4: Expand to all (period, variant) combinations and carry forward the last observation
+                ExpandedCumulativeStats AS (
+                    SELECT
+                        p.period_end,
+                        v.variant_name,
+                        argMax(a.mean, a.period_end) AS mean,
+                        argMax(a.variance, a.period_end) AS variance,
+                        argMax(a.count, a.period_end) AS count,
+                        max(a.period_end) AS last_period
+                    FROM
+                        (SELECT DISTINCT period_end FROM AllCumulativeStats) p
+                    CROSS JOIN
+                        (SELECT DISTINCT variant_name FROM AllCumulativeStats) v
+                    LEFT JOIN AllCumulativeStats a
+                        ON a.variant_name = v.variant_name
+                        AND a.period_end <= p.period_end
+                    GROUP BY p.period_end, v.variant_name
+                ),
+
+                -- CTE 5: Filter to only the most recent max_periods of intervals and exclude variants with no history yet
+                FilteredCumulativeStats AS (
                     SELECT
                         period_end,
                         variant_name,
                         mean,
                         variance,
                         count
-                    FROM AllCumulativeStats
-                    WHERE period_end >= (
-                        SELECT max(period_end)
-                        FROM AllCumulativeStats
-                    ) - INTERVAL {{max_periods:UInt32}} * {{interval_minutes:UInt32}} MINUTE
+                    FROM ExpandedCumulativeStats
+                    WHERE
+                        last_period IS NOT NULL
+                        AND period_end >= (
+                            SELECT max(period_end)
+                            FROM AllCumulativeStats
+                        ) - INTERVAL {{max_periods:UInt32}} * {{interval_minutes:UInt32}} MINUTE
                 )
 
             -- Final SELECT: Format the DateTime to string
             SELECT
                 formatDateTime(period_end, '%Y-%m-%dT%H:%i:%SZ') AS period_end,
                 variant_name,
                 mean,
                 variance,
                 count
             FROM FilteredCumulativeStats
             ORDER BY
                 period_end ASC,
                 variant_name ASC
             FORMAT JSONEachRow
             ",
         );
 
         // Create parameters HashMap
         let interval_minutes_str = interval_minutes.to_string();
         let max_periods_str = max_periods.to_string();
         let params = std::collections::HashMap::from([
             ("function_name", function_name.as_str()),
             ("metric_name", metric_name.as_str()),
             ("interval_minutes", interval_minutes_str.as_str()),
             ("max_periods", max_periods_str.as_str()),
         ]);
 
         let response = self.run_query_synchronous(query, &params).await?;
 
         // Deserialize the results into FeedbackTimeSeriesPoint
         response
             .response
             .trim()
             .lines()
             .map(|row| {
                 serde_json::from_str(row).map_err(|e| {
                     Error::new(ErrorDetails::ClickHouseDeserialization {
                         message: format!("Failed to deserialize FeedbackTimeSeriesPoint: {e}"),
                     })
                 })
             })
             .collect::<Result<Vec<_>, _>>()
     }
 }

virajmehta
virajmehta previously approved these changes Oct 17, 2025
@virajmehta virajmehta enabled auto-merge October 17, 2025 16:00
@virajmehta virajmehta added this pull request to the merge queue Oct 17, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 17, 2025
@amishler amishler enabled auto-merge October 17, 2025 18:23
@amishler amishler added this pull request to the merge queue Oct 17, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 17, 2025
@virajmehta virajmehta added this pull request to the merge queue Oct 20, 2025
Merged via the queue into main with commit 2265c30 Oct 20, 2025
30 checks passed
@virajmehta virajmehta deleted the alan/bandits-viz-clickhouse-query branch October 20, 2025 14:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants