Fast Way to Get Top-K Elements from Numpy Array
May 14, 2022 | 1 min read | 4,811 views
Retrieving ordered top-k elements from a certain array is a common problem. However, NumPy does not support this operation natively. A naive solution is to carry out full sort and then take the top-k elements like this:
import numpy as np
def topk_by_sort(input, k, axis=None, ascending=True):
if not ascending:
input *= -1
ind = np.argsort(input, axis=axis)
ind = np.take(ind, np.arange(k), axis=axis)
if not ascending:
input *= -1
val = np.take_along_axis(input, ind, axis=axis)
return ind, val
Note that this functions is generalized so that it can handle multi-dimensional arrays and descending order requests.
When the array length is , it takes time. This isn’t an optimal solution. You can reduce it to time (assuming ) by retrieving the largest elements (non-sorted) before sorting. This is implemented by np.argpartition()
and np.argsort()
like this:
def topk_by_partition(input, k, axis=None, ascending=True):
if not ascending:
input *= -1
ind = np.argpartition(input, k, axis=axis)
ind = np.take(ind, np.arange(k), axis=axis) # k non-sorted indices
input = np.take_along_axis(input, ind, axis=axis) # k non-sorted values
# sort within k elements
ind_part = np.argsort(input, axis=axis)
ind = np.take_along_axis(ind, ind_part, axis=axis)
if not ascending:
input *= -1
val = np.take_along_axis(input, ind_part, axis=axis)
return ind, val
I measured the time complexity of the two functions with and different . As shown in the figures below, the argpartition approach is significantly faster when .
References
[1] numpy.argpartition — NumPy v1.22 Manual
[2] numpy.take_along_axis — NumPy v1.22 Manual
Written by Shion Honda. If you like this, please share!