Skip to content

Conversation

@Kiyosora
Copy link
Contributor

@dr-ci
Copy link

dr-ci bot commented Jul 17, 2020

💊 CI failures summary and remediations

As of commit 9db39c6 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 3 fixed upstream failures:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

Since your merge base is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 40 times.

@Kiyosora Kiyosora force-pushed the signbit_implement branch from eb278ce to bd179e2 Compare July 17, 2020 12:31
@Kiyosora Kiyosora marked this pull request as ready for review July 20, 2020 01:28
@Kiyosora
Copy link
Contributor Author

Hi, @mruberry ! I implemented the torch.signbit() method and it's now ready to be reviewed. It seems that there is an XLA-related error occured in the CI test. Should I add some necessary changes for XLA, or does it better to skip the XLA test by adding annotation?

@Kiyosora Kiyosora force-pushed the signbit_implement branch from bd179e2 to b65627a Compare July 20, 2020 02:25
@Kiyosora Kiyosora changed the title [WIP] Implementing NumPy-like function torch.signbit() Implementing NumPy-like function torch.signbit() Jul 20, 2020
@mruberry mruberry self-requested a review July 20, 2020 08:18
@mruberry mruberry added the module: numpy Related to numpy support, and also numpy compatibility of our operators label Jul 20, 2020
@mruberry
Copy link
Collaborator

Hi, @mruberry ! I implemented the torch.signbit() method and it's now ready to be reviewed. It seems that there is an XLA-related error occured in the CI test. Should I add some necessary changes for XLA, or does it better to skip the XLA test by adding annotation?

Great! Thanks @Kiyosora! I'll take a look soon.

@gchanan gchanan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 20, 2020
@Kiyosora Kiyosora requested a review from ailzhang July 21, 2020 04:02
Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping for xla fix :D Thanks!

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @ailzhang, for helping with the XLA dispatch!

This is a well-written PR, @Kiyosora, thanks for submitting it!

I made a few comments about the test and docs, but also have a question about the design of this PR: does it need a new TensorIterator kernel, or can it be implemented using existing PyTorch functions? I'm curious to hear your thoughts.

@Kiyosora Kiyosora force-pushed the signbit_implement branch from 7b76fb9 to 359d5d1 Compare July 28, 2020 01:34
@Kiyosora
Copy link
Contributor Author

Hi, @mruberry ! Thank you so much for your advice & Sorry for my late reply. 😢
Based on your advice, I improved my code. Please make review does it reach the goal?

@Kiyosora Kiyosora requested a review from mruberry July 28, 2020 05:27
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that clamp is not implemented for complex tensors, but this should probably say:

"signbit is not implemented for complex tensors."

No need for a "yet," either, since we have no plans to implement it in the future -- it's not even clear how to define a function like signbit on complex values. NumPy, for example, also throws an error when signbit is given a complex input.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: better add a check that result is bool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed! Thanks for the correction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to promote inputs or cast to an output. There's only one input and for this first iteration signbit doesn't need to support non-bool outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! thanks for your explanation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Tensor self), no star needed here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\operatorname{signbit} ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formulas has been deleted, Sorry for my carelessness.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regexes will need to be updated to refer to signbit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed!

@mruberry mruberry self-requested a review July 28, 2020 11:53
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks very good!

I made a few comments with updates/corrections.

@Kiyosora Kiyosora force-pushed the signbit_implement branch 3 times, most recently from 5e95f7f to 775f095 Compare July 29, 2020 08:54
@Kiyosora Kiyosora requested a review from mruberry July 29, 2020 11:40
@Kiyosora
Copy link
Contributor Author

Hi, @mruberry. I'm sorry that my carelessness caused some avoidable mistakes.
I have fixed them, Would you please help me to check them again, thanks! 🙏

@Kiyosora Kiyosora force-pushed the signbit_implement branch from 775f095 to 9db39c6 Compare July 30, 2020 01:19
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @Kiyosora!

Let me know if you're interested in implementing another function or would like to look at a different kind of issue.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@Kiyosora
Copy link
Contributor Author

Nice work, @Kiyosora!

Let me know if you're interested in implementing another function or would like to look at a different kind of issue.

Thanks for confirm, @mruberry !

I'd' like going to work on the function heaviside() , another PR would be come up soon.

@mruberry
Copy link
Collaborator

Nice work, @Kiyosora!
Let me know if you're interested in implementing another function or would like to look at a different kind of issue.

Thanks for confirm, @mruberry !

I'd' like going to work on the function heaviside() , another PR would be come up soon.

Great! Looking forward to it.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 26d5850.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: numpy Related to numpy support, and also numpy compatibility of our operators open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants