가까운 장소끼리 클러스터링하기 ! (feat. KMEANS)

SOUTH DARIA·2022년 2월 10일
5

클러스터링

목록 보기
1/1
post-thumbnail

KMEANS

KMEANS 클러스터링 - 출처 : https://hleecaster.com/ml-kmeans-clustering-concept/

“K“는 데이터 세트에서 찾을 것으로 예상되는 클러스터(그룹) 수를 말한다. “Means“는 각 데이터로부터 그 데이터가 속한 클러스터의 중심까지의 평균 거리를 의미한다. (이 값을 최소화하는 게 알고리즘의 목표가 된다.) K-Means에서는 이걸 구현하기 위해 반복적인(iterative) 접근을 취한다.

일단 K개의 임의의 중심점(centroid)을 배치하고 각 데이터들을 가장 가까운 중심점으로 할당한다. (일종의 군집을 형성한다.) 군집으로 지정된 데이터들을 기반으로 해당 군집의 중심점을 업데이트한다. 2번, 3번 단계를 그래서 수렴이 될 때까지, 즉 더이상 중심점이 업데이트 되지 않을 때까지 반복한다. 그림으로 보면 아래와 같다.

여기서 일단 k 값은 2다. 그래서 (b)에서 일단 중심점 2개를 아무 데나 찍고, (c)에서는 각 데이터들을 두 개 점 중 가까운 곳으로 할당한다. (d)에서는 그렇게 군집이 지정된 상태로 중심점을 업데이트 한다. 그리고 (e)에서는 업데이트 된 중심점과 각 데이터들의 거리를 구해서 군집을 다시 할당하는 거다.

이걸 계속 반복~

아무튼 이렇게 군집화를 해놓으면 새로운 데이터가 들어와도 그게 어떤 군집에 속할지 할당해줄 수 있게 되는 셈이다.

Kmeans 클러스터링 코드

전체 코드는 깃허브 링크에서 확인 가능합니다!

장소를 List 형태로 넘기고, 가까운 장소끼리 묶어준다.
아래 장소들은 선릉역과 당산역에 있는 장소 4개씩 선별했다.

선릉역 (마담밍, 위워크, 농민백암순대, 연더 그레이)
당산역 (오피스텔, 다이소, 삼성래미안, 당산서중학교)

Request

//리스트형태로 lat, long, locationName을 넘겨준다.
[
    {"lat": 37.503624, "lon": 127.050337, "locationName": "마담밍"},
    {"lat": 37.503341, "lon": 127.049840, "locationName": "위워크"},
    {"lat": 37.503829, "lon": 127.052980, "locationName": "농민백암순대"},
    {"lat": 37.503261, "lon": 127.050715, "locationName": "연더그레이"},
    {"lat": 37.532918, "lon": 126.900196, "locationName": "오피스텔"},
    {"lat": 37.534216, "lon": 126.900980, "locationName": "다이소"},
    {"lat": 37.532254, "lon": 126.903022, "locationName": "삼성래미안"},
    {"lat": 37.532065, "lon": 126.898589, "locationName": "당산서중학교"}
]

Response

//클러스터링 결과를 돌려준다.
[
    {
        "groupId": 0,
        "clusteringSize": 4,
        "clusteringLocationList": [
            {
                "geoPoint": {
                    "lat": 37.532918,
                    "lon": 126.900196
                },
                "locationName": "오피스텔"
            },
            {
                "geoPoint": {
                    "lat": 37.534216,
                    "lon": 126.90098
                },
                "locationName": "다이소"
            },
            {
                "geoPoint": {
                    "lat": 37.532254,
                    "lon": 126.903022
                },
                "locationName": "삼성래미안"
            },
            {
                "geoPoint": {
                    "lat": 37.532065,
                    "lon": 126.898589
                },
                "locationName": "당산서중학교"
            }
        ]
    },
    {
        "groupId": 1,
        "clusteringSize": 4,
        "clusteringLocationList": [
            {
                "geoPoint": {
                    "lat": 37.503624,
                    "lon": 127.050337
                },
                "locationName": "마담밍"
            },
            {
                "geoPoint": {
                    "lat": 37.503341,
                    "lon": 127.04984
                },
                "locationName": "위워크"
            },
            {
                "geoPoint": {
                    "lat": 37.503829,
                    "lon": 127.05298
                },
                "locationName": "농민백암순대"
            },
            {
                "geoPoint": {
                    "lat": 37.503261,
                    "lon": 127.050715
                },
                "locationName": "연더그레이"
            }
        ]
    }
]

KMEANS 라이브러리를 사용하기 위해서는 build.gradle에 해당 코드를 추가해줘야한다.
감사하게도 이 분이 개발을 해놔주셨다. ((다음에는 KMEANS 로직을 직접 짜는 것을 목표로 해야겠다))

dependencies {
    implementation 'com.github.haifengl:smile-core:2.5.3' // clustering 알고리즘
    }

핵심 서비스 로직

가까운 장소끼리 묶어주는 핵심 service 로직이다.


import com.daria.clustering.dto.ClusteringRequest;
import com.daria.clustering.dto.ClusteringResult;
import com.daria.clustering.dto.GeoPoint;
import com.daria.clustering.exception.CustomIllegalArgumentException;
import com.daria.clustering.kmeans.service.KmeansClusteringService;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import smile.clustering.KMeans;
import smile.clustering.PartitionClustering;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Service
public class KmeansClusteringServiceImpl implements KmeansClusteringService {

    @Override
    public List<ClusteringResult> createClustering(List<ClusteringRequest> clusteringRequestList) {
        if(clusteringRequestList.size() < 2){
            throw CustomIllegalArgumentException.message("현재 좌표는 2개 이상이어야합니다.");
        }
        
        List<GeoPoint> geoPointList =  new ArrayList<>();
        Map<GeoPoint, String> geoPointLocationMap = new HashMap<>();
        
        for(ClusteringRequest clusteringRequest : clusteringRequestList){
            clusteringRequest.validCheck(); //유효값 검사
            GeoPoint geoPoint = GeoPoint.of(clusteringRequest.getLat(), clusteringRequest.getLon());
            geoPointList.add(geoPoint);
            geoPointLocationMap.put(geoPoint, clusteringRequest.getLocationName());
        }
        
        double[][] geoPointArray = getGeoPointArray(geoPointList);

        int groupSize = clusteringRequestList.size()/4;
        KMeans clusters = PartitionClustering.run(20, () -> KMeans.fit(geoPointArray, groupSize < 2 ? 2 : groupSize));
        //kmeans 최소개수를 2개 이하로 할 때 에러가 나기 때문.

        Map<Integer, List<GeoPoint>> groupIdGeoPointMap = new HashMap<>();
        for (int i = 0, yLength = clusters.y.length; i < yLength; i++) {
            int groupId = clusters.y[i];
            GeoPoint geoPoint = geoPointList.get(i);
            if(geoPoint != null){
                groupIdGeoPointMap.computeIfAbsent(groupId, k -> new ArrayList()).add(geoPoint);
            }
        }

        List<ClusteringResult> clusteringResultList = new ArrayList<>();
        for(Map.Entry<Integer, List<GeoPoint>> entry : groupIdGeoPointMap.entrySet()){
            clusteringResultList.add(ClusteringResult.of(entry.getKey(), entry.getValue(), geoPointLocationMap));
        }
        return clusteringResultList;
    }


    private double[][] getGeoPointArray(List<GeoPoint> geoPointList) {
        if (CollectionUtils.isEmpty(geoPointList)) {
            return new double[0][];
        }
        double[][] geoPointArray = new double[geoPointList.size()][];
        int index = 0;
        for (GeoPoint geoPoint : geoPointList) {
            geoPointArray[index++] = new double[]{geoPoint.getLat(), geoPoint.getLon()};
        }
        return geoPointArray;
    }
}

마치며

아직은 request, response 구조 밖에 없는 로직 1탄이지만
지도를 붙여 좌표를 넘기고 마커를 그리는 작업 진행 중이다.
2탄으로.. 가지고 올 예정,,,

많이 하고 있는 물류 쪽 클러스터링의 모습을 갖추게 되는 그 날까지...
계속 해서 클러스터링 스터디를 진행해보려고한다...

할 수 있겠지.. 할 수 있을거야

profile
고양이와 함께 - 끄적끄적 개발하고 이씁니다 ~!

0개의 댓글