Skip to content

ENH: Support the dot family for user-dtypes#31574

Open
SwayamInSync wants to merge 5 commits into
numpy:mainfrom
SwayamInSync:enh/dot-family-user-dtypes-30793
Open

ENH: Support the dot family for user-dtypes#31574
SwayamInSync wants to merge 5 commits into
numpy:mainfrom
SwayamInSync:enh/dot-family-user-dtypes-30793

Conversation

@SwayamInSync

Copy link
Copy Markdown
Member

PR summary

This PR adds the dot and correlate fallback methods for the user-dtypes. It uses the multiply and matmul operations to delegate the operations as per the input shape.

closes #30793

AI Disclosure

None

@SwayamInSync SwayamInSync changed the title ENH: support the dot family for new-style user DTypes ENH: Support the dot family for user-dtypes Jun 6, 2026
Comment thread numpy/_core/numeric.py Outdated
n = a.shape[axis_a]
m = math.prod(a.shape[ax] for ax in notin_a)
p = math.prod(b.shape[ax] for ax in notin_b)
at = ascontiguousarray(a.transpose(notin_a + [axis_a]).reshape(m, n))

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.dot must make out= as C-style array, so I think doing ascontiguousarray might be OK?
Ideally the matmul/BLAS should handle any layout (for perfomance cases we can revert back to let dtype's BLAS handle this)

If we revert then, this will require new release from numpy-quaddtype (PR numpy/numpy-quaddtype#108 is already landed)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, any ufunc produces a C-contiguous array, so I don't think you'd need to do anything.

@SwayamInSync

Copy link
Copy Markdown
Member Author

Test failures seems independent from this PR

@mhvk mhvk left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments in-line. Functionally, the one about complex conjugation for correlate is the most important, but the others should help simplify the implementation. In particular, please check carefully whether all those ascontiguousarray calls are in fact necessary (by having tests with non-contigous arrays).

Comment thread numpy/_core/numeric.py Outdated
res = np.zeros(res_shape, dtype=res_dtype)
elif a.ndim == 1 and b.ndim == 1:
# 1-D x 1-D: inner product
res = np.matmul(ascontiguousarray(a), ascontiguousarray(b))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you could just do return np.matmul(a, b, out=out) - by default, output is contiguous.

Comment thread numpy/_core/numeric.py Outdated
n = a.shape[axis_a]
m = math.prod(a.shape[ax] for ax in notin_a)
p = math.prod(b.shape[ax] for ax in notin_b)
at = ascontiguousarray(a.transpose(notin_a + [axis_a]).reshape(m, n))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, any ufunc produces a C-contiguous array, so I don't think you'd need to do anything.

Comment thread numpy/_core/numeric.py
p = math.prod(b.shape[ax] for ax in notin_b)
at = ascontiguousarray(a.transpose(notin_a + [axis_a]).reshape(m, n))
bt = ascontiguousarray(b.transpose([axis_b] + notin_b).reshape(n, p))
res = np.matmul(at, bt)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure the detailed handling of the axes is necessary? Remember that with np.matmul, like with all gufuncs, you have the option of explicitly passing in the axes that are contracted; see https://numpy.org/doc/stable/reference/ufuncs.html#optional-keyword-arguments

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, for ndim > 2 dot has a funny behavior that the result shape is a.shape + b.shape minus the reduction axes.
I.e. it doesn't actually do a proper broadcasting/stacked operation, it is the same as tensordot(..., axes=[-1, -2]).

However, tensordot actually does this precise dance already. So we could consider being a bit annoying and just rejecting it for this path (we should deprecate it on the other path I guess) with a note of "please use tensordot instead".

There is another good argument for this, the code has:

#if defined(HAVE_CBLAS)
    if (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2 &&

so we don't even use BLAS for this case, which is ridiculous and unless arrays are small users are better of to use the clearer tensordot() anyway.

(If we do that this might be simple enough to just move it to C.)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback's _dot_contract is the tensordot-style reshape=>matmul, faster than legacy N-D dot, so rejecting it isn't a perf win, it's purely about not carrying the deprecation-worthy N-D dot semantics. I'm fine either way.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I am following precisely, but I am leaning to just reject it with a "use tensordot" message. Let's ping @mhvk, though.
(That would mean we should aim for a deprecation as a tendency, I guess.)

Comment thread numpy/_core/numeric.py Outdated
``_pyarray_correlate`` already handles empty checks, complex conjugation
and the ``len(a) < len(v)`` swap/reverse, so ``a`` is the longer operand.
"""
a = ascontiguousarray(np.asarray(a).ravel())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check all these ascontiguousarray calls are in fact necessary? I rather doubt it, since matmul produces a contiguous array.

Comment thread numpy/_core/numeric.py Outdated
pad = zeros(n2 - 1, dtype=res_dtype)
apad = concatenate([pad, a, pad])
win = apad[arange(n1 + n2 - 1)[:, None] + arange(n2)]
full = np.matmul(win, v)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need vecdot here to ensure complex numbers are properly conjugated (do add at test! also for dot, where complex conjugation is not done [incorrectly, but the fallback needs to mimic that mistake...])

@SwayamInSync

Copy link
Copy Markdown
Member Author

By default, any ufunc produces a C-contiguous array, so I don't think you'd need to do anything.

@mhvk the element-wise ones does not, they preserve the order. So the case for (0-D x n-D) (just a multiply) will not go to C-contig order.

For BLAS ops yes it will always, I kept this earlier because in quaddtype the PR hasn't merged yet and in case any dtype did not conform with this.
But I agree, this should be a bug from dtype side then. For the above discussed case, I will keep and remove as in-general call.

Comment thread numpy/_core/tests/test_multiarray.py Outdated
ref = np.dot(a, b)
got = _dot_fallback(a, b)
assert_array_equal(got, ref)
assert got.shape == np.shape(ref)

@mhvk mhvk Jun 6, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get the shape check by setting strict=True in assert_array_equal.

Also, I'm actually not 100% sure we should care about contiguity, but if we do, we should test it to ensure we don't regress, by adding something like assert got.strides == ref.strides (same in other tests).

@SwayamInSync

SwayamInSync commented Jun 8, 2026

Copy link
Copy Markdown
Member Author

In quaddtype the vecdot and non-C-layout input support is not merged yet (The PRs are there)
So those tests will fail here.

I will try to fast-forward the work there to get the support merge nicely, and maybe should we add quaddtype's test dependecy to direct fetch and build from repo? @seberg

@SwayamInSync

Copy link
Copy Markdown
Member Author

Also I tried, axes= only selects the contracted core dims, but matmul still broadcasts the remaining dims whereas np.dot takes their outer product (a-batch then b-batch). Concretely:
dot((2,3,4),(6,4,5)) => (2,3,6,5), but matmul(..., axes=...) raises ValueError trying to broadcast (2,3) against (6,) and dot((2,3),(4,3,5)) => (2,4,5) while matmul gives (4,2,5) (wrong, no error).
So the transpose/reshape into a single 2-D matmul is needed to express the outer-product contraction

@seberg

seberg commented Jun 8, 2026

Copy link
Copy Markdown
Member

I will try to fast-forward the work there to get the support merge nicely, and maybe should we add quaddtype's test dependecy to direct fetch and build from repo?

I already have one or two tests with quaddtype, but I am not sure it is installed yet. The problem is that that even importing quaddtype can break some other tests. So you need a bit of a tedious skipping for the quaddtype version (without actually importing it).

But, I think that's fine... we can factor that out into a little private helper. (I think you should be able to find that hack if you search, quickly.)

@SwayamInSync

Copy link
Copy Markdown
Member Author

But, I think that's fine... we can factor that out into a little private helper. (I think you should be able to find that hack if you search, quickly.)

Done, added in this PR as it wasn't that disturbing the context

@SwayamInSync

Copy link
Copy Markdown
Member Author

This is ready for 2nd round of review.
Will be happy to discuss design choices or if I missed some case.

except PackageNotFoundError:
pytest.skip("numpy_quaddtype is not installed")
if _pep440.Version(installed) < _pep440.Version(MIN_VERSION):
pytest.skip(f"numpy_quaddtype >= {MIN_VERSION} is required")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to put MIN_VERSION to 1.1 since I think we are at 1.0 (and we need to reject it).
I guess we can make it customizable if needed. The current min-version is vital because even importing is wrong, but after 1.1 is out that isn't the case anymore.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that also works, just weird habit of having even jumps for sub-versions :)

Comment thread numpy/_core/numeric.py Outdated
n1, n2 = a.size, v.size
pad = zeros(n2 - 1, dtype=np.result_type(a.dtype, v.dtype))
apad = concatenate([pad, a, pad])
win = apad[arange(n1 + n2 - 1)[:, None] + arange(n2)]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a sliding_window_view I think. I think matmul will be OK with core dimensions that have self overlap (if not, we could pass axes to make it a non-core dimension). Although that is of course not BLAS compatible (for our code, I think it might even go back into einsum).

Otherwise, you would bloat memory usage like crazy here. (I suspect we can do this, but I am slightly unsure and i tmay make sense to split the PR into two, since I think dot is a bit clearer.)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am too concerned about the memory bloat here + sliding_window_view user-dtypes don't support.

Right to split this, I will do another PR for correlate and with the current and we can make further changes there

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sliding_window_view user-dtypes don't support.

Ah, that is true and annoying :/. Could possibly not go via __array_interface__ there to make it work, but I am not sure that is better than trying a fix for __array_interface__ (heck even adding a dtype field).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: Make functions in dot family and related to support user-defined data types

4 participants