CNN ๐ŸŽž๏ธ

์„œ์€์„œยท2023๋…„ 8์›” 15์ผ
0

PyTorch

๋ชฉ๋ก ๋ณด๊ธฐ
3/5
post-thumbnail

CNN

ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง์€ ์ด๋ฏธ์ง€ ์ „์ฒด๋ฅผ ํ•œ ๋ฒˆ์— ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹Œ ์ด๋ฏธ์ง€์˜ ๊ตญ์†Œ์  ๋ถ€๋ถ„์„ ๊ณ„์‚ฐํ•จ์œผ๋กœ์จ ์‹œ๊ฐ„๊ณผ ์ž์›์„ ์ ˆ์•ฝํ•˜์—ฌ ์ด๋ฏธ์ง€์˜ ์„ธ๋ฐ€ํ•œ ๋ถ€๋ถ„๊นŒ์ง€ ๋ถ„์„ํ•  ์ˆ˜ ์žˆ๋Š” ์‹ ๊ฒฝ๋ง์ด๋‹ค.

ํ•ฉ์„ฑ๊ณฑ์ธต์˜ ํ•„์š”์„ฑ

์ด๋ฏธ์ง€ ๋ถ„์„์€ ์•„๋ž˜์˜ ๊ทธ๋ฆผ ์™ผ์ชฝ๊ณผ ๊ฐ™์€ 3x3 ๋ฐฐ์—ด์„ ์˜ค๋ฅธ์ชฝ๊ณผ ๊ฐ™์ด ํŽผ์ณ์„œ ๊ฐ ํ”ฝ์…€์— ๊ฐ€์ค‘์น˜๋ฅผ ๊ณฑํ•ด ์€๋‹‰์ธต์œผ๋กœ ์ „๋‹ฌํ•˜๊ฒŒ ๋œ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๊ทธ๋ฆผ์—์„œ ๋ณด์ด๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ์ด๋ฏธ์ง€๋ฅผ ํŽผ์ณ์„œ ๋ถ„์„ํ•˜๊ฒŒ ๋˜๋ฉด ๋ฐ์ดํ„ฐ์˜ ๊ณต๊ฐ„์  ๊ตฌ์กฐ๋ฅผ ๋ฌด์‹œํ•˜๊ฒŒ ๋œ๋‹ค. ์ด๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ํ•ฉ์„ฑ๊ณฑ์ธต์ด ๋„์ž…๋œ๋‹ค.

CNN ๊ตฌ์กฐ

ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง(Convolutional Neural Network, CNN ๋˜๋Š” ConvNet)์€ ์Œ์„ฑ ์ธ์‹์ด๋‚˜ ์ด๋ฏธ์ง€/์˜์ƒ ์ธ์‹์—์„œ ์ฃผ๋กœ ์‚ฌ์šฉ๋˜๋Š” ์‹ ๊ฒฝ๋ง์ด๋‹ค. ๋‹ค์ฐจ์› ๋ฐฐ์—ด ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋„๋ก ๊ตฌ์„ฑ๋˜์–ด ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€ ๊ฐ™์€ ๋‹ค์ฐจ์› ๋ฐฐ์—ด ์ฒ˜๋ฆฌ์— ํŠนํ™”๋˜์–ด ์žˆ์œผ๋ฉฐ ๋‹ค์„ฏ๊ฐœ์˜ ๊ณ„์ธต์œผ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.

CNN์€ ํ•ฉ์„ฑ๊ณฑ์ธต๊ณผ ํ’€๋ง์ธต์„ ๊ฑฐ์น˜๋ฉด์„œ ์ž…๋ ฅ ์ด๋ฏธ์ง€์˜ ์ฃผ์š” ํŠน์„ฑ ๋ฒกํ„ฐ(feature vector)๋ฅผ ์ถ”์ถœํ•œ๋‹ค. ์ฃผ์š” ํŠน์„ฑ ๋ฒกํ„ฐ๋“ค์€ ์™„์ „์—ฐ๊ฒฐ์ธต์„ ๊ฑฐ์น˜๋ฉด์„œ 1์ฐจ์› ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜๋˜๋ฉฐ, ์ถœ๋ ฅ์ธต์—์„œ ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋ฅผ ๊ฑฐ์ณ ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค.

