-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Example][Bugfix] Fix arma example #4218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
To trigger regression tests:
|
examples/pytorch/arma/model.py
Outdated
| output += feats | ||
| return output / self.K | ||
| tot_output = torch.cat((tot_output, output), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer defining tot_output as a list to store outputs and then calculating their mean by torch.stack(...).mean().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will update accordingly.
* Fix arma example * Update Co-authored-by: Xin Yao <xiny@nvidia.com>
Description
To fix #4202, the crash due to in-place operations for one variable needed for backward propagation.
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes
At the end of forward function, to compute average output over all stacks, instead of doing
+=(in-place operations) as:output += featuresusing
unsqueeze() + cat()to store all stack outputs and return the mean of them.Additional notes
With profiling, no significant performance penalty observed with this fix (less than 10% slow-down)