Add ClickHouse query for bandit feedback time series#3979
Conversation
…viraj/experimentation-config
Two constraint bugs: - For the simplex constraint, had 1s everywhere instead of just at indices corresponding to the weights - Signs were backwards in the SOCP constraints
* fixed performance issue in time series sql query * fixed formatting issue * fixed issue with sorting groupArray
TensorZero CI Bot Automated CommentClickHouse E2E tests failed due to the new feedback time series query returning too few rows per period/variant. The tests expect cumulative statistics for all variants at each period once that variant has any historical data (i.e., “last observation carried forward”). The current implementation only emits rows for periods in which a variant has new data, so later periods are missing variants that previously had data. Symptoms from logs:
Root cause:
Fix:
This change keeps performance reasonable (still doing the heavy aggregation once) and matches test expectations. No changes to GitHub Actions are required. Warning I encountered an error while trying to create a follow-up PR: Failed to create follow-up PR using remote https://x-access-token:***@github.com/tensorzero/tensorzero.git: git apply --whitespace=nowarn /tmp/tensorzero-pr-hXmqHQ/repo/tensorzero.patch failed: error: corrupt patch at line 233 The patch I tried to generate is as follows: diff --git a/tensorzero-core/src/db/clickhouse/select_queries.rs b/tensorzero-core/src/db/clickhouse/select_queries.rs
index 35f0afb9..ad63aef5 100644
--- a/tensorzero-core/src/db/clickhouse/select_queries.rs
+++ b/tensorzero-core/src/db/clickhouse/select_queries.rs
@@ -337,6 +337,7 @@ impl SelectQueries for ClickHouseConnectionInfo {
})
})
}
+
async fn get_feedback_timeseries(
&self,
function_name: String,
@@ -346,6 +347,7 @@ impl SelectQueries for ClickHouseConnectionInfo {
interval_minutes: u32,
max_periods: u32,
) -> Result<Vec<FeedbackTimeSeriesPoint>, Error> {
+ // Build an optional variant filter. An explicit empty variant list returns empty results.
// If variants are passed, build variant filter.
// If None we don't filter at all;
// If empty, we'll return an empty vector for consistency
@@ -363,83 +365,110 @@ impl SelectQueries for ClickHouseConnectionInfo {
}
};
- // Old implementation using time_buckets CTE and self-join
- // Clickhouse has no `toEndOfInterval` function, so we use toStartOfInterval + interval to get
- // the end of each period, since we're computing cumulative stats up to each time point
- // let time_grouping = format!("toStartOfInterval(minute, INTERVAL {interval_minutes} MINUTE) + INTERVAL {interval_minutes} MINUTE");
- // let time_filter = format!(
- // "minute >= (SELECT max(toStartOfInterval(minute, INTERVAL {interval_minutes} MINUTE)) FROM FeedbackByVariantStatistics) - INTERVAL {max_periods} * {interval_minutes} MINUTE"
- // );
- //
- // // Query to compute cumulative statistics: for each time bucket, aggregate all data from start up to that bucket
- // let query = format!(
- // r"
- // WITH time_buckets AS (
- // SELECT DISTINCT {time_grouping} as period
- // FROM FeedbackByVariantStatistics
- // WHERE function_name = '{escaped_function_name}'
- // AND metric_name = '{escaped_metric_name}'
- // AND {time_filter}
- // )
- // SELECT
- // formatDateTime(tb.period, '%Y-%m-%dT%H:%i:%SZ') as period_end,
- // f.variant_name,
- // avgMerge(f.feedback_mean) as mean,
- // varSampStableMerge(f.feedback_variance) as variance,
- // sum(f.count) as count
- // FROM time_buckets tb
- // INNER JOIN FeedbackByVariantStatistics f
- // ON f.function_name = '{escaped_function_name}'
- // AND f.metric_name = '{escaped_metric_name}'
- // AND f.minute <= tb.period
- // {variant_filter}
- // GROUP BY tb.period, f.variant_name
- // ORDER BY period_end ASC, f.variant_name ASC
- // FORMAT JSONEachRow
- // ",
- // );
-
- // New implementation using window functions for better performance
- // This computes true cumulative statistics from all historical data
+ // Implementation that:
+ // 1) Aggregates minute-level statistics into requested intervals and merges states per period/variant.
+ // 2) Builds cumulative stats per variant over time (from the beginning up to each period).
+ // 3) Expands to all (period, variant) combinations and carries forward the last observation
+ // so that each period includes all variants that have any historical data up to that period.
+ // 4) Filters to the most recent max_periods intervals.
let query = format!(
r"
WITH
-- CTE 1: Aggregate ALL historical data into time periods (no time filter)
AggregatedFilteredFeedbackByVariantStatistics AS (
SELECT
- toStartOfInterval(minute, INTERVAL {{interval_minutes:UInt32}} MINUTE) + INTERVAL {{interval_minutes:UInt32}} MINUTE AS period_end,
+ toStartOfInterval(minute, INTERVAL {{interval_minutes:UInt32}} MINUTE)
+ + INTERVAL {{interval_minutes:UInt32}} MINUTE AS period_end,
variant_name,
-- Apply -MergeState combinator to merge and keep as state for later merging
avgMergeState(feedback_mean) AS merged_mean_state,
varSampStableMergeState(feedback_variance) AS merged_var_state,
sum(count) AS period_count
-
FROM FeedbackByVariantStatistics
WHERE
function_name = {{function_name:String}}
AND metric_name = {{metric_name:String}}
{variant_filter}
GROUP BY
period_end,
variant_name
),
- -- CTE 2: For each variant, create sorted arrays of the periodic data.
+ -- CTE 2: For each variant, create sorted arrays of the periodic data
ArraysByVariant AS (
SELECT
variant_name,
- -- 3. Unzip the sorted tuples back into individual arrays
- arrayMap(x -> x.1, sorted_zipped_arrays) AS periods,
- arrayMap(x -> x.2, sorted_zipped_arrays) AS mean_states,
- arrayMap(x -> x.3, sorted_zipped_arrays) AS var_states,
- arrayMap(x -> x.4, sorted_zipped_arrays) AS counts
+ -- Unzip the sorted tuples back into individual arrays
+ arrayMap(x -> x.1, sorted_zipped_arrays) AS periods,
+ arrayMap(x -> x.2, sorted_zipped_arrays) AS mean_states,
+ arrayMap(x -> x.3, sorted_zipped_arrays) AS var_states,
+ arrayMap(x -> x.4, sorted_zipped_arrays) AS counts
FROM (
SELECT
variant_name,
(
- -- 2. Sort the array of tuples based on the first element (the period_end)
+ -- Sort the array of tuples based on the first element (the period_end)
arraySort(x -> x.1,
-- 1. Zip the unsorted arrays together into an array of tuples
arrayZip(
groupArray(period_end),
groupArray(merged_mean_state),
groupArray(merged_var_state),
groupArray(period_count)
)
)
) AS sorted_zipped_arrays
FROM AggregatedFilteredFeedbackByVariantStatistics
GROUP BY variant_name
)
),
- -- CTE 3: Compute cumulative stats for all periods
+ -- CTE 3: Compute cumulative stats for all periods (per variant)
AllCumulativeStats AS (
SELECT
periods[i] AS period_end,
variant_name,
arrayReduce('avgMerge', arraySlice(mean_states, 1, i)) AS mean,
arrayReduce('varSampStableMerge', arraySlice(var_states, 1, i)) AS variance,
arraySum(arraySlice(counts, 1, i)) AS count
FROM ArraysByVariant
ARRAY JOIN arrayEnumerate(periods) AS i
),
- -- CTE 4: Filter to only the most recent max_periods (DateTime arithmetic on DateTime types)
- FilteredCumulativeStats AS (
+ -- CTE 4: Expand to all (period, variant) combinations and carry forward the last observation
+ ExpandedCumulativeStats AS (
+ SELECT
+ p.period_end,
+ v.variant_name,
+ argMax(a.mean, a.period_end) AS mean,
+ argMax(a.variance, a.period_end) AS variance,
+ argMax(a.count, a.period_end) AS count,
+ max(a.period_end) AS last_period
+ FROM
+ (SELECT DISTINCT period_end FROM AllCumulativeStats) p
+ CROSS JOIN
+ (SELECT DISTINCT variant_name FROM AllCumulativeStats) v
+ LEFT JOIN AllCumulativeStats a
+ ON a.variant_name = v.variant_name
+ AND a.period_end <= p.period_end
+ GROUP BY p.period_end, v.variant_name
+ ),
+
+ -- CTE 5: Filter to only the most recent max_periods of intervals and exclude variants with no history yet
+ FilteredCumulativeStats AS (
SELECT
period_end,
variant_name,
mean,
variance,
count
- FROM AllCumulativeStats
- WHERE period_end >= (
- SELECT max(period_end)
- FROM AllCumulativeStats
- ) - INTERVAL {{max_periods:UInt32}} * {{interval_minutes:UInt32}} MINUTE
+ FROM ExpandedCumulativeStats
+ WHERE
+ last_period IS NOT NULL
+ AND period_end >= (
+ SELECT max(period_end)
+ FROM AllCumulativeStats
+ ) - INTERVAL {{max_periods:UInt32}} * {{interval_minutes:UInt32}} MINUTE
)
-- Final SELECT: Format the DateTime to string
SELECT
formatDateTime(period_end, '%Y-%m-%dT%H:%i:%SZ') AS period_end,
variant_name,
mean,
variance,
count
FROM FilteredCumulativeStats
ORDER BY
period_end ASC,
variant_name ASC
FORMAT JSONEachRow
",
);
// Create parameters HashMap
let interval_minutes_str = interval_minutes.to_string();
let max_periods_str = max_periods.to_string();
let params = std::collections::HashMap::from([
("function_name", function_name.as_str()),
("metric_name", metric_name.as_str()),
("interval_minutes", interval_minutes_str.as_str()),
("max_periods", max_periods_str.as_str()),
]);
let response = self.run_query_synchronous(query, ¶ms).await?;
// Deserialize the results into FeedbackTimeSeriesPoint
response
.response
.trim()
.lines()
.map(|row| {
serde_json::from_str(row).map_err(|e| {
Error::new(ErrorDetails::ClickHouseDeserialization {
message: format!("Failed to deserialize FeedbackTimeSeriesPoint: {e}"),
})
})
})
.collect::<Result<Vec<_>, _>>()
}
} |
Important
Add ClickHouse query for cumulative feedback time series with support for various time windows and extensive testing.
get_cumulative_feedback_timeseriesmethod toDatabaseClientindatabase.rsandindex.tsto retrieve cumulative feedback statistics.CumulativeFeedbackTimeSeriesPointandGetCumulativeFeedbackTimeseriesParamstypes inbindings.TimeWindowtype to include 'minute'.get_cumulative_feedback_timeseriesinselect_queries.rswith ClickHouse SQL logic.bandit_queries.rsfor various time windows and scenarios, including error handling for 'cumulative'.This description was created by
for f2568ee. You can customize this summary. It will automatically update as commits are pushed.