1๏ธโƒฃ ์ž…๋ ฅ์ธต

๐Ÿ‘‰๐Ÿป ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๊ฐ€ ์ตœ์ดˆ๋กœ ๊ฑฐ์น˜๊ฒŒ ๋˜๋Š” ๊ณ„์ธต

์ด๋ฏธ์ง€๋Š” ๋†’์ด(height), ๋„ˆ๋น„(width), ์ฑ„๋„(channel)์˜ ๊ฐ’์„ ๊ฐ–๋Š” 3์ฐจ์› ๋ฐ์ดํ„ฐ์ด๋‹ค.
์ด๋ฏธ์ง€ ์ฑ„๋„์€ ํ‘๋ฐฑ์ด๋ฉด 1, ์ปฌ๋Ÿฌ์ด๋ฉด 3 ๊ฐ’์„ ๊ฐ–๋Š”๋‹ค.

2๏ธโƒฃ ํ•ฉ์„ฑ๊ณฑ์ธต

๐Ÿ‘‰๐Ÿป ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์—์„œ ํŠน์„ฑ์„ ์ถ”์ถœํ•˜๋Š” ์—ญํ• ์„ ์ˆ˜ํ–‰

์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํŠน์„ฑ์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ์ปค๋„(kernel)์ด๋‚˜ ํ•„ํ„ฐ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

์ปค๋„/ํ•„ํ„ฐ๋Š” ์ด๋ฏธ์ง€์˜ ๋ชจ๋“  ์˜์—ญ์„ ํ›‘์œผ๋ฉฐ ํŠน์„ฑ์„ ์ถ”์ถœํ•˜๋ฉฐ ๋‚˜์˜จ ๊ฒฐ๊ณผ๋ฌผ์„ ํŠน์„ฑ ๋งต(feature map)์ด๋ผ๊ณ  ํ•œ๋‹ค.
โ–ถ๏ธŽ ์ผ๋ฐ˜์ ์œผ๋กœ 3x3,5x5 ํฌ๊ธฐ์˜ ์ปค๋„์„ ์ ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ด๋ฉฐ, ์ŠคํŠธ๋ผ์ด๋“œ(stride)๋ผ๋Š” ์ง€์ •๋œ ๊ฐ„๊ฒฉ์œผ๋กœ ์ด๋™ํ•œ๋‹ค.

์˜ˆ์‹œ 1 (์ฑ„๋„ = 1)

1๋‹จ๊ณ„)์ž…๋ ฅ ์ด๋ฏธ์ง€์— 3x3 ํ•„ํ„ฐ ์ ์šฉ

2๋‹จ๊ณ„) ์ŠคํŠธ๋ผ์ด๋“œ( = 1 )๋งŒํผ ์ด๋™

3๋‹จ๊ณ„) ๋‘๋ฒˆ์งธ ์ด๋™

๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„) ๋งˆ์ง€๋ง‰ ์ด๋™์œผ๋กœ feature map ๋„์ถœ


์˜ˆ์‹œ 2 (์ฑ„๋„ = 3)

์ด์ „ ์˜ˆ์‹œ์™€๋Š” ๋‹ค๋ฅด๊ฒŒ RGB ๊ฐ๊ฐ์— ์„œ๋กœ ๋‹ค๋ฅธ ๊ฐ€์ค‘์น˜๋กœ ํ•ฉ์„ฑ๊ณฑ์„ ์ ์šฉํ•œ ํ›„ ๊ฒฐ๊ณผ๋ฅผ ๋”ํ•ด์ค€๋‹ค.


์˜ˆ์‹œ 3 (ํ•„ํ„ฐ๊ฐ€ ์—ฌ๋Ÿฌ๊ฐœ)

โ–ถ๏ธŽ ์˜ˆ์‹œ์—์„œ๋Š” ํ•„ํ„ฐ๊ฐ€ 4๊ฐœ
  • ์ž…๋ ฅ ๋ฐ์ดํ„ฐ : W x H x D (W : ๊ฐ€๋กœ, H : ์„ธ๋กœ, D : ๊นŠ์ด ๋˜๋Š” ์ฑ„๋„)
  • ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
    • ํ•„ํ„ฐ ๊ฐœ์ˆ˜ : K
    • ํ•„ํ„ฐ ํฌ๊ธฐ : F
    • ์ŠคํŠธ๋ผ์ด๋“œ : S
    • ํŒจ๋”ฉ : P
  • ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ
    • W = (W - F + 2P) / S + 1
    • H = (H - F + 2P) / S + 1
    • D = K
  • ํŒจ๋”ฉ(Padding)
    ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์˜ ๊ฒฐ๊ณผ๋กœ ์–ป์€ ํŠน์„ฑ ๋งต์€ ์ž…๋ ฅ๋ณด๋‹ค ํฌ๊ธฐ๊ฐ€ ์ž‘์•„์ง„๋‹ค๋Š” ํŠน์ง•์ด ์žˆ๋‹ค. ๋งŒ์•ฝ, ํ•ฉ์„ฑ๊ณฑ ์ธต์„ ์—ฌ๋Ÿฌ๊ฐœ ์Œ“์•˜๋‹ค๋ฉด ์ตœ์ข…์ ์œผ๋กœ ์–ป์€ ํŠน์„ฑ ๋งต์€ ์ดˆ๊ธฐ ์ž…๋ ฅ๋ณด๋‹ค ๋งค์šฐ ์ž‘์•„์ง„ ์ƒํƒœ๊ฐ€ ๋˜๊ธฐ ๋•Œ๋ฌธ์— ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ ์ดํ›„์—๋„ ํŠน์„ฑ ๋งต์˜ ํฌ๊ธฐ๊ฐ€ ์ž…๋ ฅ์˜ ํฌ๊ธฐ์™€ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€๋˜๋„๋ก ํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ํŒจ๋”ฉ(padding)์„ ์‚ฌ์šฉํ•˜๋ฉด ๋œ๋‹ค.

3๏ธโƒฃ ํ’€๋ง์ธต

๐Ÿ‘‰๐Ÿป ํŠน์„ฑ ๋งต์˜ ์ฐจ์›์„ ๋‹ค์šด ์ƒ˜ํ”Œ๋งํ•˜์—ฌ ์—ฐ์‚ฐ๋Ÿ‰์„ ๊ฐ์†Œ์‹œํ‚ค๊ณ , ์ฃผ์š”ํ•œ ํŠน์„ฑ ๋ฒกํ„ฐ๋ฅผ ์ถ”์ถœํ•˜์—ฌ ํ•™์Šต์„ ํšจ๊ณผ์ ์œผ๋กœ ํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•œ๋‹ค.

๋‹ค์šด ์ƒ˜ํ”Œ๋ง์ด๋ž€ ์ด๋ฏธ์ง€๋ฅผ ์ถ•์†Œํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

  • ์ตœ๋Œ€ ํ’€๋ง(Max Pooling) : ๋Œ€์ƒ ์˜์—ญ์—์„œ ์ตœ๋Œ“๊ฐ’์„ ์ถ”์ถœ
  • ํ‰๊ท  ํ’€๋ง(Average Pooling) : ๋Œ€์ƒ ์˜์—ญ์—์„œ ํ‰๊ท ์„ ๋ฐ˜ํ™˜

