[CUDA 공부] 1-1. elementwise.cu

NJ·2025년 8월 18일
0

CUDA공부

목록 보기
1/4

CUDA 공부를 시작한다.
우선 보고 있는 자료는 LeetCUDA (https://github.com/xlite-dev/LeetCUDA) 와 <CUDA 기반 GPU 병렬 처리 프로그래밍>이라는 책.

책은 이제 차근차근 읽어나갈 거고, 우선 LeetCUDA의 kernel 코드를 차근차근 따라가며, 모르는 부분은 GPT에게 물어가며 공부하는 게시글이 될 것 같다.

우선 첫 번 째는!!
https://github.com/xlite-dev/LeetCUDA/blob/main/kernels/elementwise/elementwise.cu <- 이 코드!

아예 아무 것도 몰라서 진짜 코드 한 줄 한 줄 따라가며 GPT에게 계속 물어봄.
나중을 (나중의 나를) 위하여 뭔가를 남겨보기로 함!


헤더 파일

#include <algorithm>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <float.h>
#include <stdio.h>
#include <stdlib.h>
#include <torch/extension.h>
#include <torch/types.h>
#include <vector>

우선 헤더 파일들을 include 하는 부분은 그냥 그렇구나... 하고 넘어감.
헤더까지 살피기엔 아직 너무 초짜임! (그도 그럴 것이 원래 stdio.h, stdlib.h는 그냥 이게 있으니 된다! 하고 넘어가는 거자낭?)

define

WARP

#define WARP_SIZE 32

WARP_SIZE를 32로 지정했는데 WARP가 뭘까...
WARP는 GPU 병렬 처리의 가장 기본 단위라고 한다.

Warp는 32개의 쓰레드로 구성된 실행 단위입니다.
CUDA GPU는 쓰레드를 하나씩 실행하지 않고, 한 번에 32개 쓰레드를 묶어서 동시에 실행합니다. 이 32개의 쓰레드 묶음을 warp라고 합니다.

🧠 왜 Warp 단위로 실행할까?

📌 이유: SIMT (Single Instruction, Multiple Threads) 구조 때문

GPU는 일반 CPU와 달리 같은 명령어를 여러 데이터에 대해 동시에 실행하는 방식으로 효율을 극대화합니다.

  • CPU: 각 코어가 각기 다른 일을 할 수 있음
  • GPU: 여러 쓰레드가 같은 명령어를 공유하며 동시에 실행됨 → 성능 향상

CUDA에서는 각 warp 내의 32개 쓰레드가 동시에 같은 instruction을 실행합니다.
이를 SIMT라고 하며, warp가 실행 단위가 됩니다.

그렇군! CUDA의 기본 실행 단위! 32개의 쓰레드로 구성된 실행 단위이군!

reinterpret_cast

#define INT4(value) (reinterpret_cast<int4 *>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4 *>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2 *>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])

타입 캐스팅 define~
이 매크로들은 데이터를 벡터 타입으로 reinterpret하여 메모리 접근 성능을 향상시키기 위한 것!

✅ 이 매크로들을 쓰는 이유는?

메모리 병합(coalescing) 최적화: GPU에서는 연속된 쓰레드들이 연속된 주소에 접근할 때 한 번의 메모리 트랜잭션으로 처리하면 빠릅니다.

  • 128비트 단위 load/store: float4, half2 등을 이용하면 메모리 IO 성능이 상승함

reinterpret_cast<new_type>(value) 이 문법 자체가 타입 정보를 무시하고 메모리 주소만 강제로 해석하는 용도로 쓰이며, 보통 매우 저수준의 메모리 접근이나 하드웨어 최적화에서 사용된다고 함. reinterpret_cast는 C++ 언어 자체에 포함된 키워드라고 함.

이 부분 관련해서 GPT에 물어보면 메모리 접근 관련 정보/지식들을 얻을 수 있다.

여기에 관해서 티키타카 티키타카 해본 후 내가 이해한 바로는: 한 번에 더 많이 읽어서 메모리 I/O를 줄이기! <- 이것이 목적인듯 하다.

  • 일반적으로 128비트(16B)에 최적화 되어 있음.

어쨌든 이렇게 벡터화해서 읽는 게 유리함: 데이터를 여러 개 단위로 묶어 한 번에 처리하거나, 한 명령어로 여러 데이터를 동시에 처리하는 최적화 기법!
즉, float value 하나 하나... 처리하는 게 아니고, 128bit까지 마! 가져온나!! 팍팍! 한 번에 128bit 처리한닷!! 요런 느낌.

elementwise add kernel (f32_kernel)

// FP32
// ElementWise Add grid(N/256),
// block(256) a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
__global__ void elementwise_add_f32_kernel(float *a, float *b, float *c,
                                           int N) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N)
    c[idx] = a[idx] + b[idx];
}
  • __global__ CUDA 커널 함수임을 나타내는 키워드
  • int N 배열 길이

