Skip to content

Conversation

@antoinebaker
Copy link
Contributor

@antoinebaker antoinebaker commented Oct 15, 2025

Fixes #32459.

What does this implement/fix? Explain your changes.

The RidgeCV fit needs to estimate $\text{diag}(G^{-1})$ and $c = G^{-1} y$ where $G = XX^T + \alpha I$.
They are computed thanks to the solve and decompose functions chosen by the gcv_mode option.

This PR uses the full svd of the design matrix $X=U S V^T$:

  • singular values s1, ..., sK with K = min(n_samples, n_features)
  • U (n_samples, n_samples) matrix
  • s = [s1, ..., sK, 0, ..., 0] (n_samples,) vector
  • $G^{-1} = U \text{ diag}\left( \frac{1}{ s^2 + \alpha} \right)U^T$

The implementation in main uses the reduced svd $X=USV^T$:

  • singular values s1, ..., sK with K = min(n_samples, n_features)
  • U (n_samples, K) matrix
  • s = [s1, ..., sK] (K,) vector
  • $G^{-1} = U \text{ diag}\left( \frac{1}{ s^2 + \alpha} - \frac{1}{\alpha} \right) U^T + \frac{1}{\alpha} I$
    This last formula is numerically unstable in the $\alpha \to 0$ limit (we are dividing by alpha)

Any other comments

The full svd will be costly (time and memory) when n_samples >> n_features. Maybe we should try a trick similar to main (reduced svd), but without the division by alpha. Not sure how to do it though :)

@github-actions
Copy link

github-actions bot commented Oct 15, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 66671b7. Link to the linter CI: here

@ogrisel
Copy link
Member

ogrisel commented Oct 17, 2025

I triggered the CUDA build out of curiosity.

However, I agree that I am worried about the performance impact of using full SVD on large datasets. Can you benchmark main vs PR branch for various data shapes?

@antoinebaker
Copy link
Contributor Author

Here a few benchmarks, with fit time and peak memory:

Snippet
from sklearn.linear_model import RidgeCV
from sklearn.datasets import make_regression
from time import time
import tracemalloc

def benchmark(n_samples, n_features, alphas, gcv_mode):
    X, y = make_regression(n_samples=n_samples, n_features=n_features)
    tracemalloc.start()
    reg = RidgeCV(alphas=alphas, gcv_mode=gcv_mode)
    t0 = time()
    reg.fit(X, y)
    elapsed = time() - t0
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    peak /= 1024**2 # MB
    print(f"{elapsed=:.2f}s {peak=:.1f}MB {gcv_mode=} {n_samples=}")
benchmark(n_samples=1000, n_features=100, alphas=[1e-15], gcv_mode="svd")
benchmark(n_samples=10_000, n_features=100, alphas=[1e-15], gcv_mode="svd")
# main
elapsed=0.02s peak=3.2MB gcv_mode='svd' n_samples=1000
elapsed=0.05s peak=31.1MB gcv_mode='svd' n_samples=10000
# PR using full svd
elapsed=0.05s peak=23.8MB gcv_mode='svd' n_samples=1000
elapsed=7.29s peak=2297.0MB gcv_mode='svd' n_samples=10000

@antoinebaker
Copy link
Contributor Author

antoinebaker commented Oct 20, 2025

I think I'll try using the reduced svd of $X$ when n_samples < n_features and the reduced svd of $X^T$ when n_samples > n_features, a bit like the "eigen" mode where we either use the Gram $XX^T$ or covariance $X^TX$ matrix.

EDIT: actually that's a dead end :( Using the reduced svd and the Woodbury identity yields the same formula as main for $G^{-1}$.

@ogrisel
Copy link
Member

ogrisel commented Oct 24, 2025

The performance overhead is indeed not admissible.

To better handle the n_samples >> n_features case, I think we should implement a new variant of RidgeCV for a k-fold (or any user defined CV strategy) by internally forming X[train_fold_i].T @ X[train_fold_i] once for each fold and then calling a cheap xp.linalg.solve for each (fold_idx, alpha) combination (similarly to the Ridge(solver="cholesky") does for a single alpha and no CV.

EDIT: link to the relevant part of Ridge: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/linear_model/_ridge.py#L217-L236

@ogrisel
Copy link
Member

ogrisel commented Oct 24, 2025

@antoinebaker would you be interested in trying to prototype the above approach (possibly in google colab notebook using the array API) and run some benchmark to check that it's a worthwhile approach (both from a computational and a numerical precision standpoint)?

@ogrisel
Copy link
Member

ogrisel commented Oct 24, 2025

Also, I think for Ridge we should introduce a new solver, very similar to what solver="cholesky" does but without the assume_a="pos". This would allow to:

  • naturally support array API because assume_a is not part of the xp.linalg.solve spec.
  • not have to fallback to slow solver="svd" in the presence of collinear features and too small regularization while still converging to the minimum norm solution in the alpha=0 limit.

The assume_a="pos" typically only brings less than 20% perf improvement (if I recall correctly from my dabblings in #29318) at the cost of the too limitations above.

@antoinebaker
Copy link
Contributor Author

antoinebaker commented Oct 24, 2025

I think the idea in #32506 (comment) sounds good and worth pursuing when the number of folds is small (like say cv=5 folds). I'll try in a dedicated PR (or a colab as suggested) to improve RidgeCV with given cv.

However, for this PR and issue #32459 specifically, ie for the gcv_mode option, I think we still need to figure out a solution for the n_samples >> n_features case. In the gcv_mode option, there are n_samples - 1 "folds" (leave one out cross validation). It's only manageable because the implementation uses a clever trick to compute all the leave one out errors, for any alpha, once $diag(G^{-1})$ and $c = G^{-1} y$ are known.

Unfortunately, I still don't see how to make the gcv_mode="svd" work in the n_samples >> n_features and small alpha case:

  • reduced svd is numerically unstable
  • full svd is numerically stable but very expansive (time/memory)

@ogrisel
Copy link
Member

ogrisel commented Oct 27, 2025

The full svd will be costly (time and memory) when n_samples >> n_features. Maybe we should try a trick similar to main (reduced svd), but without the division by alpha. Not sure how to do it though :)

Wouldn't it be possible to first estimate H = diag(alpha * G ** -1) and then compute c = H * y / alpha? Maybe that will be more stable to compute when alpha is very small?

@antoinebaker
Copy link
Contributor Author

Wouldn't it be possible to first estimate H = diag(alpha * G ** -1) and then compute c = H * y / alpha? Maybe that will be more stable to compute when alpha is very small?

Alas, computing alpha*c and alpha*diag(G ** -1) makes the instability worse :( 24 tests now failed with crazy coefficients.

@antoinebaker
Copy link
Contributor Author

For the n_samples < n_features case (wide data), there is I think a satisfying solution.
As K = min(n_samples, n_features) = n_samples, the reduced svd (full_matrices=False) gives the full (n_samples, n_samples) U matrix, and we can use the stable $G^{-1} = U \text{ diag}\left( \frac{1}{ s^2 + \alpha} \right)U^T$ formula. This is done in b8b89ff.

For this case the PR introduces no overhead compared to main:

benchmark(n_samples=1000, n_features=10_000, alphas=[1e-15], gcv_mode="svd")
elapsed=2.14s peak=236.8MB gcv_mode='svd' n_samples=1000 n_features=10000 # PR
elapsed=1.93s peak=236.8MB gcv_mode='svd' n_samples=1000 n_features=10000 # main

and fixes the numerical instability observed in main. I don't have a good minimal test for this case. We expect Ridge(alpha -> 0) to recover the least square norm solution in the noiseless underdetermined case. Plotting the fitted coef vs the least square norm solution gives in this PR:
image
while in main it gives:
image
with runtime warning:

.../sklearn/linear_model/_ridge.py:2253: RuntimeWarning: divide by zero encountered in divide
  squared_errors = (c / G_inverse_diag) ** 2

@antoinebaker
Copy link
Contributor Author

antoinebaker commented Oct 28, 2025

I don't have a good minimal test for this case.

As one can see, we do not not recover exactly the least square norm solution (as found by scipy.linalg.lstsq) but at least it's strongly correlated. For comparison, that's also the case for gcv_mode="eigen":
image

So maybe "allclose" recovery is too much to ask, and checking that it's correlated is good enough ?

@antoinebaker
Copy link
Contributor Author

For the n_samples < n_features case (wide data), there is I think a satisfying solution.

However we still have the performance overhead issue for the n_samples > n_features case.

@ogrisel
Copy link
Member

ogrisel commented Oct 29, 2025

However we still have the performance overhead issue for the n_samples > n_features case.

I think we can be pragmatic and fix the numerical stability problem only when n_samples < n_features which is likely the most useful case (and in particular happens when using the default gcv_mode="auto").

Maybe we could also expand the docstring of gcv_mode to warn that choosing gcv_mode="svd" with n_samples > n_features can lead to numerically unstable results.

@ogrisel
Copy link
Member

ogrisel commented Oct 29, 2025

BTW @antoinebaker, did you set an explicit value to the cond parameter of scipy.linalg.lstsq in your experiments (#32506 (comment)) (as we had to do in #30040)?

@antoinebaker
Copy link
Contributor Author

antoinebaker commented Oct 29, 2025

BTW @antoinebaker, did you set an explicit value to the cond parameter of scipy.linalg.lstsq in your experiments (#32506 (comment)) (as we had to do in #30040)?

Ah good point, I just try, but it doesn't affect the solution. However, good news! We recover exactly the least square norm solution for fit_intercept=False. I need to think about fit_intercept=True, I don't recall what we expect in this case.

@antoinebaker
Copy link
Contributor Author

Interestingly, the full svd induces less overhead than the "eigen" option, which is very slow in the n_samples >> n_features regime:

# eigen, numerically stable
benchmark(n_samples=1000, n_features=100, alphas=[1e-15], gcv_mode="eigen")
benchmark(n_samples=10_000, n_features=100, alphas=[1e-15], gcv_mode="eigen")
elapsed=0.10s peak=23.8MB gcv_mode='eigen' n_samples=1000 n_features=100
elapsed=118.13s peak=2297.0MB gcv_mode='eigen' n_samples=10000 n_features=100
# this PR, full svd, numerically stable
benchmark(n_samples=1000, n_features=100, alphas=[1e-15], gcv_mode="svd")
benchmark(n_samples=10_000, n_features=100, alphas=[1e-15], gcv_mode="svd")
elapsed=0.03s peak=23.8MB gcv_mode='svd' n_samples=1000 n_features=100
elapsed=7.18s peak=2297.0MB gcv_mode='svd' n_samples=10000 n_features=100
# main, reduced svd, numerically unstable
benchmark(n_samples=1000, n_features=100, alphas=[1e-15], gcv_mode="svd")
benchmark(n_samples=10_000, n_features=100, alphas=[1e-15], gcv_mode="svd")
elapsed=0.02s peak=3.2MB gcv_mode='svd' n_samples=1000 n_features=100
elapsed=0.05s peak=31.1MB gcv_mode='svd' n_samples=10000 n_features=100

@ogrisel
Copy link
Member

ogrisel commented Oct 29, 2025

Interestingly, the full svd induces less overhead than the "eigen" option, which is very slow in the n_samples >> n_features regime.

I am so confused now. My previous comment (#32506 (comment)) was wrong: the default gcv_mode="auto" means that "svd" is used when n_samples > n_features. So we still have a tradeoff between a performance and a numerical stability problem for the n_samples >> n_features regime

@ogrisel
Copy link
Member

ogrisel commented Oct 29, 2025

Was the numerical stability problem of the candidate change described in #32506 (comment) observed in the n_samples >> n_features regime?

@antoinebaker
Copy link
Contributor Author

Was the numerical stability problem of the candidate change described in #32506 (comment) observed in the n_samples >> n_features regime?

Yep, it breaks test_ridge_gcv_noiseless which has n_samples > n_features, and other tests too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG RidgeCV with gcv_mode="svd" is unstable when alpha is very small

2 participants