ํ‰๊ท  ํ’€๋ง์€ ๊ฐ ์ปค๋„ ๊ฐ’์„ ํ‰๊ท ํ™”์‹œ์ผœ ์ค‘์š”ํ•œ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ–๋Š” ๊ฐ’์˜ ํŠน์„ฑ์ด ํฌ๋ฏธํ•ด์งˆ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋Œ€๋ถ€๋ถ„์˜ ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง์—์„œ๋Š” ์ตœ๋Œ€ ํ’€๋ง์ด ์‚ฌ์šฉ๋œ๋‹ค.

  • ์ž…๋ ฅ ๋ฐ์ดํ„ฐ : W x H x D
  • ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
    • ํ•„ํ„ฐ ํฌ๊ธฐ : F
    • ์ŠคํŠธ๋ผ์ด๋“œ : S
  • ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ
    • W = (W - F ) / S + 1
    • H = (H - F) / S + 1
    • D = D

4๏ธโƒฃ ์™„์ „์—ฐ๊ฒฐ์ธต

๐Ÿ‘‰๐Ÿป ํ•ฉ์„ฑ๊ณฑ์ธต๊ณผ ํ’€๋ง์ธต์„ ๊ฑฐ์น˜๋ฉด์„œ ์ฐจ์›์ด ์ถ•์†Œ๋œ ํŠน์„ฑ ๋งต์€ ์ตœ์ข…์ ์œผ๋กœ ์™„์ „์—ฐ๊ฒฐ์ธต(fully connected layer)์œผ๋กœ ์ „๋‹ฌ๋œ๋‹ค. ์ด ๊ณผ์ •์—์„œ ์ด๋ฏธ์ง€๋Š” 3์ฐจ์› ๋ฒกํ„ฐ์—์„œ 1์ฐจ์› ๋ฒกํ„ฐ๋กœ ํŽผ์ณ์น˜๊ฒŒ(flatten) ๋œ๋‹ค.

5๏ธโƒฃ ์ถœ๋ ฅ์ธต

๐Ÿ‘‰๐Ÿป ์ถœ๋ ฅ์ธต์—์„œ๋Š” ์†Œํ”„ํŠธ๋งฅ์Šค ํ™œ์„ฑํ™” ํ•จ์ˆ˜๊ฐ€ ์‚ฌ์šฉ๋˜๋Š”๋ฐ ์ž…๋ ฅ๋ฐ›์€ ๊ฐ’์„ 0 ~ 1 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ ์ถœ๋ ฅํ•œ๋‹ค. ๋งˆ์ง€๋ง‰ ์ถœ๋ ฅ์ธต์˜ ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๊ฐ€ ๊ฐ ๋ ˆ์ด๋ธ”(label)์— ์†ํ•  ํ™•๋ฅ  ๊ฐ’์ด ์ถœ๋ ฅ๋˜๋ฉฐ, ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ  ๊ฐ’์„ ๊ฐ–๋Š” ๋ ˆ์ด๋ธ”์ด ์ตœ์ข… ๊ฐ’์œผ๋กœ ์„ ์ •๋œ๋‹ค.

