Algorytm znajdowania szczytu 2D w czasie O (n) najgorszym przypadku?
RobiłemTen kurs na algorytmy z MIT. W pierwszym wykładzie profesor przedstawia następujący problem: -
Pik w tablicy 2D jest wartością taką, że wszystkie jej 4 sąsiadki są od niej mniejsze lub równe, tj. dla
a[i][j]
BYĆ lokalnym maksimum,
a[i+1][j] <= a[i][j]
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]
Teraz mając tablicę NXN 2D, znajdź szczyt w tablicy.
To pytanie może być łatwo rozwiązane w O(N^2)
czasie przez iterację wszystkich elementów i zwracanie szczyt.
Jednak można go zoptymalizować do rozwiązania w czasie O(NlogN)
za pomocą rozwiązania podziel i podbij, jak wyjaśniono tutaj .
Ale powiedzieli, że istnieje algorytm czasu O(N)
, który rozwiązuje ten problem. Proszę zasugerować, jak możemy rozwiązać ten problem w O(N)
czas.
PS (dla tych, którzy znają Pythona) personel kursu wyjaśnił podejście tutaj (Problem 1-5. Peak-Finding Proof), a także dostarczył część kodu Pythona w swoich zestawach problemów. Ale wyjaśnione podejście jest całkowicie nieoczywiste i bardzo trudne do rozszyfrowania. Kod Pythona jest równie mylący. Dlatego skopiowałem główną część poniższego kodu dla tych, którzy znają Pythona i mogą powiedzieć, jaki algorytm jest używany z kodu.
def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
# if it's empty, we're done
if problem.numRow <= 0 or problem.numCol <= 0:
return None
subproblems = []
divider = []
if rowSplit:
# the recursive subproblem will involve half the number of rows
mid = problem.numRow // 2
# information about the two subproblems
(subStartR1, subNumR1) = (0, mid)
(subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
(subStartC, subNumC) = (0, problem.numCol)
subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
subproblems.append((subStartR2, subStartC, subNumR2, subNumC))
# get a list of all locations in the dividing column
divider = crossProduct([mid], range(problem.numCol))
else:
# the recursive subproblem will involve half the number of columns
mid = problem.numCol // 2
# information about the two subproblems
(subStartR, subNumR) = (0, problem.numRow)
(subStartC1, subNumC1) = (0, mid)
(subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))
subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
subproblems.append((subStartR, subStartC2, subNumR, subNumC2))
# get a list of all locations in the dividing column
divider = crossProduct(range(problem.numRow), [mid])
# find the maximum in the dividing row or column
bestLoc = problem.getMaximum(divider, trace)
neighbor = problem.getBetterNeighbor(bestLoc, trace)
# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
bestSeen = neighbor
if not trace is None: trace.setBestSeen(bestSeen)
# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
if not trace is None: trace.foundPeak(bestLoc)
return bestLoc
# figure out which subproblem contains the largest number we've seen so
# far, and recurse, alternating between splitting on rows and splitting
# on columns
sub = problem.getSubproblemContaining(subproblems, bestSeen)
newBest = sub.getLocationInSelf(problem, bestSeen)
if not trace is None: trace.setProblemDimensions(sub)
result = algorithm4(sub, newBest, not rowSplit, trace)
return problem.getLocationInSelf(sub, result)
#Helper Method
def crossProduct(list1, list2):
"""
Returns all pairs with one item from the first list and one item from
the second list. (Cartesian product of the two lists.)
The code is equivalent to the following list comprehension:
return [(a, b) for a in list1 for b in list2]
but for easier reading and analysis, we have included more explicit code.
"""
answer = []
for a in list1:
for b in list2:
answer.append ((a, b))
return answer
2 answers
- Załóżmy, że szerokość tablicy jest większa niż wysokość, w przeciwnym razie podzielimy ją w innym kierunku.
- podziel tablicę na trzy części: kolumnę centralną, lewą i prawą stronę.
- przejdź przez kolumnę centralną i dwie sąsiadujące kolumny i poszukaj maksimum.
- jeśli jest w kolumnie centralnej-to jest nasz szczyt
- jeśli jest po lewej stronie, uruchom ten algorytm na subarray
left_side + central_column
- jeśli jest po prawej stronie, uruchom ten algorytm na subarray
right_side + central_column
Dlaczego to działa:
W przypadkach, gdy element maksymalny znajduje się w kolumnie centralnej-oczywisty. Jeśli nie, możemy przejść od tego maksimum do elementów rosnących i na pewno nie przekroczymy środkowego rzędu, więc szczyt na pewno będzie istniał w odpowiadającej mu połowie.
Dlaczego To Jest O (n):
Krok # 3 zajmuje mniej lub równo max_dimension
iteracji i max_dimension
co najmniej połowę na każdym kroku algorytmu. Daje to n+n/2+n/4+...
, czyli O(n)
. Ważny szczegół: dzielimy przez maksymalny kierunek. Dla tablic kwadratowych oznacza to, że kierunki podziału będą przemienne. Jest to różnica w stosunku do ostatniej próby w pliku PDF, z którym się łączyłeś.
Uwaga: nie jestem pewien, czy dokładnie pasuje do algorytmu w kodzie, który podałeś, może to być inne podejście, ale nie musi być inne.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2014-04-19 06:42:48
Oto działający kod Javy , który implementuje algorytm @ maxim1000. Poniższy kod znajduje szczyt w tablicy 2D w czasie liniowym.
import java.util.*;
class Ideone{
public static void main (String[] args) throws java.lang.Exception{
new Ideone().run();
}
int N , M ;
void run(){
N = 1000;
M = 100;
// arr is a random NxM array
int[][] arr = randomArray();
long start = System.currentTimeMillis();
// for(int i=0; i<N; i++){ // TO print the array.
// System. out.println(Arrays.toString(arr[i]));
// }
System.out.println(findPeakLinearTime(arr));
long end = System.currentTimeMillis();
System.out.println("time taken : " + (end-start));
}
int findPeakLinearTime(int[][] arr){
int rows = arr.length;
int cols = arr[0].length;
return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
}
// helper function that splits on the middle Column
int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
if(loCol==hiCol){
int max = arr[loRow][loCol];
int foundRow = loRow;
for(int row = loRow; row<=hiRow; row++){
if(max < arr[row][loCol]){
max = arr[row][loCol];
foundRow = row;
}
}
if(!correctPeak(arr, foundRow, loCol)){
System.out.println("THIS PEAK IS WRONG");
}
return max;
}
int midCol = (loCol+hiCol)/2;
int max = arr[loRow][loCol];
for(int row=loRow; row<=hiRow; row++){
max = Math.max(max, arr[row][midCol]);
}
boolean centralMax = true;
boolean rightMax = false;
boolean leftMax = false;
if(midCol-1 >= 0){
for(int row = loRow; row<=hiRow; row++){
if(arr[row][midCol-1] > max){
max = arr[row][midCol-1];
centralMax = false;
leftMax = true;
}
}
}
if(midCol+1 < M){
for(int row=loRow; row<=hiRow; row++){
if(arr[row][midCol+1] > max){
max = arr[row][midCol+1];
centralMax = false;
leftMax = false;
rightMax = true;
}
}
}
if(centralMax) return max;
if(rightMax) return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
if(leftMax) return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
throw new RuntimeException("INCORRECT CODE");
}
// helper function that splits on the middle
int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
if(loRow==hiRow){
int ans = arr[loCol][loRow];
int foundCol = loCol;
for(int col=loCol; col<=hiCol; col++){
if(arr[loRow][col] > ans){
ans = arr[loRow][col];
foundCol = col;
}
}
if(!correctPeak(arr, loRow, foundCol)){
System.out.println("THIS PEAK IS WRONG");
}
return ans;
}
boolean centralMax = true;
boolean upperMax = false;
boolean lowerMax = false;
int midRow = (loRow+hiRow)/2;
int max = arr[midRow][loCol];
for(int col=loCol; col<=hiCol; col++){
max = Math.max(max, arr[midRow][col]);
}
if(midRow-1>=0){
for(int col=loCol; col<=hiCol; col++){
if(arr[midRow-1][col] > max){
max = arr[midRow-1][col];
upperMax = true;
centralMax = false;
}
}
}
if(midRow+1<N){
for(int col=loCol; col<=hiCol; col++){
if(arr[midRow+1][col] > max){
max = arr[midRow+1][col];
lowerMax = true;
centralMax = false;
upperMax = false;
}
}
}
if(centralMax) return max;
if(lowerMax) return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
if(upperMax) return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
throw new RuntimeException("Incorrect code");
}
int[][] randomArray(){
int[][] arr = new int[N][M];
for(int i=0; i<N; i++)
for(int j=0; j<M; j++)
arr[i][j] = (int)(Math.random()*1000000000);
return arr;
}
boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
if(row-1>=0 && arr[row-1][col]>arr[row][col]) return false;
if(row+1<N && arr[row+1][col]>arr[row][col]) return false;
if(col-1>=0 && arr[row][col-1]>arr[row][col]) return false;
if(col+1<M && arr[row][col+1]>arr[row][col]) return false;
return true;
}
}
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/doraprojects.net/template/agent.layouts/content.php on line 54
2014-04-19 19:12:37