Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4577,6 +4577,12 @@ def merge_dicts(*dicts):

Returns the median value of all elements in the :attr:`input` tensor.

.. note::
The median is not unique for :attr:`input` tensors with an even number
of elements. In this case the lower of the two medians is returned. To
compute the mean of both medians in :attr:`input`, use :func:`torch.quantile`
with ``q=0.5`` instead.

.. warning::
This function produces deterministic (sub)gradients unlike ``median(dim=0)``
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: it's very surprising to me that this version of median produces deterministic gradients but the other version, which actually computes indices, doesn't. We should probably fix "median_with_indices" to consistently return the FIRST valid median, just like we fixed argmin and argmax.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, as you mention, that the other function returns non-deterministic indices and thus non-deterministic subgradients.
This one evenly distribute the gradient to all the inputs with the value used. So it always is deterministic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason it returns non-deterministic indices is because it uses quickselect to partition which is not stable. One way to solve it is to use 3-way quickselect that partitions input into <, = and >. Then choose the smallest index from the = segment.


Expand Down Expand Up @@ -4604,6 +4610,12 @@ def merge_dicts(*dicts):
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
the outputs tensor having 1 fewer dimension than :attr:`input`.

.. note::
The median is not unique for :attr:`input` tensors with an even number
of elements in the dimension :attr:`dim`. In this case the lower of the
two medians is returned. To compute the mean of both medians in
:attr:`input`, use :func:`torch.quantile` with ``q=0.5`` instead.

.. warning::
``indices`` does not necessarily contain the first occurrence of each
median value found, unless it is unique.
Expand Down