class CNN(nn.Module):
    def __init__(self):
        super(FashionCNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
        self.drop = nn.Dropout2d(0.25)
        self.fc2 = nn.Linear(in_features=600, out_features=120)
        self.fc3 = nn.Linear(in_features=120, out_features=10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

1D, 2D, 3D ํ•ฉ์„ฑ๊ณฑ

1D ํ•ฉ์„ฑ๊ณฑ

ํ•„ํ„ฐ๊ฐ€ ์‹œ๊ฐ„์„ ์ถ•์œผ๋กœ ์ขŒ์šฐ๋กœ๋งŒ ์ด๋™ํ•  ์ˆ˜ ์žˆ๋‹ค.

  • ์ž…๋ ฅ : W
  • ํ•„ํ„ฐ : K
  • ์ถœ๋ ฅ : W

์ถœ๋ ฅ ํ˜•ํƒœ๋Š” 1D์˜ ๋ฐฐ์—ด์ด ๋˜๋ฉฐ, ๊ทธ๋ž˜ํ”„ ๊ณก์„ ์„ ์™„ํ™”ํ•  ๋•Œ ๋งŽ์ด ์‚ฌ์šฉ๋œ๋‹ค.

2D ํ•ฉ์„ฑ๊ณฑ

ํ•„ํ„ฐ๊ฐ€ ๋ฐฉํ–ฅ ๋‘ ๊ฐœ๋กœ ์›€์ง์ด๋Š” ํ˜•ํƒœ์ด๋‹ค.

  • ์ž…๋ ฅ : (W,H)
  • ํ•„ํ„ฐ : (k,k)
  • ์ถœ๋ ฅ : (W,H)

3D ํ•ฉ์„ฑ๊ณฑ

ํ•„ํ„ฐ๊ฐ€ ์›€์ง์ด๋Š” ๋ฐฉํ–ฅ์ด ์„ธ ๊ฐœ ์žˆ๋‹ค. ์ด ๋•Œ, d < L์„ ์œ ์ง€ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.

  • ์ž…๋ ฅ : (W,H,L)
  • ํ•„ํ„ฐ : (k,k,d)
  • ์ถœ๋ ฅ : (W,H,L)

3D ์ž…๋ ฅ์„ ๊ฐ–๋Š” 2D ํ•ฉ์„ฑ๊ณฑ

์ž…๋ ฅ์ด 3D ํ˜•ํƒœ์ž„์—๋„ ์ถœ๋ ฅ์ด 2D ํ–‰๋ ฌ์ผ ๋•Œ '3D ์ž…๋ ฅ์„ ๊ฐ–๋Š” 2D ํ•ฉ์„ฑ๊ณฑ'์ด๋ผ๊ณ  ํ•œ๋‹ค. ํ•„ํ„ฐ์— ๋Œ€ํ•œ ๊ธธ์ด(L)๊ฐ€ ์ž…๋ ฅ ์ฑ„๋„์˜ ๊ธธ์ด(L)์™€ ๊ฐ™์•„์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด์™€ ๊ฐ™์€ ํ•ฉ์„ฑ๊ณฑ ํ˜•ํƒœ๊ฐ€ ๋งŒ๋“ค์–ด์ง„๋‹ค.

  • ์ž…๋ ฅ : (W,H,L)
  • ํ•„ํ„ฐ : (k,k,L)
  • ์ถœ๋ ฅ : (W,H)

1x1 ํ•ฉ์„ฑ๊ณฑ

3D ํ˜•ํƒœ๋กœ ์ž…๋ ฅ๋œ๋‹ค. 1x1 ํ•ฉ์„ฑ๊ณฑ์—์„œ ์ฑ„๋„ ์ˆ˜๋ฅผ ์กฐ์ •ํ•ด์„œ ์—ฐ์‚ฐ๋Ÿ‰์ด ๊ฐ์†Œ๋˜๋Š” ํšจ๊ณผ๊ฐ€ ์žˆ๋‹ค.

  • ์ž…๋ ฅ : (W,H,L)
  • ํ•„ํ„ฐ : (1,1,L)
  • ์ถœ๋ ฅ : (W,H)

์ „์ด ํ•™์Šต

์ „์ด ํ•™์Šต์ด๋ž€ ์ด๋ฏธ์ง€๋„ท์ฒ˜๋Ÿผ ์•„์ฃผ ํฐ ๋ฐ์ดํ„ฐ์…‹์„ ์จ์„œ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ ธ์™€ ์šฐ๋ฆฌ๊ฐ€ ํ•ด๊ฒฐํ•˜๋ ค๋Š” ๊ณผ์ œ์— ๋งž๊ฒŒ ๋ณด์ •ํ•ด์„œ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค. ๋น„๊ต์  ์ ์€ ์ˆ˜์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ง€๊ณ ๋„ ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ๊ณผ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค.

ํŠน์„ฑ ์ถ”์ถœ ๊ธฐ๋ฒ•

ํŠน์„ฑ ์ถ”์ถœ(feature extractor)์€ ImageNet ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์˜จ ํ›„ ๋งˆ์ง€๋ง‰์— ์™„์ „์—ฐ๊ฒฐ์ธต ๋ถ€๋ถ„๋งŒ ์ƒˆ๋กœ ๋งŒ๋“ ๋‹ค. ํ•™์Šตํ•  ๋•Œ๋Š” ๋งˆ์ง€๋ง‰ ์™„์ „์—ฐ๊ฒฐ์ธต(์ถ”์ถœ๋œ ํŠน์„ฑ์„ ์ž…๋ ฅ๋ฐ›์•„ ์ตœ์ข…์ ์œผ๋กœ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํด๋ž˜์Šค๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ถ€๋ถ„)๋งŒ ํ•™์Šตํ•˜๊ณ  ๋‚˜๋จธ์ง€ ๊ณ„์ธต๋“ค์€ ํ•™์Šต๋˜์ง€ ์•Š๋„๋ก ํ•œ๋‹ค.

model = models.resnet18(pretrained=True) #ResNet18๋ชจ๋ธ์„ ์‚ฌ์šฉ, ์‚ฌ์ „ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

for param in model.parameters():
	param.requires_grad = False # ์™„์ „์—ฐ๊ฒฐ์ธต์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค๋งŒ ํ•™์Šต์‹œํ‚ฌ ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— ํ•ฉ์„ฑ๊ณฑ์ธต๊ณผ ํ’€๋ง์ธต์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ๊ณ ์ •

model.fc = nn.Linear(512, 2)
for param in model.fc.parameters(): # ์™„์ „์—ฐ๊ฒฐ์ธต์€ ํ•™์Šต์‹œํ‚ฌ ๊ฒƒ
    param.requires_grad = True

optimizer = torch.optim.Adam(model.fc.parameters())
cost = torch.nn.CrossEntropyLoss()

๋ฏธ์„ธ ์กฐ์ • ๊ธฐ๋ฒ•

๋ฏธ์„ธ ์กฐ์ •(fine-tuning) ๊ธฐ๋ฒ•์€ ํŠน์„ฑ ์ถ”์ถœ ๊ธฐ๋ฒ•์—์„œ ๋” ๋‚˜์•„๊ฐ€ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ํ•ฉ์„ฑ๊ณฑ์ธต, ๋ฐ์ดํ„ฐ ๋ถ„๋ฅ˜๊ธฐ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์—…๋ฐ์ดํŠธํ•˜์—ฌ ํ›ˆ๋ จ์‹œํ‚ค๋Š” ๋ฐฉ์‹์ด๋‹ค. ์‚ฌ์ „ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ๋ชฉ์ ์— ๋งž๊ฒŒ ์žฌํ•™์Šต์‹œํ‚ค๊ฑฐ๋‚˜ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜์˜ ์ผ๋ถ€๋ฅผ ์žฌํ•™์Šต์‹œํ‚ค๋Š” ๊ฒƒ์ด๋ฉฐ ๋ฐ์ดํ„ฐ์…‹์— ์ž˜ ๋งž๋„๋ก ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ์ •ํ•˜๋Š” ๊ธฐ๋ฒ•์ด๋‹ค.

๋ฏธ์„ธ ์กฐ์ • ๊ธฐ๋ฒ•์€ ํ›ˆ๋ จ์‹œํ‚ค๋ ค๋Š” ๋ฐ์ดํ„ฐ์…‹์˜ ํฌ๊ธฐ์™€ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ์ „๋žต์„ ์„ธ์šธ ์ˆ˜ ์žˆ๋‹ค.

  • ๋ฐ์ดํ„ฐ์…‹์ด ํฌ๊ณ  ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ์œ ์‚ฌ์„ฑ์ด ์ž‘์„ ๊ฒฝ์šฐ
    ๋ชจ๋ธ ์ „์ฒด๋ฅผ ์žฌํ•™์Šต์‹œํ‚จ๋‹ค.
  • ๋ฐ์ดํ„ฐ์…‹์ด ํฌ๊ณ  ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ์œ ์‚ฌ์„ฑ์ด ํ‹€ ๊ฒฝ์šฐ
    ํ•ฉ์„ฑ๊ณฑ์ธต์˜ ๋’ฌ๋ถ€๋ถ„๊ณผ ๋ฐ์ดํ„ฐ ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ํ•™์Šต์‹œํ‚จ๋‹ค.
  • ๋ฐ์ดํ„ฐ์…‹์ด ์ž‘๊ณ  ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ์œ ์‚ฌ์„ฑ์ด ์ž‘์„ ๊ฒฝ์šฐ
    ํ•ฉ์„ฑ๊ณฑ์ธต์˜ ์ผ๋ถ€๋ถ„๊ณผ ๋ฐ์ดํ„ฐ ๋ถ„๋ฅ˜๊ธฐ๋ฅผ ํ•™์Šต์‹œํ‚จ๋‹ค.
  • ๋ฐ์ดํ„ฐ์…‹์ด ์ž‘๊ณ  ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ์œ ์‚ฌ์„ฑ์ด ํด ๊ฒฝ์šฐ
    ๋ฐ์ดํ„ฐ ๋ถ„๋ฅ˜๊ธฐ๋งŒ ํ•™์Šต์‹œํ‚จ๋‹ค. ๋ฐ์ดํ„ฐ๊ฐ€ ์ ๊ธฐ ๋•Œ๋ฌธ์— ๋งŽ์€ ๊ณ„์ธต์— ๋ฏธ์„ธ ์กฐ์ • ๊ธฐ๋ฒ•์„ ์ ์šฉํ•˜๋ฉด ๊ณผ์ ํ•ฉ์ด ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋‹ค.

์„ค๋ช… ๊ฐ€๋Šฅํ•œ CNN

์„ค๋ช… ๊ฐ€๋Šฅํ•œ CNN์€ ๋”ฅ๋Ÿฌ๋‹ ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ๋ฅผ ์‚ฌ๋žŒ์ด ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•˜๋Š” ๊ธฐ์ˆ ์ด๋‹ค.

ํŠน์„ฑ ๋งต ์‹œ๊ฐํ™”

ํŠน์„ฑ ๋งต์€ ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๋˜๋Š” ๋‹ค๋ฅธ ํŠน์„ฑ ๋งต์ฒ˜๋Ÿผ ํ•„ํ„ฐ๋ฅผ ์ž…๋ ฅ์— ์ ์šฉํ•œ ๊ฒฐ๊ณผ์ด๋‹ค. ํŠน์„ฑ ์ž…๋ ฅ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํŠน์„ฑ ๋งต์„ ์‹œ๊ฐํ™”ํ•œ๋‹ค๋Š” ์˜๋ฏธ๋Š” ํŠน์„ฑ ๋งต์—์„œ ์ž…๋ ฅ ํŠน์„ฑ์„ ๊ฐ์ง€ํ•œ ๋ฐฉ๋ฒ•์„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก ๋•๋Š” ๊ฒƒ์ด๋‹ค.
์„ค๋ช… ๊ฐ€๋Šฅํ•œ CNN [Colab]

๊ทธ๋ž˜ํ”„ ํ•ฉ์„ฑ๊ณฑ ๋„คํŠธ์›Œํฌ

๊ทธ๋ž˜ํ”„ ํ•ฉ์„ฑ๊ณฑ ๋„คํŠธ์›Œํฌ๋Š” ๊ทธ๋ž˜ํ”„ ๋ฐ์ดํ„ฐ๋ฅผ ์œ„ํ•œ ์‹ ๊ฒฝ๋ง์ด๋‹ค.

๊ทธ๋ž˜ํ”„๋Š” ๋ฐฉํ–ฅ์„ฑ์ด ์žˆ๊ฑฐ๋‚˜ ์—†๋Š” edge๋กœ ์—ฐ๊ฒฐ๋œ node์˜ ์ง‘ํ•ฉ์ด๋‹ค.

๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง

๊ทธ๋ž˜ํ”„ ์‹ ๊ฒฝ๋ง(Graph Neural Network, GNN)์€ ๊ทธ๋ž˜ํ”„ ๊ตฌ์กฐ์—์„œ ์‚ฌ์šฉํ•˜๋Š” ์‹ ๊ฒฝ๋ง์ด๋‹ค.

1๋‹จ๊ณ„) ์ธ์ ‘ ํ–‰๋ ฌ

  • ๋…ธ๋“œ n๊ฐœ๋ฅผ nxn ํ–‰๋ ฌ๋กœ ํ‘œํ˜„ํ•œ๋‹ค. (์ธ์ ‘ ํ–‰๋ ฌ)
  • ์ž…์ ‘ ํ–‰๋ ฌ ๋‚ด์˜ ๊ฐ’์€ i์™€ j์˜ ๊ด€๋ จ์„ฑ ์—ฌ๋ถ€๋ฅผ ๋งŒ์กฑํ•˜๋Š” ๊ฐ’์œผ๋กœ ์ฑ„์›Œ์ค€๋‹ค.

2๋‹จ๊ณ„) ํŠน์„ฑ ํ–‰๋ ฌ

