Skip to content

Conversation

@velikodniy
Copy link
Contributor

@velikodniy velikodniy commented Apr 23, 2018

  • Add an option to control precision/recall in imbalanced datasets
  • Add tests (but new_criterion_tests)

Multiplier 1 + (pos_weight - 1) * target is the only significant difference.

Notes:

  • I didn't implement pos_weight for F.binary_cross_entropy/nn.BCELoss. (It uses implementation from torch._C.) I can add a straightforward and numerically unstable implementation to this function but I'm not sure if it is really needed.
  • I've added pos_weight as last argument to prevent errors in the code that doesn't use names for keyword arguments. But it looks quite ugly.

P.S. I proposed these changes in #5660

…tropy_with_logits (pytorch#5660)

- Add an option to control precision/recall in imbalanced datasets
- Add tests (but new_criterion_tests)
target: Tensor of the same shape as input
weight (Tensor, optional): a manual rescaling weight
if provided it's repeated to match input tensor shape
pos_weight (Tensor, optional): a weight of positive examples.

This comment was marked as off-topic.

weight (Tensor, optional): a manual rescaling weight given to the loss
of each batch element. If given, has to be a Tensor of size
"nbatch".
pos_weight (Tensor, optional): a weight of positive examples.

This comment was marked as off-topic.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept assuming the doc issues are fixed

`pos_weight` was moved to the end because it is the last argument in both
`nn.BCEWithLogitsLoss` and `binary_cross_entropy_with_logits`
is ``False``, returns a loss per input/target element instead and ignores
:attr:`size_average`. Default: ``True``
pos_weight (Tensor, optional): a weight of positive examples.
Must be a vector with length equal to the number of classes.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

def __init__(self, weight=None, size_average=True, reduce=True, pos_weight=None):
super(BCEWithLogitsLoss, self).__init__(size_average, reduce)
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented May 29, 2018

@pytorchbot retest this please

4 similar comments
@ezyang
Copy link
Contributor

ezyang commented May 30, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented May 30, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented May 30, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented May 30, 2018

@pytorchbot retest this please

@ssnl ssnl dismissed apaszke’s stale review June 26, 2018 16:31

we don't save None's in state_dicts

@ssnl ssnl merged commit 6e28d4d into pytorch:master Jun 26, 2018
petrex pushed a commit to ROCm/pytorch that referenced this pull request Jun 26, 2018
* upstream/master: (42 commits)
  [c10d] No default device for ProcessGroupGloo (pytorch#8888)
  Fix default values for affine= in the docstrings of InstanceNormXd (pytorch#8895)
  Stop making dynamic allocations of PinnedMemoryAllocator. (pytorch#8896)
  [C++ API]  Rework optimization package (pytorch#8815)
  Mention MPICH_MAX_THREAD_SAFETY=multiple. (pytorch#8580)
  Unify isViewable, handle n-dimensional empty tensors. (pytorch#8883)
  Add pos_weight argument to nn.BCEWithLogitsLoss (pytorch#5660) (pytorch#6856)
  [build] Enable clang-specific warnings only when using clang (pytorch#8869)
  Fix cmake cudnn autodetection (pytorch#8891)
  [c10d] Fix link order for building C++ tests (pytorch#8889)
  directly add_subdirectory(nanopb) from torch CMakeLists (pytorch#8870)
  [C++ API] Bag of fixes (pytorch#8843)
  [build] Raise in cmake when seeing NVCC{9/9.1} + GCC6 combo (pytorch#8863)
  Create avg_pool1d in ATen (pytorch#8880)
  throw error when grid_sample is passed unsupported mode (pytorch#8884)
  Allow autograd to work even when the shape of values cannot be determined (pytorch#8641)
  Make at::Tensor::to() const (pytorch#8839)
  [auto] Update onnx to 458c521 - Fix typo (onnx/onnx#1143) onnx/onnx@458c521
  [Caffe2] Fix gradient_check on in-place ops (pytorch#8828)
  Fix as_strided_backward (pytorch#8721)
  ...
@velikodniy velikodniy deleted the feature/bceloss_with_pow_weight branch June 20, 2019 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants