Skip to content

Commit 234fb2d

Browse files
committed
Issue 21424: Apply the nlargest() optimizations to nsmallest() as well.
1 parent 3a17e21 commit 234fb2d

File tree

4 files changed

+137
-117
lines changed

4 files changed

+137
-117
lines changed

Lib/heapq.py

Lines changed: 115 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@
127127
__all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge',
128128
'nlargest', 'nsmallest', 'heappushpop']
129129

130-
from itertools import islice, count, tee, chain
130+
from itertools import islice, count
131131

132132
def heappush(heap, item):
133133
"""Push item onto heap, maintaining the heap invariant."""
@@ -179,37 +179,19 @@ def heapify(x):
179179
for i in reversed(range(n//2)):
180180
_siftup(x, i)
181181

182-
def _heappushpop_max(heap, item):
183-
"""Maxheap version of a heappush followed by a heappop."""
184-
if heap and item < heap[0]:
185-
item, heap[0] = heap[0], item
186-
_siftup_max(heap, 0)
187-
return item
182+
def _heapreplace_max(heap, item):
183+
"""Maxheap version of a heappop followed by a heappush."""
184+
returnitem = heap[0] # raises appropriate IndexError if heap is empty
185+
heap[0] = item
186+
_siftup_max(heap, 0)
187+
return returnitem
188188

189189
def _heapify_max(x):
190190
"""Transform list into a maxheap, in-place, in O(len(x)) time."""
191191
n = len(x)
192192
for i in reversed(range(n//2)):
193193
_siftup_max(x, i)
194194

195-
def nsmallest(n, iterable):
196-
"""Find the n smallest elements in a dataset.
197-
198-
Equivalent to: sorted(iterable)[:n]
199-
"""
200-
if n <= 0:
201-
return []
202-
it = iter(iterable)
203-
result = list(islice(it, n))
204-
if not result:
205-
return result
206-
_heapify_max(result)
207-
_heappushpop = _heappushpop_max
208-
for elem in it:
209-
_heappushpop(result, elem)
210-
result.sort()
211-
return result
212-
213195
# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos
214196
# is the index of a leaf with a possibly out-of-order value. Restore the
215197
# heap invariant.
@@ -327,6 +309,10 @@ def _siftup_max(heap, pos):
327309
from _heapq import *
328310
except ImportError:
329311
pass
312+
try:
313+
from _heapq import _heapreplace_max
314+
except ImportError:
315+
pass
330316

331317
def merge(*iterables):
332318
'''Merge multiple sorted inputs into a single sorted output.
@@ -367,22 +353,86 @@ def merge(*iterables):
367353
yield v
368354
yield from next.__self__
369355

370-
# Extend the implementations of nsmallest and nlargest to use a key= argument
371-
_nsmallest = nsmallest
356+
357+
# Algorithm notes for nlargest() and nsmallest()
358+
# ==============================================
359+
#
360+
# Makes just a single pass over the data while keeping the k most extreme values
361+
# in a heap. Memory consumption is limited to keeping k values in a list.
362+
#
363+
# Measured performance for random inputs:
364+
#
365+
# number of comparisons
366+
# n inputs k-extreme values (average of 5 trials) % more than min()
367+
# ------------- ---------------- - ------------------- -----------------
368+
# 1,000 100 3,317 133.2%
369+
# 10,000 100 14,046 40.5%
370+
# 100,000 100 105,749 5.7%
371+
# 1,000,000 100 1,007,751 0.8%
372+
# 10,000,000 100 10,009,401 0.1%
373+
#
374+
# Theoretical number of comparisons for k smallest of n random inputs:
375+
#
376+
# Step Comparisons Action
377+
# ---- -------------------------- ---------------------------
378+
# 1 1.66 * k heapify the first k-inputs
379+
# 2 n - k compare remaining elements to top of heap
380+
# 3 k * (1 + lg2(k)) * ln(n/k) replace the topmost value on the heap
381+
# 4 k * lg2(k) - (k/2) final sort of the k most extreme values
382+
# Combining and simplifying for a rough estimate gives:
383+
# comparisons = n + k * (1 + log(n/k)) * (1 + log(k, 2))
384+
#
385+
# Computing the number of comparisons for step 3:
386+
# -----------------------------------------------
387+
# * For the i-th new value from the iterable, the probability of being in the
388+
# k most extreme values is k/i. For example, the probability of the 101st
389+
# value seen being in the 100 most extreme values is 100/101.
390+
# * If the value is a new extreme value, the cost of inserting it into the
391+
# heap is 1 + log(k, 2).
392+
# * The probabilty times the cost gives:
393+
# (k/i) * (1 + log(k, 2))
394+
# * Summing across the remaining n-k elements gives:
395+
# sum((k/i) * (1 + log(k, 2)) for xrange(k+1, n+1))
396+
# * This reduces to:
397+
# (H(n) - H(k)) * k * (1 + log(k, 2))
398+
# * Where H(n) is the n-th harmonic number estimated by:
399+
# gamma = 0.5772156649
400+
# H(n) = log(n, e) + gamma + 1.0 / (2.0 * n)
401+
# http://en.wikipedia.org/wiki/Harmonic_series_(mathematics)#Rate_of_divergence
402+
# * Substituting the H(n) formula:
403+
# comparisons = k * (1 + log(k, 2)) * (log(n/k, e) + (1/n - 1/k) / 2)
404+
#
405+
# Worst-case for step 3:
406+
# ----------------------
407+
# In the worst case, the input data is reversed sorted so that every new element
408+
# must be inserted in the heap:
409+
#
410+
# comparisons = 1.66 * k + log(k, 2) * (n - k)
411+
#
412+
# Alternative Algorithms
413+
# ----------------------
414+
# Other algorithms were not used because they:
415+
# 1) Took much more auxiliary memory,
416+
# 2) Made multiple passes over the data.
417+
# 3) Made more comparisons in common cases (small k, large n, semi-random input).
418+
# See the more detailed comparison of approach at:
419+
# http://code.activestate.com/recipes/577573-compare-algorithms-for-heapqsmallest
420+
372421
def nsmallest(n, iterable, key=None):
373422
"""Find the n smallest elements in a dataset.
374423
375424
Equivalent to: sorted(iterable, key=key)[:n]
376425
"""
426+
377427
# Short-cut for n==1 is to use min() when len(iterable)>0
378428
if n == 1:
379429
it = iter(iterable)
380-
head = list(islice(it, 1))
381-
if not head:
382-
return []
430+
sentinel = object()
383431
if key is None:
384-
return [min(chain(head, it))]
385-
return [min(chain(head, it), key=key)]
432+
result = min(it, default=sentinel)
433+
else:
434+
result = min(it, default=sentinel, key=key)
435+
return [] if result is sentinel else [result]
386436

387437
# When n>=size, it's faster to use sorted()
388438
try:
@@ -395,15 +445,39 @@ def nsmallest(n, iterable, key=None):
395445

396446
# When key is none, use simpler decoration
397447
if key is None:
398-
it = zip(iterable, count()) # decorate
399-
result = _nsmallest(n, it)
400-
return [r[0] for r in result] # undecorate
448+
it = iter(iterable)
449+
result = list(islice(zip(it, count()), n))
450+
if not result:
451+
return result
452+
_heapify_max(result)
453+
order = n
454+
top = result[0][0]
455+
_heapreplace = _heapreplace_max
456+
for elem in it:
457+
if elem < top:
458+
_heapreplace(result, (elem, order))
459+
top = result[0][0]
460+
order += 1
461+
result.sort()
462+
return [r[0] for r in result]
401463

402464
# General case, slowest method
403-
in1, in2 = tee(iterable)
404-
it = zip(map(key, in1), count(), in2) # decorate
405-
result = _nsmallest(n, it)
406-
return [r[2] for r in result] # undecorate
465+
it = iter(iterable)
466+
result = [(key(elem), i, elem) for i, elem in zip(range(n), it)]
467+
if not result:
468+
return result
469+
_heapify_max(result)
470+
order = n
471+
top = result[0][0]
472+
_heapreplace = _heapreplace_max
473+
for elem in it:
474+
k = key(elem)
475+
if k < top:
476+
_heapreplace(result, (k, order, elem))
477+
top = result[0][0]
478+
order += 1
479+
result.sort()
480+
return [r[2] for r in result]
407481

408482
def nlargest(n, iterable, key=None):
409483
"""Find the n largest elements in a dataset.
@@ -442,9 +516,9 @@ def nlargest(n, iterable, key=None):
442516
_heapreplace = heapreplace
443517
for elem in it:
444518
if top < elem:
445-
order -= 1
446519
_heapreplace(result, (elem, order))
447520
top = result[0][0]
521+
order -= 1
448522
result.sort(reverse=True)
449523
return [r[0] for r in result]
450524

@@ -460,9 +534,9 @@ def nlargest(n, iterable, key=None):
460534
for elem in it:
461535
k = key(elem)
462536
if top < k:
463-
order -= 1
464537
_heapreplace(result, (k, order, elem))
465538
top = result[0][0]
539+
order -= 1
466540
result.sort(reverse=True)
467541
return [r[2] for r in result]
468542

Lib/test/test_heapq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
1414
# _heapq is imported, so check them there
1515
func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
16-
'heapreplace', '_nsmallest']
16+
'heapreplace', '_heapreplace_max']
1717

1818
class TestModules(TestCase):
1919
def test_py_functions(self):

Misc/NEWS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ Library
8484
- Issue #21156: importlib.abc.InspectLoader.source_to_code() is now a
8585
staticmethod.
8686

87-
- Issue #21424: Simplified and optimized heaqp.nlargest() to make fewer
88-
tuple comparisons.
87+
- Issue #21424: Simplified and optimized heaqp.nlargest() and nmsmallest()
88+
to make fewer tuple comparisons.
8989

9090
- Issue #21396: Fix TextIOWrapper(..., write_through=True) to not force a
9191
flush() on the underlying binary stream. Patch by akira.

Modules/_heapqmodule.c

Lines changed: 19 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -354,88 +354,34 @@ _siftupmax(PyListObject *heap, Py_ssize_t pos)
354354
}
355355

356356
static PyObject *
357-
nsmallest(PyObject *self, PyObject *args)
357+
_heapreplace_max(PyObject *self, PyObject *args)
358358
{
359-
PyObject *heap=NULL, *elem, *iterable, *los, *it, *oldelem;
360-
Py_ssize_t i, n;
361-
int cmp;
359+
PyObject *heap, *item, *returnitem;
362360

363-
if (!PyArg_ParseTuple(args, "nO:nsmallest", &n, &iterable))
361+
if (!PyArg_UnpackTuple(args, "_heapreplace_max", 2, 2, &heap, &item))
364362
return NULL;
365363

366-
it = PyObject_GetIter(iterable);
367-
if (it == NULL)
364+
if (!PyList_Check(heap)) {
365+
PyErr_SetString(PyExc_TypeError, "heap argument must be a list");
368366
return NULL;
369-
370-
heap = PyList_New(0);
371-
if (heap == NULL)
372-
goto fail;
373-
374-
for (i=0 ; i<n ; i++ ){
375-
elem = PyIter_Next(it);
376-
if (elem == NULL) {
377-
if (PyErr_Occurred())
378-
goto fail;
379-
else
380-
goto sortit;
381-
}
382-
if (PyList_Append(heap, elem) == -1) {
383-
Py_DECREF(elem);
384-
goto fail;
385-
}
386-
Py_DECREF(elem);
387367
}
388-
n = PyList_GET_SIZE(heap);
389-
if (n == 0)
390-
goto sortit;
391-
392-
for (i=n/2-1 ; i>=0 ; i--)
393-
if(_siftupmax((PyListObject *)heap, i) == -1)
394-
goto fail;
395-
396-
los = PyList_GET_ITEM(heap, 0);
397-
while (1) {
398-
elem = PyIter_Next(it);
399-
if (elem == NULL) {
400-
if (PyErr_Occurred())
401-
goto fail;
402-
else
403-
goto sortit;
404-
}
405-
cmp = PyObject_RichCompareBool(elem, los, Py_LT);
406-
if (cmp == -1) {
407-
Py_DECREF(elem);
408-
goto fail;
409-
}
410-
if (cmp == 0) {
411-
Py_DECREF(elem);
412-
continue;
413-
}
414368

415-
oldelem = PyList_GET_ITEM(heap, 0);
416-
PyList_SET_ITEM(heap, 0, elem);
417-
Py_DECREF(oldelem);
418-
if (_siftupmax((PyListObject *)heap, 0) == -1)
419-
goto fail;
420-
los = PyList_GET_ITEM(heap, 0);
369+
if (PyList_GET_SIZE(heap) < 1) {
370+
PyErr_SetString(PyExc_IndexError, "index out of range");
371+
return NULL;
421372
}
422373

423-
sortit:
424-
if (PyList_Sort(heap) == -1)
425-
goto fail;
426-
Py_DECREF(it);
427-
return heap;
428-
429-
fail:
430-
Py_DECREF(it);
431-
Py_XDECREF(heap);
432-
return NULL;
374+
returnitem = PyList_GET_ITEM(heap, 0);
375+
Py_INCREF(item);
376+
PyList_SET_ITEM(heap, 0, item);
377+
if (_siftupmax((PyListObject *)heap, 0) == -1) {
378+
Py_DECREF(returnitem);
379+
return NULL;
380+
}
381+
return returnitem;
433382
}
434383

435-
PyDoc_STRVAR(nsmallest_doc,
436-
"Find the n smallest elements in a dataset.\n\
437-
\n\
438-
Equivalent to: sorted(iterable)[:n]\n");
384+
PyDoc_STRVAR(heapreplace_max_doc, "Maxheap variant of heapreplace");
439385

440386
static PyMethodDef heapq_methods[] = {
441387
{"heappush", (PyCFunction)heappush,
@@ -448,8 +394,8 @@ static PyMethodDef heapq_methods[] = {
448394
METH_VARARGS, heapreplace_doc},
449395
{"heapify", (PyCFunction)heapify,
450396
METH_O, heapify_doc},
451-
{"nsmallest", (PyCFunction)nsmallest,
452-
METH_VARARGS, nsmallest_doc},
397+
{"_heapreplace_max",(PyCFunction)_heapreplace_max,
398+
METH_VARARGS, heapreplace_max_doc},
453399
{NULL, NULL} /* sentinel */
454400
};
455401

0 commit comments

Comments
 (0)