너무나도 당연하게 a, b를 더해서 c에 저장하는 함수임.

  • int idx blockIdx.x * blockDim.x + threadIdx.x;
  • 위의 코드는 현재 쓰레드가 처리할 인덱스를 계산하는 부분.
  • CUDA는 thread를 block 단위로 구성하고, block은 grid 단위로 구성되어 있다고 함.
  • 그래서 전체 글로벌 인덱스를 계산하려면 위처럼 계산해야 한다고..!!
    • blockIdx.x: 현재 블록의 x축 인덱스
    • blockDim.x: 블록당 쓰레드 수
    • threadIdx.x: 현재 블록 내에서의 쓰레드 인덱스
    • idx: 전체 배열 상에서 이 쓰레드가 담당할 위치

음... 이렇게 하나하나 위치를 계산해줘야 하는군.

너무나도 오랜만에 맡는 C의 향기.
그동안 너무 쉽게 코딩해왔나봐..!!

당연히 if 문은 배열 경계 보호하는 if문.

elementwise add kernel (f32x4_kernel)

// ElementWise Add + Vec4
// grid(N/256), block(256/4)
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
__global__ void elementwise_add_f32x4_kernel(float *a, float *b, float *c,
                                             int N) {
  int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
  if (idx < N) {
    float4 reg_a = FLOAT4(a[idx]);
    float4 reg_b = FLOAT4(b[idx]);
    float4 reg_c;
    reg_c.x = reg_a.x + reg_b.x;
    reg_c.y = reg_a.y + reg_b.y;
    reg_c.z = reg_a.z + reg_b.z;
    reg_c.w = reg_a.w + reg_b.w;
    FLOAT4(c[idx]) = reg_c;
  }
}

위와 조금 달라진 건 4개를 한 꺼번에 처리하고자 하는 이 욕망!
그래서 idx를 계산할 때도 4개 단위로 계산해줌.

위와 조금 다른 건, 로딩 부분!
4개를 한 번에 로딩해준다! a 배열에 있던 float 4개 일루와! b 배열에 있던 float 4개 일루와!
그리고 .x, .y, .z, .w로 한땀 한땀 더해주고. 마지막도 FLOAT4로 매크로 처리해서 이 4개의 값을 한.꺼.번.에 저장해줌.

여기서 이 xyzw는 어디서 나온거냐 하면: float4라는 것은 CUDA에서 제공하는 벡터 타입.
내부적으로 다음과 같이 정의되어 있다

struct __device_builtin__ __builtin_align__(16) float4 { float x, y, z, w; };

아하! 이래서 xyzw로 접근할 수 있구나!

elementwise add kernel (f16x8_kernel)

__global__ void elementwise_add_f16x8_kernel(half *a, half *b, half *c, int N) {
  int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
  half2 reg_a_0 = HALF2(a[idx + 0]);
  half2 reg_a_1 = HALF2(a[idx + 2]);
  half2 reg_a_2 = HALF2(a[idx + 4]);
  half2 reg_a_3 = HALF2(a[idx + 6]);
  half2 reg_b_0 = HALF2(b[idx + 0]);
  half2 reg_b_1 = HALF2(b[idx + 2]);
  half2 reg_b_2 = HALF2(b[idx + 4]);
  half2 reg_b_3 = HALF2(b[idx + 6]);
  half2 reg_c_0, reg_c_1, reg_c_2, reg_c_3;
  reg_c_0.x = __hadd(reg_a_0.x, reg_b_0.x);
  reg_c_0.y = __hadd(reg_a_0.y, reg_b_0.y);
  reg_c_1.x = __hadd(reg_a_1.x, reg_b_1.x);
  reg_c_1.y = __hadd(reg_a_1.y, reg_b_1.y);
  reg_c_2.x = __hadd(reg_a_2.x, reg_b_2.x);
  reg_c_2.y = __hadd(reg_a_2.y, reg_b_2.y);
  reg_c_3.x = __hadd(reg_a_3.x, reg_b_3.x);
  reg_c_3.y = __hadd(reg_a_3.y, reg_b_3.y);
  if ((idx + 0) < N) {
    HALF2(c[idx + 0]) = reg_c_0;
  }
  if ((idx + 2) < N) {
    HALF2(c[idx + 2]) = reg_c_1;
  }
  if ((idx + 4) < N) {
    HALF2(c[idx + 4]) = reg_c_2;
  }
  if ((idx + 6) < N) {
    HALF2(c[idx + 6]) = reg_c_3;
  }
}

