Hippocampus's Garden

Under the sea, in the hippocampus's garden...

    Search by

    Fast Way to Get Top-K Elements from Numpy Array

    May 14, 2022  |  2 min read  |  469 views

    • このエントリーをはてなブックマークに追加

    Retrieving ordered top-k elements from a certain array is a common problem. However, NumPy does not support this operation natively. A naive solution is to carry out full sort and then take the top-k elements like this:

    import numpy as np
    
    def topk_by_sort(input, k, axis=None, ascending=True):
        if not ascending:
            input *= -1
        ind = np.argsort(input, axis=axis)
        ind = np.take(ind, np.arange(k), axis=axis)
        if not ascending:
            input *= -1
        val = np.take_along_axis(input, ind, axis=axis) 
        return ind, val

    Note that this functions is generalized so that it can handle multi-dimensional arrays and descending order requests.

    When the array length is nn, it takes O(nlogn)O(n\log n) time. This isn’t an optimal solution. You can reduce it to O(n)O(n) time (assuming nkn \gg k) by retrieving the kk largest elements (non-sorted) before sorting. This is implemented by np.argpartition() and np.argsort() like this:

    def topk_by_partition(input, k, axis=None, ascending=True):
        if not ascending:
            input *= -1
        ind = np.argpartition(input, k, axis=axis)
        ind = np.take(ind, np.arange(k), axis=axis) # k non-sorted indices
        input = np.take_along_axis(input, ind, axis=axis) # k non-sorted values
    
        # sort within k elements
        ind_part = np.argsort(input, axis=axis)
        ind = np.take_along_axis(ind, ind_part, axis=axis)
        if not ascending:
            input *= -1
        val = np.take_along_axis(input, ind_part, axis=axis) 
        return ind, val

    I measured the time complexity of the two functions with k=10k=10 and different nn. As shown in the figures below, the argpartition approach is significantly faster when n>104n \gt 10^4.

    ogp

    2022 05 15 14 33 26

    References

    [1] numpy.argpartition — NumPy v1.22 Manual
    [2] numpy.take_along_axis — NumPy v1.22 Manual


    • このエントリーをはてなブックマークに追加
    [object Object]

    Written by Shion Honda. If you like this, please share!

    Shion Honda

    Hippocampus's Garden © 2022, Shion Honda. Built with Gatsby