A*(A Star)

장승현·2023년 3월 27일
0

알고리즘

목록 보기
10/11
post-thumbnail

개요

A* 알고리즘은 최단 경로 찾기 알고리즘으로, 다익스트라 알고리즘과 유사하지만 시작 노드에서 도착 노드까지의 최단 경로를 찾는다는 점에서 차이가 있다. 휴리스틱 추정값(특정 노드에서 도착 노드까지 가는데 필요한 추정 비용)을 통해 알고리즘을 개선할 수 있으며, 그 방식에 따라 성능이 결정된다.

A* Algorithm

시작 노드에서 도착 노드로 가는데 필요한 최소 비용을 계산한다. 그 방식은 출발 노드에서 경유 노드로 가는 비용(G)과 경유 노드에서 도착 노드로 가는데 필요한 추정 비용(H)을 합한 비용(F)들간의 비교를 통해 이루어진다. 수식으로 표현하면 다음과 같다.

수행 과정

  1. 먼저 시작 노드를 선정하고 closed list에 추가한다.
  2. 시작 노드 인근의 노드를 opened list에 추가한다. 여기서 시작 노드는 인근 노드들의 부모가 된다.
  3. opened list에 있는 노드들 중 F 값이 가장 작은 노드를 선택하고 closed list에 추가한다.
  4. 새롭게 추가한 노드를 기준으로 인접 노드들을 opened list에 추가한다. 여기서 closed list에 있는 노드는 제외하고, 이미 opened list에 존재하는 노드들은 부모 노드에서 오는 비용과 새롭게 추가한 노드를 경유해서 오는 비용을 비교해 최소 비용을 갱신한다. 이때, 부모 또한 갱신된다.
  5. 3~4번 과정을 목표 노드에 도착할 때까지 반복한다. 만약 opened list가 비어있게 된다면, 길이 없는 경우로 길 찾기에 실패하게 된다.
  6. 최종적으로 만들어진 closed list에서 도착 노드의 부모부터 거슬러 올라가 최단 경로를 찾아낸다.

예시

