Skip to content

Conversation

@li-roy
Copy link
Contributor

@li-roy li-roy commented Feb 22, 2018

  • Implemented MarginRankingLoss as native function.
  • As per losses per-batch-element #264. When reduce is False, MarginRankingLoss outputs a loss per sample in minibatch. When reduce is True (default), the current behavior is kept.

Test Plan
test/run_test.sh
Added unit tests for MarginRankingLoss.

@li-roy li-roy changed the title add reduce=True arg to MarginRankingLoss Implement MarginRankingLoss as native function and add reduce=True arg to it Feb 22, 2018
@ssnl
Copy link
Collaborator

ssnl commented Feb 22, 2018

why doesn't this also have a backward function?

@li-roy
Copy link
Contributor Author

li-roy commented Feb 22, 2018

I used #5080 as a reference for this PR, and a backwards wasn't implemented there as well. Should I also add a backwards?

@ssnl
Copy link
Collaborator

ssnl commented Feb 22, 2018

The fact that it is moved from python to cpp might be enough to justify ignoring backward. You should either implement it or do a speed benchmark as #5080 .

@ezyang
Copy link
Contributor

ezyang commented Feb 23, 2018

17:48:04 ======================================================================
17:48:04 ERROR: test_MarginRankingLoss_reduce (__main__.TestNN)
17:48:04 ----------------------------------------------------------------------
17:48:04 Traceback (most recent call last):
17:48:04   File "test_nn.py", line 4576, in <lambda>
17:48:04     setattr(TestNN, test_name, lambda self, test=test: test(self))
17:48:04   File "/var/lib/jenkins/workspace/test/common_nn.py", line 1041, in __call__
17:48:04     deepcopy(target), module)
17:48:04   File "test_nn.py", line 4735, in <lambda>
17:48:04     marginrankingloss_reference(i[0], i[1], t, size_average=get_size_average(m)),
17:48:04 NameError: global name 'marginrankingloss_reference' is not defined
17:48:04 
17:48:04 ======================================================================
17:48:04 ERROR: test_MarginRankingLoss_reduce_no_size_average (__main__.TestNN)
17:48:04 ----------------------------------------------------------------------
17:48:04 Traceback (most recent call last):
17:48:04   File "test_nn.py", line 4576, in <lambda>
17:48:04     setattr(TestNN, test_name, lambda self, test=test: test(self))
17:48:04   File "/var/lib/jenkins/workspace/test/common_nn.py", line 1041, in __call__
17:48:04     deepcopy(target), module)
17:48:04   File "test_nn.py", line 4735, in <lambda>
17:48:04     marginrankingloss_reference(i[0], i[1], t, size_average=get_size_average(m)),
17:48:04 NameError: global name 'marginrankingloss_reference' is not defined
17:48:04 

@li-roy
Copy link
Contributor Author

li-roy commented Feb 23, 2018

@ezyang Yeah I have an update, was going to run benchmarks first.

@li-roy
Copy link
Contributor Author

li-roy commented Feb 27, 2018

benchmarks run with inputs of size (1000), 5000 times:

forward (old) [0.8390149101614952, 0.8477208158001304, 0.859797858633101]
backward (old) [2.4765465622767806, 2.535214721225202, 2.4746940098702908]
double backward (old) [0.7999639995396137, 0.7724319389089942, 0.7383095575496554]

forward (new) [0.7037791842594743, 0.7032447448000312, 0.6909415749832988]
backward (new) [0.8209385611116886, 0.8369621662423015, 0.8271905779838562]
double backward (new) [0.8673017444089055, 0.8709612749516964, 0.8717606142163277]

@li-roy li-roy force-pushed the marginranking branch 3 times, most recently from f43af50 to 81e173c Compare February 28, 2018 01:38
@li-roy
Copy link
Contributor Author

li-roy commented Feb 28, 2018

@pytorchbot retest this please

@ssnl
Copy link
Collaborator

ssnl commented Feb 28, 2018

Interesting that double backward slows down, but this is fine given the single backward improvements.

@li-roy
Copy link
Contributor Author

li-roy commented Mar 1, 2018

@pytorchbot retest this please

@li-roy li-roy force-pushed the marginranking branch 2 times, most recently from 9a46f6f to ed36fcf Compare March 9, 2018 00:10
@li-roy
Copy link
Contributor Author

li-roy commented Mar 9, 2018

@pytorchbot retest this please

1 similar comment
@li-roy
Copy link
Contributor Author

li-roy commented Mar 9, 2018

@pytorchbot retest this please

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. But please fix the conflict :)

This comment was marked as off-topic.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants