서로 다른 model에서 optimizer를 공유해서 쓸 수 있을까?

Develop My Life·2023년 1월 26일
0

PyTorch

목록 보기
2/5
post-thumbnail

배경

Global structured pruning을 진행하면서 Dense model과 Pruned model을 각각 학습해야할 일이 생겼다. Pruned model을 학습할 때 기존 Dense model의 optimizer와 scheduler를 가져와서 그대로 학습에 사용하려고 하였는데 학습이 진행되지 않는 현상을 발견했다. 심지어 learning rate를 0으로 하여도 accuracy가 달라지는 것을 확인하였다.

정답

안된다!!!!!

이유

from torch.optim as optim


optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=90)

위와 같이 optimizer를 선언할 때 model.parameters() 인자가 들어가기 때문에 모델의 구조가 반영된다. 특히 optimizer.state_dict()['state']를 살펴보면 각 layer의 차원에 맞추어 'momentum_buffer'가 생성되며 이는 parameter update시 사용된다. 따라서 optimizer를 선언할 때 Dense model의 paramter를 활용하여 선언했다면 Pruned model의 파라미터가 업데이트 되는 것이 아니라 Dense model의 파라미터가 업데이트 되는 것이다.

resnet50 optimizer.SGD.state_dict()['state']

0 : torch.Size([64, 3, 3, 3])
1 : torch.Size([64])
2 : torch.Size([64])
3 : torch.Size([64, 64, 1, 1])
4 : torch.Size([64])
5 : torch.Size([64])
6 : torch.Size([64, 64, 3, 3])
7 : torch.Size([64])
8 : torch.Size([64])
9 : torch.Size([256, 64, 1, 1])
10 : torch.Size([256])
11 : torch.Size([256])
12 : torch.Size([256, 64, 1, 1])
13 : torch.Size([256])
14 : torch.Size([256])
15 : torch.Size([64, 256, 1, 1])
16 : torch.Size([64])
17 : torch.Size([64])
18 : torch.Size([64, 64, 3, 3])
19 : torch.Size([64])
20 : torch.Size([64])
21 : torch.Size([256, 64, 1, 1])
22 : torch.Size([256])
23 : torch.Size([256])
24 : torch.Size([64, 256, 1, 1])
25 : torch.Size([64])
26 : torch.Size([64])
27 : torch.Size([64, 64, 3, 3])
28 : torch.Size([64])
29 : torch.Size([64])
30 : torch.Size([256, 64, 1, 1])
31 : torch.Size([256])
32 : torch.Size([256])
33 : torch.Size([128, 256, 1, 1])
34 : torch.Size([128])
35 : torch.Size([128])
36 : torch.Size([128, 128, 3, 3])
37 : torch.Size([128])
38 : torch.Size([128])
39 : torch.Size([512, 128, 1, 1])
40 : torch.Size([512])
41 : torch.Size([512])
42 : torch.Size([512, 256, 1, 1])
43 : torch.Size([512])
44 : torch.Size([512])
45 : torch.Size([128, 512, 1, 1])
46 : torch.Size([128])
47 : torch.Size([128])
48 : torch.Size([128, 128, 3, 3])
49 : torch.Size([128])
50 : torch.Size([128])
51 : torch.Size([512, 128, 1, 1])
52 : torch.Size([512])
53 : torch.Size([512])
54 : torch.Size([128, 512, 1, 1])
55 : torch.Size([128])
56 : torch.Size([128])
57 : torch.Size([128, 128, 3, 3])
58 : torch.Size([128])
59 : torch.Size([128])
60 : torch.Size([512, 128, 1, 1])
61 : torch.Size([512])
62 : torch.Size([512])
63 : torch.Size([128, 512, 1, 1])
64 : torch.Size([128])
65 : torch.Size([128])
66 : torch.Size([128, 128, 3, 3])
67 : torch.Size([128])
68 : torch.Size([128])
69 : torch.Size([512, 128, 1, 1])
70 : torch.Size([512])
71 : torch.Size([512])
72 : torch.Size([256, 512, 1, 1])
73 : torch.Size([256])
74 : torch.Size([256])
75 : torch.Size([256, 256, 3, 3])
76 : torch.Size([256])
77 : torch.Size([256])
78 : torch.Size([1024, 256, 1, 1])
79 : torch.Size([1024])
80 : torch.Size([1024])
81 : torch.Size([1024, 512, 1, 1])
82 : torch.Size([1024])
83 : torch.Size([1024])
84 : torch.Size([256, 1024, 1, 1])
85 : torch.Size([256])
86 : torch.Size([256])
87 : torch.Size([256, 256, 3, 3])
88 : torch.Size([256])
89 : torch.Size([256])
90 : torch.Size([1024, 256, 1, 1])
91 : torch.Size([1024])
92 : torch.Size([1024])
93 : torch.Size([256, 1024, 1, 1])
94 : torch.Size([256])
95 : torch.Size([256])
96 : torch.Size([256, 256, 3, 3])
97 : torch.Size([256])
98 : torch.Size([256])
99 : torch.Size([1024, 256, 1, 1])
100 : torch.Size([1024])
101 : torch.Size([1024])
102 : torch.Size([256, 1024, 1, 1])
103 : torch.Size([256])
104 : torch.Size([256])
105 : torch.Size([256, 256, 3, 3])
106 : torch.Size([256])
107 : torch.Size([256])
108 : torch.Size([1024, 256, 1, 1])
109 : torch.Size([1024])
110 : torch.Size([1024])
111 : torch.Size([256, 1024, 1, 1])
112 : torch.Size([256])
113 : torch.Size([256])
114 : torch.Size([256, 256, 3, 3])
115 : torch.Size([256])
116 : torch.Size([256])
117 : torch.Size([1024, 256, 1, 1])
118 : torch.Size([1024])
119 : torch.Size([1024])
120 : torch.Size([256, 1024, 1, 1])
121 : torch.Size([256])
122 : torch.Size([256])
123 : torch.Size([256, 256, 3, 3])
124 : torch.Size([256])
125 : torch.Size([256])
126 : torch.Size([1024, 256, 1, 1])
127 : torch.Size([1024])
128 : torch.Size([1024])
129 : torch.Size([512, 1024, 1, 1])
130 : torch.Size([512])
131 : torch.Size([512])
132 : torch.Size([512, 512, 3, 3])
133 : torch.Size([512])
134 : torch.Size([512])
135 : torch.Size([2048, 512, 1, 1])
136 : torch.Size([2048])
137 : torch.Size([2048])
138 : torch.Size([2048, 1024, 1, 1])
139 : torch.Size([2048])
140 : torch.Size([2048])
141 : torch.Size([512, 2048, 1, 1])
142 : torch.Size([512])
143 : torch.Size([512])
144 : torch.Size([512, 512, 3, 3])
145 : torch.Size([512])
146 : torch.Size([512])
147 : torch.Size([2048, 512, 1, 1])
148 : torch.Size([2048])
149 : torch.Size([2048])
150 : torch.Size([512, 2048, 1, 1])
151 : torch.Size([512])
152 : torch.Size([512])
153 : torch.Size([512, 512, 3, 3])
154 : torch.Size([512])
155 : torch.Size([512])
156 : torch.Size([2048, 512, 1, 1])
157 : torch.Size([2048])
158 : torch.Size([2048])
159 : torch.Size([10, 2048])
160 : torch.Size([10])

Pruned resnet50 optimizer.SGD.state_dict()['state']

0 : torch.Size([64, 3, 3, 3])
1 : torch.Size([64])
2 : torch.Size([64])
3 : torch.Size([64, 64, 1, 1])
4 : torch.Size([64])
5 : torch.Size([64])
6 : torch.Size([64, 64, 3, 3])
7 : torch.Size([64])
8 : torch.Size([64])
9 : torch.Size([256, 64, 1, 1])
10 : torch.Size([256])
11 : torch.Size([256])
12 : torch.Size([256, 64, 1, 1])
13 : torch.Size([256])
14 : torch.Size([256])
15 : torch.Size([64, 256, 1, 1])
16 : torch.Size([64])
17 : torch.Size([64])
18 : torch.Size([64, 64, 3, 3])
19 : torch.Size([64])
20 : torch.Size([64])
21 : torch.Size([256, 64, 1, 1])
22 : torch.Size([256])
23 : torch.Size([256])
24 : torch.Size([64, 256, 1, 1])
25 : torch.Size([64])
26 : torch.Size([64])
27 : torch.Size([64, 64, 3, 3])
28 : torch.Size([64])
29 : torch.Size([64])
30 : torch.Size([256, 64, 1, 1])
31 : torch.Size([256])
32 : torch.Size([256])
33 : torch.Size([128, 256, 1, 1])
34 : torch.Size([128])
35 : torch.Size([128])
36 : torch.Size([128, 128, 3, 3])
37 : torch.Size([128])
38 : torch.Size([128])
39 : torch.Size([512, 128, 1, 1])
40 : torch.Size([512])
41 : torch.Size([512])
42 : torch.Size([512, 256, 1, 1])
43 : torch.Size([512])
44 : torch.Size([512])
45 : torch.Size([128, 512, 1, 1])
46 : torch.Size([128])
47 : torch.Size([128])
48 : torch.Size([128, 128, 3, 3])
49 : torch.Size([128])
50 : torch.Size([128])
51 : torch.Size([512, 128, 1, 1])
52 : torch.Size([512])
53 : torch.Size([512])
54 : torch.Size([128, 512, 1, 1])
55 : torch.Size([128])
56 : torch.Size([128])
57 : torch.Size([128, 128, 3, 3])
58 : torch.Size([128])
59 : torch.Size([128])
60 : torch.Size([512, 128, 1, 1])
61 : torch.Size([512])
62 : torch.Size([512])
63 : torch.Size([128, 512, 1, 1])
64 : torch.Size([128])
65 : torch.Size([128])
66 : torch.Size([128, 128, 3, 3])
67 : torch.Size([128])
68 : torch.Size([128])
69 : torch.Size([512, 128, 1, 1])
70 : torch.Size([512])
71 : torch.Size([512])
72 : torch.Size([256, 512, 1, 1])
73 : torch.Size([256])
74 : torch.Size([256])
75 : torch.Size([256, 256, 3, 3])
76 : torch.Size([256])
77 : torch.Size([256])
78 : torch.Size([1024, 256, 1, 1])
79 : torch.Size([1024])
80 : torch.Size([1024])
81 : torch.Size([1024, 512, 1, 1])
82 : torch.Size([1024])
83 : torch.Size([1024])
84 : torch.Size([72, 1024, 1, 1])
85 : torch.Size([72])
86 : torch.Size([72])
87 : torch.Size([256, 72, 3, 3])
88 : torch.Size([256])
89 : torch.Size([256])
90 : torch.Size([1024, 256, 1, 1])
91 : torch.Size([1024])
92 : torch.Size([1024])
93 : torch.Size([80, 1024, 1, 1])
94 : torch.Size([80])
95 : torch.Size([80])
96 : torch.Size([256, 80, 3, 3])
97 : torch.Size([256])
98 : torch.Size([256])
99 : torch.Size([1024, 256, 1, 1])
100 : torch.Size([1024])
101 : torch.Size([1024])
102 : torch.Size([48, 1024, 1, 1])
103 : torch.Size([48])
104 : torch.Size([48])
105 : torch.Size([256, 48, 3, 3])
106 : torch.Size([256])
107 : torch.Size([256])
108 : torch.Size([1024, 256, 1, 1])
109 : torch.Size([1024])
110 : torch.Size([1024])
111 : torch.Size([38, 1024, 1, 1])
112 : torch.Size([38])
113 : torch.Size([38])
114 : torch.Size([256, 38, 3, 3])
115 : torch.Size([256])
116 : torch.Size([256])
117 : torch.Size([1024, 256, 1, 1])
118 : torch.Size([1024])
119 : torch.Size([1024])
120 : torch.Size([44, 1024, 1, 1])
121 : torch.Size([44])
122 : torch.Size([44])
123 : torch.Size([256, 44, 3, 3])
124 : torch.Size([256])
125 : torch.Size([256])
126 : torch.Size([1024, 256, 1, 1])
127 : torch.Size([1024])
128 : torch.Size([1024])
129 : torch.Size([512, 1024, 1, 1])
130 : torch.Size([512])
131 : torch.Size([512])
132 : torch.Size([512, 512, 3, 3])
133 : torch.Size([512])
134 : torch.Size([512])
135 : torch.Size([512, 512, 1, 1])
136 : torch.Size([512])
137 : torch.Size([512])
138 : torch.Size([512, 1024, 1, 1])
139 : torch.Size([512])
140 : torch.Size([512])
141 : torch.Size([512, 512, 1, 1])
142 : torch.Size([512])
143 : torch.Size([512])
144 : torch.Size([512, 512, 3, 3])
145 : torch.Size([512])
146 : torch.Size([512])
147 : torch.Size([512, 512, 1, 1])
148 : torch.Size([512])
149 : torch.Size([512])
150 : torch.Size([512, 512, 1, 1])
151 : torch.Size([512])
152 : torch.Size([512])
153 : torch.Size([512, 512, 3, 3])
154 : torch.Size([512])
155 : torch.Size([512])
156 : torch.Size([512, 512, 1, 1])
157 : torch.Size([512])
158 : torch.Size([512])
159 : torch.Size([10, 512])
160 : torch.Size([10])

➡️ Dense model과 Pruned model의 optimizer의 구조가 다르다는 것을 확인할 수 있다.

추가적으로 momentum을 사용하면 learning rate의 값이 0이라도 momentum 때문에 값이 업데이트 될 수 있다!

0개의 댓글