위 그림에서 출발 노드는 초록색, 도착 노드는 빨간색, 장애물은 파란색을 의미한다. 이동에서 가로, 세로는 10의 비용이, 대각선은 14의 비용(sqrt(pow(10,2) + pow(10,2))이 든다.
여기서 출발 노드인 초록색을 closed list에 추가하고, 인접 노드들은 opened list에 추가한다. 인접 노드들의 비용은 다음과 같다.

F는 사각형의 왼쪽 위, G는 왼쪽 아래, H는 오른쪽 아래에 표시하였다. H는 맨해튼 거리로 구하였으며, 대각선 이동은 포함하지 않았다. 이 중 F가 가장 작은 노드는 출발 노드의 오른쪽 노드로, 이를 closed list에 추가한다. 이 노드를 기준으로 opened list를 갱신할 때, closed list에 존재하는 출발 노드는 제외한다. 또한, 오른쪽과 오른쪽 위아래 노드는 장애물이므로 무시한다. 남은 노드들은 이미 opened list에 있는 노드들로, 각 노드의 부모에서 오는 비용(G)과 현재 노드를 경유해서 가는 비용을 비교해 갱신한다. 여기서 갱신 가능한 노드는 없으므로 각 노드의 부모 변화 없이 F가 가장 작은 노드를 선택한다. 현재 노드 기준 위 아래의 F 값이 같기 때문에 어느 노드를 골라도 무방하다. 만약 위 노드를 선택하여 closed list에 추가하고 opened list를 갱신하면 다음과 같다.

아래와 왼쪽 아래는 closed list로 제외하며, 오른쪽과 오른쪽 위 아래는 장애물로 인해 지나갈 수 없으므로 제외한다. 왼쪽 노드는 기존 opened list에 있었지만, 10 : 14 + 10으로 최소 비용의 변동은 없다. 위와 왼쪽 위는 새로 opened list에 추가하며 비용을 계산한다.

다음 opened list에서 가장 작은 수는 F가 54인 출발 노드 오른쪽 아래이며, 위와 같이 나타낼 수 있다.

이러한 과정을 목표 노드에 도착할 때까지 반복하다 보면 다음과 같은 결과가 나오게 된다.

이때 반복 과정에서 같은 F 값을 가졌던 노드 중 어떤 노드를 먼저 선택하느냐에 따라 경로는 조금씩 변할 수 있으며, 그렇다 할지라도 비용은 모두 같을 것이다.

코드 구현

#include <iostream>
#include <vector>
#include <array>

std::array<std::array<int, 7>, 5> map = {{ //1은 장애물
    {0, 0, 0, 0, 0, 0, 0},
    {0, 0, 0, 1, 0, 0, 0},
    {0, 0, 0, 1, 0, 0, 0},
    {0, 0, 0, 1, 0, 0, 0},
    {0, 0, 0, 0, 0, 0, 0}
}};

std::array<std::array<int, 7>, 5> h = {{ //도착 노드로 가는 휴리스틱 추정값
    {70, 60, 50, 40, 30, 20, 30},
    {60, 50, 40, 30, 20, 10, 20},
    {50, 40, 30, 20, 10, 0, 10},
    {60, 50, 40, 30, 20, 10, 20},
    {70, 60, 50, 40, 30, 20, 30}
}};

std::array<std::array<int, 2>, 9> offset = {{ //인접 노드들로 가기 위한 offset값
    {-1,-1}, {-1,0}, {-1,1},
    {0,-1}, {0,0}, {0,1},
    {1,-1}, {1,0}, {1,1}
}};

std::array<int, 9> g = {{ //출발 노드에서 경유 노드로 가는 비용
    14, 10, 14,
    10, 0, 10,
    14, 10, 14
}};

class node
{
public:
    int f{0}, g{0}, h{0};
    std::array<int, 2> pos, parent;
};

class AStar
{
public:
    bool isIncluded(std::vector<node> list, node target){
        bool is_included{false};

        for (int i=0;i<list.size();i++){
            if (list[i].pos == target.pos){
                is_included = true;
            }
        }
        return is_included;
    }

    bool isObstacle(node target){
        bool is_obstacle{false};

        if (map[target.pos[0]][target.pos[1]] == 1){
            is_obstacle = true;
        }

        return is_obstacle;
    }

    bool isMapOut(node target){
        bool is_mapout{false};

        if (target.pos[0] < 0 || target.pos[1] < 0 || target.pos[0] > 4 || target.pos[1] > 6){
            is_mapout = true;
        }

        return is_mapout;
    }

    std::vector<int> cantPass(node curr_node){
        std::vector<int> cant_pass;
        int array[4] = {1,3,5,7};

        for (int i=0;i<4;i++){
            node temp_node;
            temp_node.pos = {curr_node.pos[0] + offset[array[i]][0], curr_node.pos[1] + offset[array[i]][1]};
            bool is_obstacle = isObstacle(temp_node);

            if (is_obstacle){
                if (array[i] == 1) {
                    cant_pass.emplace_back(0);
                    cant_pass.emplace_back(2);
                }
                else if (array[i] == 3) {
                    cant_pass.emplace_back(0);
                    cant_pass.emplace_back(6);
                }
                else if (array[i] == 5) {
                    cant_pass.emplace_back(2);
                    cant_pass.emplace_back(8);
                }
                else {
                    cant_pass.emplace_back(6);
                    cant_pass.emplace_back(8);
                }
            }
        }
        return cant_pass;
    }

    std::vector<node> renewG(std::vector<node> opened_list, node temp_node){
        for (int i=0;i<opened_list.size();i++){
            if (opened_list[i].pos == temp_node.pos){
                if (opened_list[i].g > temp_node.g){
                    opened_list[i] = temp_node;
                }
            }
        }
        return opened_list;
    } 

    std::vector<node> getOpenedList(node curr_node, std::vector<node> opened_list, std::vector<node> closed_list, std::vector<int> cant_pass){
        for (int i=0;i<9;i++){
            if (i != 4){ //4번 인덱스는 자기 자신
                node temp_node;
                temp_node.pos = {curr_node.pos[0] + offset[i][0], curr_node.pos[1] + offset[i][1]};
                temp_node.g = curr_node.g + g[i];
                temp_node.h = h[temp_node.pos[0]][temp_node.pos[1]];
                temp_node.f = temp_node.g + temp_node.h;
                temp_node.parent = {curr_node.pos[0], curr_node.pos[1]};
                
                bool open_included = isIncluded(opened_list, temp_node);
                bool close_included = isIncluded(closed_list, temp_node);
                bool is_obstacle = isObstacle(temp_node);
                bool is_mapout = isMapOut(temp_node);
                bool is_cant_pass{false};
                for (int j=0;j<cant_pass.size();j++){
                    if (i == cant_pass[j]){
                        is_cant_pass = true;
                    }
                }

                if (open_included){
                    opened_list = renewG(opened_list, temp_node);
                }

                if (!(open_included || close_included || is_obstacle || is_mapout || is_cant_pass)){
                    opened_list.emplace_back(temp_node);
                }
            }
        }
        return opened_list;
    }

    int getCurrNodeIndex(std::vector<node> opened_list, std::vector<node> closed_list){
        int min_f = 9999, min_index;

        for (int i=0;i<opened_list.size();i++){
            bool is_included = isIncluded(closed_list, opened_list[i]);
            if (min_f > opened_list[i].f && !is_included){
                min_f = opened_list[i].f;
                min_index = i;
            }
        }
        return min_index;
    }

    std::vector<std::array<int, 2>> getShortPath(std::vector<node> closed_list, node start){
        std::vector<std::array<int, 2>> short_path;
        std::array<int, 2> parent;

        parent = closed_list[closed_list.size()-1].parent;
        short_path.insert(short_path.begin(), closed_list[closed_list.size()-1].pos);
        short_path.insert(short_path.begin(), parent);
        while (parent != start.pos){
            for (int i=0;i<closed_list.size();i++){
                if (closed_list[i].pos == parent){
                    parent = closed_list[i].parent;
                    short_path.insert(short_path.begin(), parent);
                }
            }
        }
        return short_path;
    }
};

int main(){
    std::vector<node> opened_list, closed_list;
    node start, end, curr_node;
    AStar astar;
    int min_index;
    std::vector<std::array<int, 2>> short_path;
    std::vector<int> cant_pass;

    start.pos = {2, 1};

    end.pos = {2, 5};

    curr_node = start;
    closed_list.emplace_back(start);

    while (curr_node.pos != end.pos){
        cant_pass = astar.cantPass(curr_node);
        opened_list = astar.getOpenedList(curr_node, opened_list, closed_list, cant_pass);
        min_index = astar.getCurrNodeIndex(opened_list, closed_list);
        curr_node = opened_list[min_index];
        closed_list.emplace_back(curr_node);
        opened_list.erase(opened_list.begin() + min_index);

        if (opened_list.size() == 0) {
            std::cout<<"Fail"<<std::endl;
            break;
        }
    }
    
    short_path = astar.getShortPath(closed_list, start);
    for (int i=0;i<short_path.size();i++){
        std::cout<<short_path[i][0]<<' '<<short_path[i][1]<<std::endl;
    }

    return 0;
}

Reference

http://cozycoz.egloos.com/m/9748811

profile
늦더라도 끝이 강한 내가 되자

0개의 댓글