numpy API들의 axis, keepdims 인자

김상윤·2023년 3월 1일
0

행렬 ndarray에서의 axis, keepdims

import numpy as np

a = np.arange(5) [0 1 2 3 4]
print("np.sum: ", np.sum(a)) np.sum: 10
print("ndarray.sum: ",a.sum()) ndarray.sum : 10

Axis

어떤 축을 기준으로 연산을 진행할 것인지 나타낸다.
sum_ = a.sum(axis=0)

흔한 오해 : axis 0을 기준으로 더했으면 axis 0이 남겠네?

아니다! 오히려 axis 0의 차원이 사라진다. 그 이유는 axis 0에 있는 원소들끼리 연산 때문에 남아있지 않는 것이다.

(3,4) axis = 0 -> (4,) axis = 1 -> (3,)

import numpy as np

a = np.arange(12).reshape((3,-1)) ->(4,3)
sum_class = np.sum(a, axis =0) -> (4,) 과목별 합!
sum_student = np.sum(a, axis =1) -> (3,) 학생의 점수합

sum_class는 broadcasting이 가능하지만 sum_student는 broadcasting이 불가하다. Broadcasting은 높은 차원의 Matrix의 안쪽 차원을 기준으로 이루어지기 때문에 (4,)만 가능하다. sum_student를 broadcasting 하기 위해서는 (3,) 을 (3,1)로 바꿔주어야 한다. (3,1)은 (4,3)과 차원이 같고 1이라는 차원값이 있기 때문에 broadcasting 가능함.

차원이 줄어들지 않기를 원한다면?
Keepdims 옵션을 사용해보자
sum_class = np.sum(a, axis = 0, keepdims = True)

3차원 행렬에 대한 axis, keepdims

profile
AI 대학원 지망생

0개의 댓글