자 바로 half 드간다.
__hadd 이거는 CUDA에서 제공하는 half 타입 전용 덧셈 함수. 위 코드들과 같이 그냥 + 연산자로는 덧셈이 안 되는 듯 하다. half용 덧셈 연산자(function)~

HALF2는 half 두 개라서 4바이트(32bit)가 된다.
.x, .y로 더하는 부분도 위에서 쭈욱 내려왔다면 대충 이해갈 거라 생각.
여기서 궁금한 건 왜 128bit 씩 처리하지 않는가?
내 생각엔 아마도 다음 커널과 비교하기 위해 있는 것 같음.

elementwise add kernel (f16x8_pack_kernel)

__global__ void elementwise_add_f16x8_pack_kernel(half *a, half *b, half *c,
                                                  int N) {
  int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
  
  // temporary register(memory), .local space in ptx, addressable
  half pack_a[8], pack_b[8], pack_c[8]; // 8x16 bits=128 bits.
  
  // reinterpret as float4 and load 128 bits in 1 memory issue.
  LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits
  LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits

#pragma unroll
  for (int i = 0; i < 8; i += 2) {
    // __hadd2 for half2 x 4
    HALF2(pack_c[i]) = __hadd2(HALF2(pack_a[i]), HALF2(pack_b[i]));
  }
  // reinterpret as float4 and store 128 bits in 1 memory issue.
  if ((idx + 7) < N) {
    LDST128BITS(c[idx]) = LDST128BITS(pack_c[0]);
  }
}

팩!!! 패킹!!!
LDST128BITS 라는 아이가 나와버렸다! 이 아이는... 128 bit씩 접근하는 아이!

🧠 핵심 목적

128비트 패킹된 half[8] 배열을 사용해

  • GPU memory I/O 횟수 줄이고 (1번에 128비트)
  • half2 기반 병렬 연산(__hadd2)으로 연산량도 최적화

인덱스 계산은: half 8개씩 처리하겠다! (=128 bits)

그리고, 로컬 레지스터 공간에 half[8]씩 메모리를 잡고 이 배열을 float4로 interprete해서 128 비트 단위로 load/store.
이렇게 버퍼를 만드는 이유를 잘 모르겠어서 GPT에게 계속 꼬치꼬치 물어봄.
메모리 I/O 최소화 + 벡터화 연산을 깔끔하게 하기 위해서라고 함. LDST128BITS를 통해 일괄 load/store가 가능.
만약 로컬 버퍼 없이 HALF2(a[idx]), HALF2(a[idx+2])… 처럼 바로 글로벌에서 뽑아 쓰면 반복적으로 글로벌 메모리 접근이 발생할 수 있다고 함.
그러니까, 내가 이해하기로는 로컬 메모리로 들고 와야 더 빠르게 연산이 가능하니까! (마치 L1, L2 cache 처럼. 더 가까운 메모리로 이주시키기.)
음. 이건 이해할 수 있어.
이렇게 하지 않으면 바로 위의 코드(f16x8_kernel)와 같은 형태가 되는 거겠지.
오키. 우선 넘어가자.
a, b 에 있는 아이들을 128 bits로 가져와서 pack_a, pack_b에 담아주고.

#pragma unroll (= 이 루프 짧으니까 컴파일 타임에 미리 다 풀어서 반복문 없이 실행해줘)

처음 보는 아이.
이 아이는 CUDA에서 루프 전개(loop unrolling)를 컴파일러에 지시하는 명령!

#pragma unroll
for (int i = 0; i < 4; i++) {
    sum += arr[i];
}

은 컴파일 후에 아래처럼 변환됨

sum += arr[0];
sum += arr[1];
sum += arr[2];
sum += arr[3];

🔹 CUDA에서 중요한 이유!

  • GPU는 SIMT 구조라서 루프 제어보다는 계속된 연산 명령이 pipeline에 쭉 이어지는 게 유리함
  • 특히 여기처럼 4번만 도는 작은 루프는 unroll하면 실행 흐름이 단순해짐
  • 레지스터 사용 패턴과 메모리 접근을 컴파일러가 더 잘 최적화 가능

