Jak uzyskać indeksy N maksymalnych wartości w tablicy NumPy?

NumPy proponuje sposób na uzyskanie indeksu maksymalnej wartości tablicy poprzez np.argmax.

Chciałbym podobną rzecz, ale zwracając indeksy N maksymalnych wartości.

Na przykład, jeśli mam tablicę, [1, 3, 2, 4, 5], function(array, n=3) zwróci [4, 3, 1].

 294
Author: Peter Mortensen, 2011-08-02

15 answers

Najprostsze, jakie udało mi się wymyślić to:

In [1]: import numpy as np

In [2]: arr = np.array([1, 3, 2, 4, 5])

In [3]: arr.argsort()[-3:][::-1]
Out[3]: array([4, 3, 1])

To wymaga kompletnego rodzaju tablicy. Zastanawiam się, czy numpy zapewnia wbudowany sposób na częściowe sortowanie; do tej pory nie byłem w stanie go znaleźć.

Jeśli to rozwiązanie okaże się zbyt wolne (szczególnie dla małych n), może warto przyjrzeć się kodowaniu czegoś w Cythonie .

 213
Author: NPE,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2011-08-02 10:45:41

Nowsze wersje NumPy (1.8 i nowsze) posiadają funkcję o nazwie argpartition za to. Aby uzyskać indeksy czterech największych elementów, wykonaj

>>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> a
array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> ind = np.argpartition(a, -4)[-4:]
>>> ind
array([1, 5, 8, 0])
>>> a[ind]
array([4, 9, 6, 9])

W przeciwieństwie do argsort, Funkcja ta działa w czasie liniowym w najgorszym przypadku, ale zwracane indeksy nie są sortowane, co widać na podstawie wyniku oceny a[ind]. Jeśli tego też potrzebujesz, posortuj je później:

>>> ind[np.argsort(a[ind])]
array([1, 8, 5, 0])

Aby uzyskać top- K elementy w porządku posortowanym w ten sposób przyjmuje O (n + K log K ) Czas.

 361
Author: Fred Foo,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-08-16 17:14:16

Prostsze jeszcze:

idx = (-arr).argsort()[:n]

Gdzie n jest liczbą wartości maksymalnych.

 31
Author: Ketan,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2014-12-11 22:13:05

Użycie:

>>> import heapq
>>> import numpy
>>> a = numpy.array([1, 3, 2, 4, 5])
>>> heapq.nlargest(3, range(len(a)), a.take)
[4, 3, 1]

Dla zwykłych list Pythona:

>>> a = [1, 3, 2, 4, 5]
>>> heapq.nlargest(3, range(len(a)), a.__getitem__)
[4, 3, 1]

Jeśli używasz Pythona 2, Użyj xrange zamiast range.

Źródło: heapq-algorytm kolejki sterty

 23
Author: anishpatel,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:49:39

Jeśli zdarzy ci się pracować z wielowymiarową tablicą, musisz spłaszczyć i rozwikłać indeksy:

def largest_indices(ary, n):
    """Returns the n largest indices from a numpy array."""
    flat = ary.flatten()
    indices = np.argpartition(flat, -n)[-n:]
    indices = indices[np.argsort(-flat[indices])]
    return np.unravel_index(indices, ary.shape)

Na przykład:

>>> xs = np.sin(np.arange(9)).reshape((3, 3))
>>> xs
array([[ 0.        ,  0.84147098,  0.90929743],
       [ 0.14112001, -0.7568025 , -0.95892427],
       [-0.2794155 ,  0.6569866 ,  0.98935825]])
>>> largest_indices(xs, 3)
(array([2, 0, 0]), array([2, 2, 1]))
>>> xs[largest_indices(xs, 3)]
array([ 0.98935825,  0.90929743,  0.84147098])
 18
Author: danvk,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2016-08-10 21:42:27

Jeśli nie zależy ci na porządku K-tych największych elementów możesz użyć argpartition, który powinien działać lepiej niż pełne sortowanie argsort.

K = 4 # We want the indices of the four largest values
a = np.array([0, 8, 0, 4, 5, 8, 8, 0, 4, 2])
np.argpartition(a,-K)[-K:]
array([4, 1, 5, 6])

Napisy idą do to pytanie .

Przeprowadziłem kilka testów i wygląda na to, że argpartition przewyższa argsort Jako rozmiar tablicy i wartość wzrostu K.

 5
Author: blue,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:52:38

Dla tablic wielowymiarowych można użyć słowa kluczowego axis, Aby zastosować partycjonowanie wzdłuż oczekiwanej osi.

# For a 2D array
indices = np.argpartition(arr, -N, axis=1)[:, -N:]

I za chwytanie przedmiotów:

x = arr.shape[0]
arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)

Ale zauważ, że to nie zwróci sortowanego wyniku. W takim przypadku możesz użyć np.argsort() wzdłuż zamierzonej osi:

indices = np.argsort(arr, axis=1)[:, -N:]

# Result
x = arr.shape[0]
arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)

Oto przykład:

In [42]: a = np.random.randint(0, 20, (10, 10))

In [44]: a
Out[44]:
array([[ 7, 11, 12,  0,  2,  3,  4, 10,  6, 10],
       [16, 16,  4,  3, 18,  5, 10,  4, 14,  9],
       [ 2,  9, 15, 12, 18,  3, 13, 11,  5, 10],
       [14,  0,  9, 11,  1,  4,  9, 19, 18, 12],
       [ 0, 10,  5, 15,  9, 18,  5,  2, 16, 19],
       [14, 19,  3, 11, 13, 11, 13, 11,  1, 14],
       [ 7, 15, 18,  6,  5, 13,  1,  7,  9, 19],
       [11, 17, 11, 16, 14,  3, 16,  1, 12, 19],
       [ 2,  4, 14,  8,  6,  9, 14,  9,  1,  5],
       [ 1, 10, 15,  0,  1,  9, 18,  2,  2, 12]])

In [45]: np.argpartition(a, np.argmin(a, axis=0))[:, 1:] # 1 is because the first item is the minimum one.
Out[45]:
array([[4, 5, 6, 8, 0, 7, 9, 1, 2],
       [2, 7, 5, 9, 6, 8, 1, 0, 4],
       [5, 8, 1, 9, 7, 3, 6, 2, 4],
       [4, 5, 2, 6, 3, 9, 0, 8, 7],
       [7, 2, 6, 4, 1, 3, 8, 5, 9],
       [2, 3, 5, 7, 6, 4, 0, 9, 1],
       [4, 3, 0, 7, 8, 5, 1, 2, 9],
       [5, 2, 0, 8, 4, 6, 3, 1, 9],
       [0, 1, 9, 4, 3, 7, 5, 2, 6],
       [0, 4, 7, 8, 5, 1, 9, 2, 6]])

In [46]: np.argpartition(a, np.argmin(a, axis=0))[:, -3:]
Out[46]:
array([[9, 1, 2],
       [1, 0, 4],
       [6, 2, 4],
       [0, 8, 7],
       [8, 5, 9],
       [0, 9, 1],
       [1, 2, 9],
       [3, 1, 9],
       [5, 2, 6],
       [9, 2, 6]])

In [89]: a[np.repeat(np.arange(x), 3), ind.ravel()].reshape(x, 3)
Out[89]:
array([[10, 11, 12],
       [16, 16, 18],
       [13, 15, 18],
       [14, 18, 19],
       [16, 18, 19],
       [14, 14, 19],
       [15, 18, 19],
       [16, 17, 19],
       [ 9, 14, 14],
       [12, 15, 18]])
 5
Author: Kasrâmvd,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:53:53

Będzie to szybsze niż pełne sortowanie w zależności od rozmiaru oryginalnej tablicy i rozmiaru Twojego wyboru:

>>> A = np.random.randint(0,10,10)
>>> A
array([5, 1, 5, 5, 2, 3, 2, 4, 1, 0])
>>> B = np.zeros(3, int)
>>> for i in xrange(3):
...     idx = np.argmax(A)
...     B[i]=idx; A[idx]=0 #something smaller than A.min()
...     
>>> B
array([0, 2, 3])
To oczywiście wiąże się z manipulacją oryginalną tablicą. Które można naprawić (w razie potrzeby), wykonując kopię lub zastępując z powrotem oryginalne wartości. ...która jest tańsza w przypadku użycia.
 4
Author: Paul,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2011-08-02 14:02:59

bottleneck ma funkcję sortowania częściowego, jeśli koszt sortowania całej tablicy tylko po to, aby uzyskać N największych wartości jest zbyt duży.

nic nie wiem o tym module, po prostu wygooglowałem numpy partial sort.

 3
Author: Katriel,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2011-08-02 10:37:08

Użycie:

from operator import itemgetter
from heapq import nlargest
result = nlargest(N, enumerate(your_list), itemgetter(1))

Teraz result lista będzie zawieraćN krotki (index, value) Gdzie value jest maksymalizowana.

 2
Author: off99555,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:50:56

Użycie:

def max_indices(arr, k):
    '''
    Returns the indices of the k first largest elements of arr
    (in descending order in values)
    '''
    assert k <= arr.size, 'k should be smaller or equal to the array size'
    arr_ = arr.astype(float)  # make a copy of arr
    max_idxs = []
    for _ in range(k):
        max_element = np.max(arr_)
        if np.isinf(max_element):
            break
        else:
            idx = np.where(arr_ == max_element)
        max_idxs.append(idx)
        arr_[idx] = -np.inf
    return max_idxs

Działa również z tablicami 2D. Na przykład,

In [0]: A = np.array([[ 0.51845014,  0.72528114],
                     [ 0.88421561,  0.18798661],
                     [ 0.89832036,  0.19448609],
                     [ 0.89832036,  0.19448609]])
In [1]: max_indices(A, 8)
Out[1]:
    [(array([2, 3], dtype=int64), array([0, 0], dtype=int64)),
     (array([1], dtype=int64), array([0], dtype=int64)),
     (array([0], dtype=int64), array([1], dtype=int64)),
     (array([0], dtype=int64), array([0], dtype=int64)),
     (array([2, 3], dtype=int64), array([1, 1], dtype=int64)),
     (array([1], dtype=int64), array([1], dtype=int64))]

In [2]: A[max_indices(A, 8)[0]][0]
Out[2]: array([ 0.89832036])
 2
Author: Andyk,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:57:30

Metoda np.argpartition zwraca tylko największe indeksy k, wykonuje sortowanie lokalne i jest szybsza niż np.argsort(wykonuje sortowanie pełne), gdy tablica jest dość duża. Ale zwracane indeksy są nie w porządku rosnącym/malejącym . Na przykład:

Tutaj wpisz opis obrazka

Widzimy, że jeśli chcesz ścisłego rosnącego porządku indeksów k, np.argpartition nie zwróci tego, co chcesz.

Oprócz ręcznego sortowania po np.argpartition, moim rozwiązaniem jest użycie Pytorcha, torch.topk, narzędzie do budowy sieci neuronowych, zapewniające API podobne do NumPy z obsługą zarówno procesora, jak i GPU. Jest tak szybki jak NumPy z MKL i oferuje przyspieszenie GPU, jeśli potrzebujesz dużych obliczeń macierzy/wektorów.

Ścisły kod ascend / descend górnego indeksu k będzie wynosił:

Tutaj wpisz opis obrazka

Zauważ, że torch.topk przyjmuje tensor pochodni i zwraca zarówno górne wartości k, jak i górne indeksy k w typie torch.Tensor. Podobnie z np, torch.topk akceptuje również argument osi, więc że można obsługiwać wielowymiarowe tablice/tensory.

 1
Author: futureer,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:56:09

Uważam, że jest najbardziej intuicyjny w obsłudze np.unique.

Chodzi o to, że unikalna metoda zwraca indeksy wartości wejściowych. Następnie z wartości Max unique i wskaźników można odtworzyć położenie oryginalnych wartości.

multi_max = [1,1,2,2,4,0,0,4]
uniques, idx = np.unique(multi_max, return_inverse=True)
print np.squeeze(np.argwhere(idx == np.argmax(uniques)))
>> [4 7]
 0
Author: phi,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-01-12 18:38:54

Myślę, że najbardziej efektywnym sposobem jest Manualna iteracja przez tablicę i zachowanie sterty min O rozmiarze K, jak inni ludzie wspominali.

I ja też wymyślam podejście brute force:

top_k_index_list = [ ]
for i in range(k):
    top_k_index_list.append(np.argmax(my_array))
    my_array[top_k_index_list[-1]] = -float('inf')

Ustaw największy element na dużą wartość ujemną po użyciu argmax, aby uzyskać jego indeks. Następne wywołanie argmax zwróci drugi co do wielkości element. Możesz też zapisać oryginalną wartość tych elementów i odzyskać je, jeśli chcesz.

 0
Author: Zhenghao Zhao,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:59:14

Poniżej znajduje się bardzo łatwy sposób, aby zobaczyć maksymalne elementy i ich pozycje. Tutaj axis jest domeną; axis = 0 oznacza maksymalną liczbę kolumn, a axis = 1 oznacza maksymalną liczbę wierszy dla przypadku 2D. A dla wyższych wymiarów to zależy od Ciebie.

M = np.random.random((3, 4))
print(M)
print(M.max(axis=1), M.argmax(axis=1))
 0
Author: liberal,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 03:01:12