์ธ์ ‘ ํ–‰๋ ฌ๋งŒ์œผ๋กœ๋Š” ํŠน์„ฑ์„ ํŒŒ์•…ํ•˜๊ธฐ ์–ด๋ ต๊ธฐ ๋•Œ๋ฌธ์— ๋‹จ์œ„ ํ–‰๋ ฌ์„ ์ ์šฉํ•œ๋‹ค.
  • ๊ฐ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์—์„œ ์ด์šฉํ•  ํŠน์„ฑ์„ ์„ ํƒํ•œ๋‹ค.
  • ํŠน์„ฑ ํ–‰๋ ฌ์—์„œ ๊ฐ ํ–‰์€ ์„ ํƒ๋œ ํŠน์„ฑ์— ๋Œ€ํ•ด ๊ฐ ๋…ธ๋“œ๊ฐ€ ๊ฐ–๋Š” ๊ฐ’์„ ์˜๋ฏธํ•œ๋‹ค.

๊ทธ๋ž˜ํ”„ ํ•ฉ์„ฑ๊ณฑ ๋„คํŠธ์›Œํฌ

๊ทธ๋ž˜ํ”„ ํ•ฉ์„ฑ๊ณฑ ๋„คํŠธ์›Œํฌ(Graph Convolutional Network, GCN)๋Š” ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํ•ฉ์„ฑ๊ณฑ์„ ๊ทธ๋ž˜ํ”„ ๋ฐ์ดํ„ฐ๋กœ ํ™•์žฅํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด๋‹ค.


์ถœ์ฒ˜
profile
๋‚ด์ผ์˜ ๋‚˜๋Š” ์˜ค๋Š˜๋ณด๋‹ค ๋” ๋‚˜์•„์ง€๊ธฐ๋ฅผ :D

1๊ฐœ์˜ ๋Œ“๊ธ€

comment-user-thumbnail
2023๋…„ 8์›” 15์ผ

์ •๋ฆฌ๊ฐ€ ์ž˜ ๋œ ๊ธ€์ด๋„ค์š”. ๋„์›€์ด ๋์Šต๋‹ˆ๋‹ค.

๋‹ต๊ธ€ ๋‹ฌ๊ธฐ