🔹 여기서 적용된 이유

#pragma unroll
for (int i = 0; i < 8; i += 2) {
    HALF2(pack_c[i]) = __hadd2(HALF2(pack_a[i]), HALF2(pack_b[i]));
}

위의 코드는 컴파일을 하면 아래처럼 바뀜

HALF2(pack_c[0]) = __hadd2(HALF2(pack_a[0]), HALF2(pack_b[0]));
HALF2(pack_c[2]) = __hadd2(HALF2(pack_a[2]), HALF2(pack_b[2]));
HALF2(pack_c[4]) = __hadd2(HALF2(pack_a[4]), HALF2(pack_b[4]));
HALF2(pack_c[6]) = __hadd2(HALF2(pack_a[6]), HALF2(pack_b[6]));
  • 루프 분기 오버헤드가 사라지고, 명령이 연속적으로 나열되니까 GPU warp 실행 효율이 올라감

휴 이렇게 간단한 elementwise add하는 kernel code를 살펴보았다.
되게 간단한데 모르는 게 너무 많아서 하나하나 물어보고 이해해가면서 하느라 오래 걸렸음!

이렇게 커널 코드 이후에 매크로 코드가 나온다.
PyTorch C++/CUDA 확장 바인딩과 커널 런처(wrapper) 생성을 매크로로 자동화한 부분!

간단 매크로

#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func)                                   \
  m.def(STRINGFY(func), &func, STRINGFY(func));

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32)
  TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32x4)
  TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16)
  TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x2)
  TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x8)
  TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x8_pack)
}

#define CHECK_TORCH_TENSOR_DTYPE(T, th_type)                                   \
  if (((T).options().dtype() != (th_type))) {                                  \
    std::cout << "Tensor Info:" << (T).options() << std::endl;                 \
    throw std::runtime_error("values must be " #th_type);                      \
  }
  
  • 토큰을 문자열로 바꾸는 STRINGFY
    • 예) STRINGFY(elementwise_add_f32) → "elementwise_add_f32"
  • TORCH_BINDING_COMMON_EXTENSION(func)은 pybind11로 C++ 함수 func를 그 이름 그대로 python에 등록
    • 예) TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32) →
      m.def("elementwise_add_f32", &elementwise_add_f32, "elementwise_add_f32");
  • pybind11 모듈 정의
    • python에서 lib.elementwise_add_f16x8_pack(a, b, c) 처럼 직접 호출 가능.
    • 빌드 시 지정한 이름으로 python 모듈이 만들어지고, 위에서 생성된 6개의 래퍼 함수가 동일한 이름으로 python에 노출됨.
  • dtype 체크
    • 텐서 T의 dtype이 기대 타입과 다르면 에러를 던짐.

핵심: 커널 런처(wrapper) 생성 매크로

#define TORCH_BINDING_ELEM_ADD(packed_type, th_type, element_type, n_elements) \
  void elementwise_add_##packed_type(torch::Tensor a, torch::Tensor b,         \
                                     torch::Tensor c) {                        \
    CHECK_TORCH_TENSOR_DTYPE(a, (th_type))                                     \
    CHECK_TORCH_TENSOR_DTYPE(b, (th_type))                                     \
    CHECK_TORCH_TENSOR_DTYPE(c, (th_type))                                     \
    const int ndim = a.dim();                                                  \
    if (ndim != 2) {                                                           \
      int N = 1;                                                               \
      for (int i = 0; i < ndim; ++i) {                                         \
        N *= a.size(i);                                                        \
      }                                                                        \
      dim3 block(256 / (n_elements));                                          \
      dim3 grid((N + 256 - 1) / 256);                                          \
      elementwise_add_##packed_type##_kernel<<<grid, block>>>(                 \
          reinterpret_cast<element_type *>(a.data_ptr()),                      \
          reinterpret_cast<element_type *>(b.data_ptr()),                      \
          reinterpret_cast<element_type *>(c.data_ptr()), N);                  \
    } else {                                                                   \
      const int S = a.size(0);                                                 \
      const int K = a.size(1);                                                 \
      const int N = S * K;                                                     \
      if ((K / (n_elements)) <= 1024) {                                        \
        dim3 block(K / (n_elements));                                          \
        dim3 grid(S);                                                          \
        elementwise_add_##packed_type##_kernel<<<grid, block>>>(               \
            reinterpret_cast<element_type *>(a.data_ptr()),                    \
            reinterpret_cast<element_type *>(b.data_ptr()),                    \
            reinterpret_cast<element_type *>(c.data_ptr()), N);                \
      } else {                                                                 \
        int N = 1;                                                             \
        for (int i = 0; i < ndim; ++i) {                                       \
          N *= a.size(i);                                                      \
        }                                                                      \
        dim3 block(256 / (n_elements));                                        \
        dim3 grid((N + 256 - 1) / 256);                                        \
        elementwise_add_##packed_type##_kernel<<<grid, block>>>(               \
            reinterpret_cast<element_type *>(a.data_ptr()),                    \
            reinterpret_cast<element_type *>(b.data_ptr()),                    \
            reinterpret_cast<element_type *>(c.data_ptr()), N);                \
      }                                                                        \
    }                                                                          \
  }

