-
-
Notifications
You must be signed in to change notification settings - Fork 26.4k
Remove unnecessary dtype conversion from pairwise_distances_argmin_* #32511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Remove unnecessary dtype conversion from pairwise_distances_argmin_* #32511
Conversation
|
Thanks for your offer to add my tests. I'd appreciate it! This is the test that I wrote It confirms that |
|
@pushkar-hue I added you as a collaborator to my fork, I believe that may be the easiest way for you to push the tests there. |
…for incorrect dtype conversion
|
@IgnacioJPickering I have just pushed the commit for tests you can take a look let me know if there's anything I need to change. Thanks again for the collaboration! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix @IgnacioJPickering and @pushkar-hue. Could you please add a changelog entry for this?
See instructions in https://github.com/scikit-learn/scikit-learn/blob/main/doc/whats_new/upcoming_changes/README.md for details.
2ec3d1e to
21bbfbf
Compare
|
Hi @ogrisel, The changelog has been added. The PR should be ready for final review now. Thanks! |
|
@pushkar-hue can you have a look at the codecov report: and see if there is an easy way to cover those lines by a small extension to the tests (e.g. testing with Python lists)? |
|
I really apologize for this confusion. @IgnacioJPickering already did that due to this very reason but when i added changelog and tried to clean up my commit history somehow the test commit was removed. I re added those tests this should cover the codecov warning. I again apologize for my mistake. |
|
@ogrisel no problem, I've reinstated the correct tests. |
| Y is not None and not xp.isdtype(Y.dtype, "bool") | ||
| ): | ||
| msg = f"Data was converted to boolean for metric {metric}" | ||
| warnings.warn(msg, DataConversionWarning) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not let pairwise_distances give this warning? The concern here is also that we issue this warning before we actually do the conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I looked into it and it looks like @IgnacioJPickering just moved the the existing logic inside the helper function.
The old code in pairwise_distances had a manual block that issued this exact warning:
dtype = bool if metric in PAIRWISE_BOOLEAN_FUNCTIONS else "infer_float"
if dtype is bool and (X.dtype != bool or (Y is not None and Y.dtype != bool)):
msg = "Data was converted to boolean for metric %s" % metric
warnings.warn(msg, DataConversionWarning)
X, Y = check_pairwise_arrays(
X, Y, dtype=dtype, ensure_all_finite=ensure_all_finite
)
# precompute data-derived metric params
params = _precompute_metric_params(X, Y, metric=metric, **kwds)
kwds.update(**params)
if effective_n_jobs(n_jobs) == 1 and X is Y:
return distance.squareform(distance.pdist(X, metric=metric, **kwds))
func = partial(distance.cdist, metric=metric, **kwds)
it is now moved into the helper functions and used as such:
X, Y, dtype = _find_dtype_for_check_pairwise_arrays(X, Y, metric)
X, Y = check_pairwise_arrays(
X, Y, dtype=dtype, ensure_all_finite=ensure_all_finite
)
# precompute data-derived metric params
params = _precompute_metric_params(X, Y, metric=metric, **kwds)
kwds.update(**params)
if effective_n_jobs(n_jobs) == 1 and X is Y:
return distance.squareform(distance.pdist(X, metric=metric, **kwds))
func = partial(distance.cdist, metric=metric, **kwds)
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
we are actually issuing the waring after the conversion as it was before the refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucyleeow I understand the concern, its not great that the warning is for what a different function does (I believe what @pushkar-hue means is that this is the way it was done before in the code too, but I see its not optimal)
The reason we can't delegate the warning to pairwise_distances is because that function may not get called if ArgKMin is usable for the metric.
Do you think it would be better to write a wrapper _check_pairwise_arrays_for_metric(...)? This would find the dtype, raise the warning and also call check_pairwise_arrays which is what ultimately does the conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucyleeow I went ahead and did this, since I thought it was cleaner. and also clarified the comment and made it a bit more precise. Hopefully things are much more clear now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we can't delegate the warning to pairwise_distances is because that function may not get called if ArgKMin is usable for the metric.
Aren't bool metrics specifically excluded from ArgKMin?
| def valid_metrics(cls) -> List[str]: | |
| excluded = { | |
| # PyFunc cannot be supported because it necessitates interacting with | |
| # the CPython interpreter to call user defined functions. | |
| "pyfunc", | |
| "mahalanobis", # is numerically unstable | |
| # In order to support discrete distance metrics, we need to have a | |
| # stable simultaneous sort which preserves the order of the indices | |
| # because there generally is a lot of occurrences for a given values | |
| # of distances in this case. | |
| # TODO: implement a stable simultaneous_sort. | |
| "hamming", | |
| *BOOL_METRICS, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is that check_pairwise_arrays has to be called before the ArgKmin.is_usable_for(...)
Why is this a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand https://github.com/scikit-learn/scikit-learn/pull/32511/files#r2485450646: check_pairwise_arrays would no longer be called at all for PAIRWISE_BOOLEAN metrics: so error messages about inconsistent shapes would not be raised when passing invalid inputs for boolean metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ogrisel check_pairwise_arrays would still be called, since for boolean metrics the function would delegate to pairwise_distances_chunked, which itself delegates to pairwise_distances, which calls check_pairwise_arrays. I recognize it is a bit confusing though, but I think the logic checks out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucyleeow It is not a problem, but we must first filter for boolean arrays. Calling check_pairwise_arrays in these functions without first filtering for boolean metrics is the bug that is currently in main.
By default, if check_pairwise_arrays is called with infer_float it converts the arrays to float unconditionally, with no warning, which is not necessary since the arrays are bool to begin with.
Afterwards a second check in pairwise_distances checks the dtypes of the arrays, converts the arrays back to bool, and raises a conversion warning. This means there are 2 casts and 1 warning where there should have been none.
If I filter for PAIRWISE_BOOLEAN first then I can call check_pairwise_arrays only in the case that ArgKmin may be called, and delegate the rest of the checks to pairwise_distances_chunked.
@lucyleeow @ogrisel
From the comments in the PR I believe this, together with a comment specifying why check_pairwise_arrays is being called early in the case that ArgKmin may be called, is preferable, since it seems to me the code is still hard to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I've modified the code so that the checks are delegated to pairwise_distances_chunked I believe this should get rid of the confusion.
The TypeError is raised only in the case where sparse arrays are forwarded to pairwise_distances, which is a single place in the code, the same place where the warning is raised, and the cast is performed.
This only required an if check for the check_pairwise_arrays call, which is only performed if the metric is not a PAIRWISE_BOOLEAN, so we avoid casting it to bool like what is currently in main. I added a comment for extra clarity.
In the end I think my initial fix was overly complicated, this gets rid of the issue and has minimal modifications.
Co-authored-by: Lucy Liu <jliu176@gmail.com>
Co-authored-by: Lucy Liu <jliu176@gmail.com>
Co-authored-by: Lucy Liu <jliu176@gmail.com>
…ithub.com:IgnacioJPickering/scikit-learn into fix/ipickering/remove-incorrect-dtype-conversions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One comment but otherwise looks good.
@ogrisel I think this may be worth a second review from you as the code has changed quite a bit.
…ithub.com:IgnacioJPickering/scikit-learn into fix/ipickering/remove-incorrect-dtype-conversions
Reference Issues/PRs
Fixes #32495
What does this implement/fix? Explain your changes.
In
pairwise_distances_argmin_*an initial check withcheck_pairwise_arrays(...)usedthe default argument
dtype="infer_float". For boolean metrics this triggered an unnecessarytype conversion to
float64, even when the arrays were originally bool. When the arrays were forwarded topairwise_distances, anothercall to
check_pairwise_arrays(...)cast the arrays back to bool, and this triggered a warning that there was data conversion.I've added a new utility function
_find_floating_or_bool_dtype_allow_sparse(X, Y, metric, xp)which works in an equivalent way to_find_floating_dtype_allow_sparsebut is metric-aware, and returnsboolfor boolean metrics.Additionally, I've factored out the warnings into another helper function to reduce duplication.
@pushkar-hue