Skip to content

Conversation

@anjali411
Copy link
Contributor

@anjali411 anjali411 commented Aug 19, 2020

torch.conj is a very commonly used operator for complex tensors, but it's mathematically a no op for real tensors. Switching to tensorflow gradients for complex tensors (as discussed in #41857) would involve adding torch.conj() to the backward definitions for a lot of operators. In order to preserve autograd performance for real tensors and maintain numpy compatibility for torch.conj, this PR updates torch.conj() which behaves the same for complex tensors but performs a view/returns self tensor for tensors of non-complex dtypes. The documentation states that the returned tensor for a real input shouldn't be mutated. We could perhaps return an immutable tensor for this case in future when that functionality is available (@zdevito @ezyang ).

Stack from ghstack:

Differential Revision: D23460493

@anjali411 anjali411 requested a review from zou3519 August 19, 2020 16:10
@anjali411 anjali411 requested review from albanD and ezyang August 19, 2020 16:10
anjali411 added a commit that referenced this pull request Aug 19, 2020
ghstack-source-id: e00a233
Pull Request resolved: #43270
@dr-ci
Copy link

dr-ci bot commented Aug 19, 2020

💊 CI failures summary and remediations

As of commit 33c3375 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 40 times.

}
}

Tensor& conj_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, conj_stub); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this one?
What happens if you try to use autograd with it? Does it still raise a nice error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow. it's only called through at::_conj which has a definition in derivatives.yaml. how else can you use it with autograd?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is not called from _conj no.
And the user can call it directly by doing a.conj(out=b).
And this should raise an error if either a or b requires grad.

Computes the element-wise conjugate of the given :attr:`input` tensor.
Computes the element-wise conjugate of the given :attr:`input` tensor. If :attr:`self` tensor
has a non-complex dtype, this function returns the :attr:`self` tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should be input here not self right?

