Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Jul 31, 2018

Implemented via a wrapper, thank you Richard for the suggestion!

Fixes: #9929

implemented via a wrapper

Fixes: pytorch#9929
variants: function

- func: einsum(std::string equation, TensorList tensors) -> Tensor
- func: _einsum(std::string equation, TensorList tensors) -> Tensor

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Thanks to fmassa for the suggestion!
if the method to fix it is not as pretty, it is my own fault, though.
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SsnL has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

"""
if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
# the old interface of passing the operands as one list argument
operands = operands[0]

This comment was marked as off-topic.

This comment was marked as off-topic.

return P, L, U


def einsum(equation, *operands):

This comment was marked as off-topic.

@t-vi
Copy link
Collaborator Author

t-vi commented Jul 31, 2018 via email

@zou3519
Copy link
Contributor

zou3519 commented Jul 31, 2018

Oh I see. Thanks for the clarification @t-vi !

@t-vi
Copy link
Collaborator Author

t-vi commented Aug 9, 2018

I think the CI failure isn't this patch. Should we deprecate having a list as second argument?

@zou3519
Copy link
Contributor

zou3519 commented Aug 14, 2018

Let's keep it like this for now and think about if the deprecation should happen later.

Could you rebase this pr @t-vi?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[pytorch] torch.einsum: minor API discrepancy with NumPy and TensorFlow

6 participants