ENH: Support the dot family for user-dtypes#31574
Conversation
dot family for new-style user DTypesdot family for user-dtypes
| n = a.shape[axis_a] | ||
| m = math.prod(a.shape[ax] for ax in notin_a) | ||
| p = math.prod(b.shape[ax] for ax in notin_b) | ||
| at = ascontiguousarray(a.transpose(notin_a + [axis_a]).reshape(m, n)) |
There was a problem hiding this comment.
np.dot must make out= as C-style array, so I think doing ascontiguousarray might be OK?
Ideally the matmul/BLAS should handle any layout (for perfomance cases we can revert back to let dtype's BLAS handle this)
If we revert then, this will require new release from numpy-quaddtype (PR numpy/numpy-quaddtype#108 is already landed)
There was a problem hiding this comment.
By default, any ufunc produces a C-contiguous array, so I don't think you'd need to do anything.
|
Test failures seems independent from this PR |
mhvk
left a comment
There was a problem hiding this comment.
Some comments in-line. Functionally, the one about complex conjugation for correlate is the most important, but the others should help simplify the implementation. In particular, please check carefully whether all those ascontiguousarray calls are in fact necessary (by having tests with non-contigous arrays).
| res = np.zeros(res_shape, dtype=res_dtype) | ||
| elif a.ndim == 1 and b.ndim == 1: | ||
| # 1-D x 1-D: inner product | ||
| res = np.matmul(ascontiguousarray(a), ascontiguousarray(b)) |
There was a problem hiding this comment.
Here, you could just do return np.matmul(a, b, out=out) - by default, output is contiguous.
| n = a.shape[axis_a] | ||
| m = math.prod(a.shape[ax] for ax in notin_a) | ||
| p = math.prod(b.shape[ax] for ax in notin_b) | ||
| at = ascontiguousarray(a.transpose(notin_a + [axis_a]).reshape(m, n)) |
There was a problem hiding this comment.
By default, any ufunc produces a C-contiguous array, so I don't think you'd need to do anything.
| p = math.prod(b.shape[ax] for ax in notin_b) | ||
| at = ascontiguousarray(a.transpose(notin_a + [axis_a]).reshape(m, n)) | ||
| bt = ascontiguousarray(b.transpose([axis_b] + notin_b).reshape(n, p)) | ||
| res = np.matmul(at, bt) |
There was a problem hiding this comment.
Are you sure the detailed handling of the axes is necessary? Remember that with np.matmul, like with all gufuncs, you have the option of explicitly passing in the axes that are contracted; see https://numpy.org/doc/stable/reference/ufuncs.html#optional-keyword-arguments
There was a problem hiding this comment.
Unfortunately, for ndim > 2 dot has a funny behavior that the result shape is a.shape + b.shape minus the reduction axes.
I.e. it doesn't actually do a proper broadcasting/stacked operation, it is the same as tensordot(..., axes=[-1, -2]).
However, tensordot actually does this precise dance already. So we could consider being a bit annoying and just rejecting it for this path (we should deprecate it on the other path I guess) with a note of "please use tensordot instead".
There is another good argument for this, the code has:
#if defined(HAVE_CBLAS)
if (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2 &&
so we don't even use BLAS for this case, which is ridiculous and unless arrays are small users are better of to use the clearer tensordot() anyway.
(If we do that this might be simple enough to just move it to C.)
There was a problem hiding this comment.
fallback's _dot_contract is the tensordot-style reshape=>matmul, faster than legacy N-D dot, so rejecting it isn't a perf win, it's purely about not carrying the deprecation-worthy N-D dot semantics. I'm fine either way.
There was a problem hiding this comment.
Not sure I am following precisely, but I am leaning to just reject it with a "use tensordot" message. Let's ping @mhvk, though.
(That would mean we should aim for a deprecation as a tendency, I guess.)
| ``_pyarray_correlate`` already handles empty checks, complex conjugation | ||
| and the ``len(a) < len(v)`` swap/reverse, so ``a`` is the longer operand. | ||
| """ | ||
| a = ascontiguousarray(np.asarray(a).ravel()) |
There was a problem hiding this comment.
Did you check all these ascontiguousarray calls are in fact necessary? I rather doubt it, since matmul produces a contiguous array.
| pad = zeros(n2 - 1, dtype=res_dtype) | ||
| apad = concatenate([pad, a, pad]) | ||
| win = apad[arange(n1 + n2 - 1)[:, None] + arange(n2)] | ||
| full = np.matmul(win, v) |
There was a problem hiding this comment.
I think you need vecdot here to ensure complex numbers are properly conjugated (do add at test! also for dot, where complex conjugation is not done [incorrectly, but the fallback needs to mimic that mistake...])
@mhvk the element-wise ones does not, they preserve the order. So the case for (0-D x n-D) (just a multiply) will not go to C-contig order. For BLAS ops yes it will always, I kept this earlier because in quaddtype the PR hasn't merged yet and in case any dtype did not conform with this. |
| ref = np.dot(a, b) | ||
| got = _dot_fallback(a, b) | ||
| assert_array_equal(got, ref) | ||
| assert got.shape == np.shape(ref) |
There was a problem hiding this comment.
You can get the shape check by setting strict=True in assert_array_equal.
Also, I'm actually not 100% sure we should care about contiguity, but if we do, we should test it to ensure we don't regress, by adding something like assert got.strides == ref.strides (same in other tests).
|
In quaddtype the I will try to fast-forward the work there to get the support merge nicely, and maybe should we add quaddtype's test dependecy to direct fetch and build from repo? @seberg |
|
Also I tried, |
I already have one or two tests with quaddtype, but I am not sure it is installed yet. The problem is that that even importing quaddtype can break some other tests. So you need a bit of a tedious skipping for the quaddtype version (without actually importing it). But, I think that's fine... we can factor that out into a little private helper. (I think you should be able to find that hack if you search, quickly.) |
Done, added in this PR as it wasn't that disturbing the context |
|
This is ready for 2nd round of review. |
| except PackageNotFoundError: | ||
| pytest.skip("numpy_quaddtype is not installed") | ||
| if _pep440.Version(installed) < _pep440.Version(MIN_VERSION): | ||
| pytest.skip(f"numpy_quaddtype >= {MIN_VERSION} is required") |
There was a problem hiding this comment.
Did you mean to put MIN_VERSION to 1.1 since I think we are at 1.0 (and we need to reject it).
I guess we can make it customizable if needed. The current min-version is vital because even importing is wrong, but after 1.1 is out that isn't the case anymore.
There was a problem hiding this comment.
Yeah that also works, just weird habit of having even jumps for sub-versions :)
| n1, n2 = a.size, v.size | ||
| pad = zeros(n2 - 1, dtype=np.result_type(a.dtype, v.dtype)) | ||
| apad = concatenate([pad, a, pad]) | ||
| win = apad[arange(n1 + n2 - 1)[:, None] + arange(n2)] |
There was a problem hiding this comment.
This can be a sliding_window_view I think. I think matmul will be OK with core dimensions that have self overlap (if not, we could pass axes to make it a non-core dimension). Although that is of course not BLAS compatible (for our code, I think it might even go back into einsum).
Otherwise, you would bloat memory usage like crazy here. (I suspect we can do this, but I am slightly unsure and i tmay make sense to split the PR into two, since I think dot is a bit clearer.)
There was a problem hiding this comment.
I am too concerned about the memory bloat here + sliding_window_view user-dtypes don't support.
Right to split this, I will do another PR for correlate and with the current and we can make further changes there
There was a problem hiding this comment.
sliding_window_viewuser-dtypes don't support.
Ah, that is true and annoying :/. Could possibly not go via __array_interface__ there to make it work, but I am not sure that is better than trying a fix for __array_interface__ (heck even adding a dtype field).
PR summary
This PR adds the
dotandcorrelatefallback methods for the user-dtypes. It uses themultiplyandmatmuloperations to delegate the operations as per the input shape.closes #30793
AI Disclosure
None