TORCH_BINDING_ELEM_ADD(f32, torch::kFloat32, float, 1)
TORCH_BINDING_ELEM_ADD(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_ELEM_ADD(f16, torch::kHalf, half, 1)
TORCH_BINDING_ELEM_ADD(f16x2, torch::kHalf, half, 2)
TORCH_BINDING_ELEM_ADD(f16x8, torch::kHalf, half, 8)
TORCH_BINDING_ELEM_ADD(f16x8_pack, torch::kHalf, half, 8)

TORCH_BINDING_ELEM_ADD(packed_type, th_type, element_type, n_elements)

  • 이 매크로는 PyTorch에서 호출할 래퍼 함수를 만든다. 내부에서:
    1. 입력/출력 텐서의 dtype 검사
    2. 텐서 차원에 따라 그리드/블록 계산
    3. 원시 포인터로 캐스팅 후 해당 CUDA 커널 호출

만들어지는 함수 시그니처

void elementwise_add_<packed_type>(torch::Tensor a, torch::Tensor b, torch::Tensor c)

예) packed_type=f16x8_pack → elementwise_add_f16x8_pack(...)

내부 로직 요약

  • ndim = a.dim()을 보고 분기:

    • 비(非) 2D: 전체 요소 수 N = ∏ a.size(i)로 평탄화해서 처리

      block = 256 / n_elements
      grid  = (N + 256 - 1) / 256
    • 2D (S×K): 먼저 “행 단위” 실행을 시도

      • 한 쓰레드가 n_elements개(예: float4=4, half2=2, half8=8)를 처리하므로, block.x = K / n_elements를 잡습니다.
      • 이 값이 1024 이하(스레드 최대)면 grid = S, block = K / n_elements로 설정 → 각 블록이 한 행을 처리(메모리 접근 정합성이 좋아 성능 유리).
      • 그렇지 않으면 다시 평탄화 일반 경로(위의 256 스레드 기준)로 폴백.
  • 커널 호출:

elementwise_add_<packed_type>_kernel<<<grid, block>>>(
    reinterpret_cast<element_type*>(a.data_ptr()),
    reinterpret_cast<element_type*>(b.data_ptr()),
    reinterpret_cast<element_type*>(c.data_ptr()),
    N);
  • element_type은 커널이 기대하는 원소 타입: float/half 등
  • n_elements는 벡터 패킹 폭: 1, 2, 4, 8 (float4→4, half2→2, half8→8 같은 개념)

요컨대, 같은 커널 패턴을 패킹 폭에 따라 재사용하고, 텐서가 2D이면 행 단위 실행으로 최적화를 먼저 시도하는 런처를 매크로 한 방으로 생성.

  • 그리고 TORCH_BINDING_ELEM_ADD를 불러서 6개의 래퍼 함수 생성
    • 예) elementwise_add_f32x4(a,b,c)는 내부에서 elementwise_add_f32x4_kernel을 호출하고, block/grid는 n_elements=4 기준으로 설정.

마지막 커널 런처 생성 매크로 부분은 열심히 이해 안 하고 대충 건너뜀... 나중에 다시 꼼꼼히 살펴보자!
휴, 우선 이렇게 아주아주 간단 elementwise add 커널 코드 읽어봄!
그 다음 이 코드를 부르는 python 코드도 살펴보고 실행 결과도 보려고 한다!
(흑흑.. ㅜㅜ 빡세...)

그 다음엔 이제 하나하나.. histogram도 보고, 책도 읽고 해야 하는데.
갈 길이 멀게 느껴지지만 화이팅 하자!
우선 책 한 번 읽고, histogram까지도 좀 파악하고 나면 앞으로 공부 어떻게 해야할지 감이 좀 잡혔으면 좋겠음! 공부 스케줄도!

profile
Studying NLP

0개의 댓글