Comment on lines +194 to +199
Tensor conj(const Tensor& self) {
if (!self.is_complex()) {
return self;
}
return at::_conj(self);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems BC-breaking: previously if a user passed in a real tensor to conj, they'd get back a copy of said tensor. What was the rationale behind this change?

FWIW, in numpy, conj returns a copy of the original tensor if it wasn't complex:

>>> x = np.array([0.1, 0.1])
>>> y = np.conj(x)
>>> y.flags
...
OWNDATA : True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in this issue: #41857 switching to the tensorflow style gradients would involve addition of a bunch of conj in the gradient definitions. In order to preserve the backward performance for real tensors, it seems like a good option to make torch.conj a no op for real tensors.

Copy link
Contributor

@zou3519 zou3519 Aug 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so the decision in that thread was to go for TF-style gradients. However, TF-style gradients require conj-calls everywhere in the gradient formulas. Because of that, we're trying to optimize calls to conj to avoid regressions.

The downside of calling the operator "fast_conj" is that writers of gradient formulas might use "conj" instead of "fast_conj". If we go down that route it would be nice to document somewhere (derivatives.yaml?) that users should use fast_conj and not conj. On the other hand, I think it is probably OK for us to break BC and to diverge from numpy and have our "conj" not do copies when possible but I am not sure which of these two solutions would be better.

Exploring the design space some more, I noticed that there were some comments about having a conj flag on tensors that reinterprets the tensor as its conjugate (#41857 (comment)). Was there a decision on that? That seems like it would eliminate the need to have a fast_conj

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand why we want to have a function with this behavior, but have a few questions/suggestions:

  • I would prefer our conj have the same behavior as NumPy's
  • torch.conj is, conceptually, a universal function, and all universal functions create new outputs
    • for example, torch.abs(uint8) produces a new tensor, even though it could mathematically fulfill its contract by returning self
    • analogously, would we expect torch.ceil(int64) to return self or a copy of self?
  • do we expect this function to be used by anyone but developers? Is making a non-user-facing function self_or_conj an option to address the scenarios we'd like to target? If we're uncertain about those scenarios, could we start with this?

I appreciate @zou3519's concern about gradient writers mistakenly using conj instead of self_or_conj, and I agree that's a real possibility, but if we're consistent about using the latter and do our job in code review I think it'll be OK. I'm also less concerned about introducing a fixable performance issue vs. creating an inconsistent and semantically confusing UX for end users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I am leaning towards "faithfully replicate numpy behavior" (@anjali411, I think I mentioned that I was having trouble deciding between these two choices in our call), I want to point out some reasons why fast conj by default might make sense.

do we expect this function to be used by anyone but developers? Is making a non-user-facing function self_or_conj an option to address the scenarios we'd like to target? If we're uncertain about those scenarios, could we start with this?

fast_conj absolutely would be used by non-developers. Any time you write code that you want to generalize over both reals and complex you will write conjugations to handle the complex case, and if they are not no-ops, you will unnecessarily slow down the execution of your code in the real case. Some more support from boedekker here #41857 (comment)

Exploring the design space some more, I noticed that there were some comments about having a conj flag on tensors that reinterprets the tensor as its conjugate (#41857 (comment)). Was there a decision on that?

For the BC consideration here, it doesn't help. If we have a fancy version of conjugate that returns a view rather than new tensor, that would still be incompatible with the Numpy behavior. So if you want to keep numpy behavior, you're still going to have to come up with a new name for view_conj or whatever.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would lean towards @ezyang opinion here personally. In particular, "fast_conj absolutely would be used by non-developers." and not having it would mean that people will start adding a lot of if t.is_complex() in their code.

@mruberry are there functions for which we explicitly plan to have a different API compared to numpy? And maybe provide a numpy-compliant version in a different namespace? Could this function be one of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working a lot with numpy and I must say, I never checked if numpy.conj returns a copy or hold the original data.
Sometimes I got the impression, that it is a no op, also for complex numbers, when I measured the execution time (execution time is close to nothing compared to any other function).

Nevertheless, assuming that conj returns a copy in the source tensor, also for real tensors, is a bad idea.

When you introduce fast_conj, I would never use conj.

In context of torch, I would prefer, when conj never creates a copy.
While the calculation of the copy may be suboptimal, the increased memory consumption is critical.
In numpy you can calculate on demand the conjugate and it will be deleted a few moments later.
In torch, it may be kept for the backward graph. When you call multiple times conj, you store several copies.

So maybe it would also be worth considering, that the conjugate operation gets an argument to enable something like torch.utils.checkpoint.checkpoint or have the conj flag, that was earlier mentioned.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zdevito this may be one of the first cases where we have existentially important laziness on tensors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry are there functions for which we explicitly plan to have a different API compared to numpy? And maybe provide a numpy-compliant version in a different namespace?

Yes, but we should strive to limit these discrepancies. In particular we want a compelling reason for the difference (the NumPy behavior would be inconsistent with PyTorch, or it would require an extremely painful deprecation, ...).

Could this function be one of them?

Sure. I'm actually less worried about the NumPy compatibility of this function than the PyTorch inconsistency of it. It seems like we're trying to combine the ideas of torch.conj() the unary ufunc and a .conj attribute with a corresponding view function, like .real, .imag, and .T.

@anjali411 anjali411 added module: complex Related to complex number support in PyTorch module: bc-breaking Related to a BC-breaking change and removed module: bc-breaking Related to a BC-breaking change labels Aug 19, 2020
anjali411 added a commit that referenced this pull request Aug 19, 2020
ghstack-source-id: 960922c
Pull Request resolved: #43270
@anjali411 anjali411 requested a review from mruberry August 19, 2020 17:36
@anjali411 anjali411 changed the title Make torch.conj a composite function Add a composite function torch.fast_conj() Aug 19, 2020
@gchanan
Copy link
Contributor

gchanan commented Aug 20, 2020

Here's some fun that I think shows we shouldn't follow NumPy exactly:

>>> np.__version__
'1.17.2'

>>> x=np.random.randn(2,3)

>>> x.ctypes.data == x.conj().ctypes.data
True

>>> x.ctypes.data == np.conj(x).ctypes.data
False

Perhaps x.conj() returning self is just a bug, though. We've had those bugs before.

Same doesn't apply to complex values:

y=np.random.randn(2,3) + 1j*np.random.randn(2,3)

>>> y.ctypes.data == y.conj().ctypes.data
False

>>> y.ctypes.data == np.conj(y).ctypes.data
False

def test_conj_self(self, device, dtype):
t = torch.ones(5, 5, device=device)
s = t.fast_conj()
self.assertTrue(s is t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we actually want this or do we want s to be a view of t?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could but I think we should just return same tensor as input with same storage to make it a complete no op

@gchanan
Copy link
Contributor

gchanan commented Aug 20, 2020

I guess my questions are:

  1. what is the "right" numpy conj behavior? I don't expect numpy meant to distinguish the functional and method variant. @rgommers?
  2. is the the semantics of some function we actually want long-term? I would expect fast_conj or view_conj or whatever to do something fast/view for complex as well. I'm unsure if you are planning this, considering this, or you are set on the current semantics.
  3. As an API issue, IMO you should start the name with the most "discoverable" information -- think code completion. People will type torch.conj... when looking for conjugate functionality, not start with fast.... So whatever we choose for the name, it should start with conj

@mruberry
Copy link
Collaborator

  1. what is the "right" numpy conj behavior? I don't expect numpy meant to distinguish the functional and method variant. @rgommers?

I filed numpy/numpy#17124 about the discrepancy.

@rgommers
Copy link
Collaborator

what is the "right" numpy conj behavior? I don't expect numpy meant to distinguish the functional and method variant. @rgommers?

I commented on the NumPy issue Mike opened. For the decision to be made here, I think it's more important that PyTorch has a consistent design here rather than that it follow NumPy exactly. The copy-vs-view behavior in NumPy can be inconsistent for no good reason, and is very hard to change at this point.

I'd prefer if functions and methods always had the same behaviour.

@ezyang
Copy link
Contributor

ezyang commented Aug 21, 2020

Based on this new information (thanks @gchanan for looking, this is really useful information), I've amended my position: I think we should do a serious investigation of doing conjugation view by default; I kind of suspect we still won't do it for conj because it is too surprising (because it is just too different from how all of our other unary operations work), but I think we should try to understand it better.

Conjugate views work in the following way:

  • Every tensor is associated with a boolean flag that says whether or not it is conjugated
  • x.conj_view() shares memory with x, but simply flips the boolean flag
  • Conjugation views necessitate the creation of negative views as well, so that x.imag is always a view (if x is conjugated, then x.imag must be negative)
  • A backend fallback, by default, materializes a conjugate/negative view into an actual tensor when it is used as a non-mutating input to an operation that doesn't support fused conjugation/negative
  • On a per operator basis, we may override the backend fallback to call specialized fused operations. For conjugation, typical fused operations include dot product and matrix multiply; for negative, an obvious fused operation is transforming add into sub, and vice versa. This fallback should be done AFTER autograd, so that autograd saves conjugate views for backward.
  • If you mutate a conjugate/negative view, you instead conjugate/negate the values you would have written in, and write them into the source tensor.
  • The machinery for conjugate/negative views would also generalize to zero tensors, which @albanD has argued are useful for autograd

An alternative to conjugate views is peephole lazy tensors, which @zdevito has advocated for. Peephole lazy tensors work like this:

  • When you write x.conj(), we don't immediately do the conjugation operation. Instead, we simply produce a tensor which says, "This is x, but conj." The distinction between this and the conjugate flag, is that this tensor is semantically a new copy of the tensor. So for example, if we then attempted to mutate the result, we would simply directly materialize the conjugated result and then mutate that.
  • If the result is used by another operation which doesn't support fusion, we materialize the conjugated tensor; all subsequent uses of the result make use of this materialized tensor

A variation of peephole lazy tensors is rematerialization (#42056). Instead of saving the materialized tensor, we never save it, and instead always redo the conjugation at a use site. (Or perhaps there is some policy, which specifies whether or not we save or not.)

@anjali411 anjali411 changed the title Add a composite function torch.fast_conj() Make torch.conj() a composite function and return self for real tensors Aug 28, 2020
… real tensors"

`torch.conj` is a very commonly used operator for complex tensors, but it's mathematically a no op for real tensors. Switching to tensorflow gradients for complex tensors (as discussed in #41857) would involve adding `torch.conj()` to the backward definitions for a lot of operators. In order to preserve autograd performance for real tensors and maintain numpy compatibility for `torch.conj`, this PR updates `torch.conj()` which behaves the same for complex tensors but performs a view/returns `self` tensor for tensors of non-complex dtypes. The documentation states that the returned tensor for a real input shouldn't be mutated. We could perhaps return an immutable tensor for this case in future when that functionality is available (@zdevito @ezyang ). 




[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 28, 2020
ghstack-source-id: bb48076
Pull Request resolved: #43270
Computes the element-wise conjugate of the given :attr:`input` tensor.
Computes the element-wise conjugate of the given :attr:`input` tensor. If the input has a non-complex dtype,
the output tensor shouldn't be mutated.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per @ezyang 's suggestion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, since we're exploring conjugate view, let me suggest we make an even stronger claim: we should tell users to flat out NOT mutate the output of this function, ever. If you avoid doing so, you will write code that is forwards compatible with any possible semantics of conj (it always returns a new copy, it returns a new copy sometimes, it always returns a view).

I despair about actually making sure people don't do this with just a doc update. We really need some way to mark views as non-mutable. Maybe this wouldn't even be that hard to do...

@anjali411 anjali411 requested review from ezyang and removed request for ezyang August 28, 2020 16:41
… real tensors"

`torch.conj` is a very commonly used operator for complex tensors, but it's mathematically a no op for real tensors. Switching to tensorflow gradients for complex tensors (as discussed in #41857) would involve adding `torch.conj()` to the backward definitions for a lot of operators. In order to preserve autograd performance for real tensors and maintain numpy compatibility for `torch.conj`, this PR updates `torch.conj()` which behaves the same for complex tensors but performs a view/returns `self` tensor for tensors of non-complex dtypes. The documentation states that the returned tensor for a real input shouldn't be mutated. We could perhaps return an immutable tensor for this case in future when that functionality is available (@zdevito @ezyang ). 




[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 28, 2020
ghstack-source-id: 9cbd190
Pull Request resolved: #43270
… real tensors"

`torch.conj` is a very commonly used operator for complex tensors, but it's mathematically a no op for real tensors. Switching to tensorflow gradients for complex tensors (as discussed in #41857) would involve adding `torch.conj()` to the backward definitions for a lot of operators. In order to preserve autograd performance for real tensors and maintain numpy compatibility for `torch.conj`, this PR updates `torch.conj()` which behaves the same for complex tensors but performs a view/returns `self` tensor for tensors of non-complex dtypes. The documentation states that the returned tensor for a real input shouldn't be mutated. We could perhaps return an immutable tensor for this case in future when that functionality is available (@zdevito @ezyang ). 




[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 28, 2020
ghstack-source-id: 281cc5f
Pull Request resolved: #43270

- func: _conj(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this doesn't get traced; that would be _convolution all over again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want to rename it to conj_complex() as it is just a version of conj() that only accepts complex inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang sorry I am not aware of the issue caused due to _convolution. what was it?
@albanD yeah that makes sense. will do!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alias information for conj above is now not correct; you need to give it an alias annotation similar to what contiguous has.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 Basically, _convolution, which has a bunch of extra arguments that people typically don't care about, is what is being traced to serialized models, meaning that now it is very difficult to add new parameters to _convolution because they are FC-breaking for all models.

conj is considerably less trafficked so the risks of getting it wrong here are lower, but it still seems to me that we should prefer tracing conj and not _conj. You should test which one we're getting in the trace.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Luckily I think the intent here is to keep _conj's signature equivalent to conj, since it's just an autograd hack to only apply autograd to complex inputs to conj. As you point out, @ezyang, _convolution suffers because it's secretly pulling in arguments the user doesn't expect.

Interesting question for what gets traced when ops are composite though: the composite ops or the original symbol. I don't know offhand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the code that handle tracing and autograd was the same until very recently. I expect that only the inner function is traced.

@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes()))
def test_conj_self(self, device, dtype):
t = torch.ones(5, 5, device=device)
s = t.fast_conj()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be updated

@ezyang ezyang added the module: bc-breaking Related to a BC-breaking change label Aug 31, 2020
@ezyang
Copy link
Contributor

ezyang commented Aug 31, 2020

In the interest of moving things along, I'd prefer to accept this patch. However, there are a bunch of ways this could go wrong. I'd prefer it if someone else could also ACK.

I'll also take a look at how hard it would be to mark tensors as immutable.

use_c10_dispatcher: full
variants: function, method

- func: conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: inplace conj_?

conj(input, *, out=None) -> Tensor
Computes the element-wise conjugate of the given :attr:`input` tensor.
Computes the element-wise conjugate of the given :attr:`input` tensor. If the input has a non-complex dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would: "Computes the element-wise complex conjugate of :attr`input`." be a clearer fist sentence? "Conjugate" seems overloaded.

I wonder if the second sentence shouldn't be a note and/or a warning? Maybe something like:

NOTE: if :attr`input` is a non-complex dtype this function just returns :attr:input.

WARNING: In the future torch.conj may return a non-writeable view of :attr:input. It's recommended that programs not modify the tensor returned by torch.conj to be compatible with this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the behavior for non-complex tensors should go in the description, but added a warning as per your suggestion.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach makes sense. Although this function may change in the future to return an immutable view the complex tensor UX is still in "beta", and @ezyang's suggestion to clarify the documentation seems like a reasonable mitigation. Few cleanup requirements:

  • test needs update (per @ezyang's note)
  • doc needs clarification

As for NumPy Compatibility, sometimes we do need to be different than NumPy for reasons of performance, our focus on neural networks, or our support of autograd, and I think we have a consensus that this approach is the best way (for now) to implement performant complex autograd.

From the perspective of how PyTorch's functions work, it's important that users know what to expect when they call a function. That is, ideally they should be able to infer what the function will do without reading its documentation. This means that functions should be named reasonably (e.g. torch.linalg.outer computes an outer product, and torch.ger is indecipherable for most users), and the system should be so consistent that users don't need to memorize a lot of exceptions or strange rules (for example, functions and methods in PyTorch perform the same computations).

torch.conj is a weird duck because there are good arguments for it being a unary ufunc like it is in NumPy, a view function, or a function like contiguous, cpu, cuda, and to. The first class of functions has always created a new tensor. The second usually returns a view (ahem, reshape), and the third class, which is hard to name, has returned self or a new tensor. From our discussion, however, it seems like torch.conj will eventually be in one of the latter two groups. Maybe the third class should be extended (and maybe torch.abs(uint8) should return self) or maybe torch.conj will eventually return an (immutable) view. Either way, this current approach seems like our best approximation of that future behavior.

@ezyang
Copy link
Contributor

ezyang commented Sep 1, 2020

Thanks Mike!

… real tensors"

`torch.conj` is a very commonly used operator for complex tensors, but it's mathematically a no op for real tensors. Switching to tensorflow gradients for complex tensors (as discussed in #41857) would involve adding `torch.conj()` to the backward definitions for a lot of operators. In order to preserve autograd performance for real tensors and maintain numpy compatibility for `torch.conj`, this PR updates `torch.conj()` which behaves the same for complex tensors but performs a view/returns `self` tensor for tensors of non-complex dtypes. The documentation states that the returned tensor for a real input shouldn't be mutated. We could perhaps return an immutable tensor for this case in future when that functionality is available (@zdevito @ezyang ). 




[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Sep 1, 2020
ghstack-source-id: d127021
Pull Request resolved: #43270
@codecov
Copy link

codecov bot commented Sep 2, 2020

Codecov Report

❗ No coverage uploaded for pull request base (gh/anjali411/53/base@bacee6a). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@                   Coverage Diff                   @@
##             gh/anjali411/53/base   #43270   +/-   ##
=======================================================
  Coverage                        ?   69.32%           
=======================================================
  Files                           ?      379           
  Lines                           ?    47106           
  Branches                        ?        0           
=======================================================
  Hits                            ?    32654           
  Misses                          ?    14452           
  Partials                        ?        0           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bacee6a...33c3375. Read the comment docs.

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in 129f406.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change module: complex Related to complex number support in PyTorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants