Jak uzyskać indeksy N maksymalnych wartości w tablicy NumPy?
NumPy proponuje sposób na uzyskanie indeksu maksymalnej wartości tablicy poprzez np.argmax
.
Chciałbym podobną rzecz, ale zwracając indeksy N maksymalnych wartości.
Na przykład, jeśli mam tablicę, [1, 3, 2, 4, 5]
, function(array, n=3)
zwróci [4, 3, 1]
.
15 answers
Najprostsze, jakie udało mi się wymyślić to:
In [1]: import numpy as np
In [2]: arr = np.array([1, 3, 2, 4, 5])
In [3]: arr.argsort()[-3:][::-1]
Out[3]: array([4, 3, 1])
To wymaga kompletnego rodzaju tablicy. Zastanawiam się, czy numpy
zapewnia wbudowany sposób na częściowe sortowanie; do tej pory nie byłem w stanie go znaleźć.
Jeśli to rozwiązanie okaże się zbyt wolne (szczególnie dla małych n
), może warto przyjrzeć się kodowaniu czegoś w Cythonie .
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2011-08-02 10:45:41
Nowsze wersje NumPy (1.8 i nowsze) posiadają funkcję o nazwie argpartition
za to. Aby uzyskać indeksy czterech największych elementów, wykonaj
>>> a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> a
array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])
>>> ind = np.argpartition(a, -4)[-4:]
>>> ind
array([1, 5, 8, 0])
>>> a[ind]
array([4, 9, 6, 9])
W przeciwieństwie do argsort
, Funkcja ta działa w czasie liniowym w najgorszym przypadku, ale zwracane indeksy nie są sortowane, co widać na podstawie wyniku oceny a[ind]
. Jeśli tego też potrzebujesz, posortuj je później:
>>> ind[np.argsort(a[ind])]
array([1, 8, 5, 0])
Aby uzyskać top- K elementy w porządku posortowanym w ten sposób przyjmuje O (n + K log K ) Czas.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-08-16 17:14:16
Prostsze jeszcze:
idx = (-arr).argsort()[:n]
Gdzie n jest liczbą wartości maksymalnych.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2014-12-11 22:13:05
Użycie:
>>> import heapq
>>> import numpy
>>> a = numpy.array([1, 3, 2, 4, 5])
>>> heapq.nlargest(3, range(len(a)), a.take)
[4, 3, 1]
Dla zwykłych list Pythona:
>>> a = [1, 3, 2, 4, 5]
>>> heapq.nlargest(3, range(len(a)), a.__getitem__)
[4, 3, 1]
Jeśli używasz Pythona 2, Użyj xrange
zamiast range
.
Źródło: heapq-algorytm kolejki sterty
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:49:39
Jeśli zdarzy ci się pracować z wielowymiarową tablicą, musisz spłaszczyć i rozwikłać indeksy:
def largest_indices(ary, n):
"""Returns the n largest indices from a numpy array."""
flat = ary.flatten()
indices = np.argpartition(flat, -n)[-n:]
indices = indices[np.argsort(-flat[indices])]
return np.unravel_index(indices, ary.shape)
Na przykład:
>>> xs = np.sin(np.arange(9)).reshape((3, 3))
>>> xs
array([[ 0. , 0.84147098, 0.90929743],
[ 0.14112001, -0.7568025 , -0.95892427],
[-0.2794155 , 0.6569866 , 0.98935825]])
>>> largest_indices(xs, 3)
(array([2, 0, 0]), array([2, 2, 1]))
>>> xs[largest_indices(xs, 3)]
array([ 0.98935825, 0.90929743, 0.84147098])
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2016-08-10 21:42:27
Jeśli nie zależy ci na porządku K-tych największych elementów możesz użyć argpartition
, który powinien działać lepiej niż pełne sortowanie argsort
.
K = 4 # We want the indices of the four largest values
a = np.array([0, 8, 0, 4, 5, 8, 8, 0, 4, 2])
np.argpartition(a,-K)[-K:]
array([4, 1, 5, 6])
Napisy idą do to pytanie .
Przeprowadziłem kilka testów i wygląda na to, że argpartition
przewyższa argsort
Jako rozmiar tablicy i wartość wzrostu K.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:52:38
Dla tablic wielowymiarowych można użyć słowa kluczowego axis
, Aby zastosować partycjonowanie wzdłuż oczekiwanej osi.
# For a 2D array
indices = np.argpartition(arr, -N, axis=1)[:, -N:]
I za chwytanie przedmiotów:
x = arr.shape[0]
arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)
Ale zauważ, że to nie zwróci sortowanego wyniku. W takim przypadku możesz użyć np.argsort()
wzdłuż zamierzonej osi:
indices = np.argsort(arr, axis=1)[:, -N:]
# Result
x = arr.shape[0]
arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)
Oto przykład:
In [42]: a = np.random.randint(0, 20, (10, 10))
In [44]: a
Out[44]:
array([[ 7, 11, 12, 0, 2, 3, 4, 10, 6, 10],
[16, 16, 4, 3, 18, 5, 10, 4, 14, 9],
[ 2, 9, 15, 12, 18, 3, 13, 11, 5, 10],
[14, 0, 9, 11, 1, 4, 9, 19, 18, 12],
[ 0, 10, 5, 15, 9, 18, 5, 2, 16, 19],
[14, 19, 3, 11, 13, 11, 13, 11, 1, 14],
[ 7, 15, 18, 6, 5, 13, 1, 7, 9, 19],
[11, 17, 11, 16, 14, 3, 16, 1, 12, 19],
[ 2, 4, 14, 8, 6, 9, 14, 9, 1, 5],
[ 1, 10, 15, 0, 1, 9, 18, 2, 2, 12]])
In [45]: np.argpartition(a, np.argmin(a, axis=0))[:, 1:] # 1 is because the first item is the minimum one.
Out[45]:
array([[4, 5, 6, 8, 0, 7, 9, 1, 2],
[2, 7, 5, 9, 6, 8, 1, 0, 4],
[5, 8, 1, 9, 7, 3, 6, 2, 4],
[4, 5, 2, 6, 3, 9, 0, 8, 7],
[7, 2, 6, 4, 1, 3, 8, 5, 9],
[2, 3, 5, 7, 6, 4, 0, 9, 1],
[4, 3, 0, 7, 8, 5, 1, 2, 9],
[5, 2, 0, 8, 4, 6, 3, 1, 9],
[0, 1, 9, 4, 3, 7, 5, 2, 6],
[0, 4, 7, 8, 5, 1, 9, 2, 6]])
In [46]: np.argpartition(a, np.argmin(a, axis=0))[:, -3:]
Out[46]:
array([[9, 1, 2],
[1, 0, 4],
[6, 2, 4],
[0, 8, 7],
[8, 5, 9],
[0, 9, 1],
[1, 2, 9],
[3, 1, 9],
[5, 2, 6],
[9, 2, 6]])
In [89]: a[np.repeat(np.arange(x), 3), ind.ravel()].reshape(x, 3)
Out[89]:
array([[10, 11, 12],
[16, 16, 18],
[13, 15, 18],
[14, 18, 19],
[16, 18, 19],
[14, 14, 19],
[15, 18, 19],
[16, 17, 19],
[ 9, 14, 14],
[12, 15, 18]])
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:53:53
Będzie to szybsze niż pełne sortowanie w zależności od rozmiaru oryginalnej tablicy i rozmiaru Twojego wyboru:
>>> A = np.random.randint(0,10,10)
>>> A
array([5, 1, 5, 5, 2, 3, 2, 4, 1, 0])
>>> B = np.zeros(3, int)
>>> for i in xrange(3):
... idx = np.argmax(A)
... B[i]=idx; A[idx]=0 #something smaller than A.min()
...
>>> B
array([0, 2, 3])
To oczywiście wiąże się z manipulacją oryginalną tablicą. Które można naprawić (w razie potrzeby), wykonując kopię lub zastępując z powrotem oryginalne wartości. ...która jest tańsza w przypadku użycia.Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2011-08-02 14:02:59
bottleneck
ma funkcję sortowania częściowego, jeśli koszt sortowania całej tablicy tylko po to, aby uzyskać N największych wartości jest zbyt duży.
nic nie wiem o tym module, po prostu wygooglowałem numpy partial sort
.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2011-08-02 10:37:08
Użycie:
from operator import itemgetter
from heapq import nlargest
result = nlargest(N, enumerate(your_list), itemgetter(1))
Teraz result
lista będzie zawieraćN krotki (index
, value
) Gdzie value
jest maksymalizowana.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:50:56
Użycie:
def max_indices(arr, k):
'''
Returns the indices of the k first largest elements of arr
(in descending order in values)
'''
assert k <= arr.size, 'k should be smaller or equal to the array size'
arr_ = arr.astype(float) # make a copy of arr
max_idxs = []
for _ in range(k):
max_element = np.max(arr_)
if np.isinf(max_element):
break
else:
idx = np.where(arr_ == max_element)
max_idxs.append(idx)
arr_[idx] = -np.inf
return max_idxs
Działa również z tablicami 2D. Na przykład,
In [0]: A = np.array([[ 0.51845014, 0.72528114],
[ 0.88421561, 0.18798661],
[ 0.89832036, 0.19448609],
[ 0.89832036, 0.19448609]])
In [1]: max_indices(A, 8)
Out[1]:
[(array([2, 3], dtype=int64), array([0, 0], dtype=int64)),
(array([1], dtype=int64), array([0], dtype=int64)),
(array([0], dtype=int64), array([1], dtype=int64)),
(array([0], dtype=int64), array([0], dtype=int64)),
(array([2, 3], dtype=int64), array([1, 1], dtype=int64)),
(array([1], dtype=int64), array([1], dtype=int64))]
In [2]: A[max_indices(A, 8)[0]][0]
Out[2]: array([ 0.89832036])
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:57:30
Metoda np.argpartition
zwraca tylko największe indeksy k, wykonuje sortowanie lokalne i jest szybsza niż np.argsort
(wykonuje sortowanie pełne), gdy tablica jest dość duża. Ale zwracane indeksy są nie w porządku rosnącym/malejącym . Na przykład:
Widzimy, że jeśli chcesz ścisłego rosnącego porządku indeksów k, np.argpartition
nie zwróci tego, co chcesz.
Oprócz ręcznego sortowania po np.argpartition, moim rozwiązaniem jest użycie Pytorcha, torch.topk
, narzędzie do budowy sieci neuronowych, zapewniające API podobne do NumPy z obsługą zarówno procesora, jak i GPU. Jest tak szybki jak NumPy z MKL i oferuje przyspieszenie GPU, jeśli potrzebujesz dużych obliczeń macierzy/wektorów.
Ścisły kod ascend / descend górnego indeksu k będzie wynosił:
Zauważ, że torch.topk
przyjmuje tensor pochodni i zwraca zarówno górne wartości k, jak i górne indeksy k w typie torch.Tensor
. Podobnie z np, torch.topk akceptuje również argument osi, więc że można obsługiwać wielowymiarowe tablice/tensory.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:56:09
Uważam, że jest najbardziej intuicyjny w obsłudze np.unique
.
Chodzi o to, że unikalna metoda zwraca indeksy wartości wejściowych. Następnie z wartości Max unique i wskaźników można odtworzyć położenie oryginalnych wartości.
multi_max = [1,1,2,2,4,0,0,4]
uniques, idx = np.unique(multi_max, return_inverse=True)
print np.squeeze(np.argwhere(idx == np.argmax(uniques)))
>> [4 7]
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-01-12 18:38:54
Myślę, że najbardziej efektywnym sposobem jest Manualna iteracja przez tablicę i zachowanie sterty min O rozmiarze K, jak inni ludzie wspominali.
I ja też wymyślam podejście brute force:
top_k_index_list = [ ]
for i in range(k):
top_k_index_list.append(np.argmax(my_array))
my_array[top_k_index_list[-1]] = -float('inf')
Ustaw największy element na dużą wartość ujemną po użyciu argmax, aby uzyskać jego indeks. Następne wywołanie argmax zwróci drugi co do wielkości element. Możesz też zapisać oryginalną wartość tych elementów i odzyskać je, jeśli chcesz.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 02:59:14
Poniżej znajduje się bardzo łatwy sposób, aby zobaczyć maksymalne elementy i ich pozycje. Tutaj axis
jest domeną; axis
= 0 oznacza maksymalną liczbę kolumn, a axis
= 1 oznacza maksymalną liczbę wierszy dla przypadku 2D. A dla wyższych wymiarów to zależy od Ciebie.
M = np.random.random((3, 4))
print(M)
print(M.max(axis=1), M.argmax(axis=1))
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2018-06-28 03:01:12