Ad code

LLM.int8() - Large Language Model (LLM)의 계산 성능을 개선하기 위한 양자화 방법

LLM.int8() 개요

LLM.int8()은 Large Language Model (LLM)의 계산 성능을 개선하기 위한 8-bit 양자화 방법이다. 기존의 8-bit 양자화 방법은 성능 저하가 발생하는 문제점이 있었는데, LLM.int8()은 이를 해결하여 LLM의 성능을 유지하면서도 계산 성능을 크게 향상시킬 수 있다.

LLM.int8()의 핵심 요소는 vector-wise quantization mixed-precision decomposition이다.

vector-wise quantization은 텐서 당 여러 개의 scaling constant를 사용하여 outlier의 영향력을 줄이는 방법이다.

mixed-precision decomposition은 0.1%의 outlier만 16-bit로 나타내어지고, 99.9%의 값들은 8-bit로 matmul 계산이 되는 방법으로 성능에 영향을 최소화 한다.

LLM.int8()은 bitsandbytes 라이브러리를 통해 구현할 수 있다. bitsandbytes는 transformers, accelerate 등 여러 다른 라이브러리를 통해서 쓸 수 있도록 되어있기 때문에 확장성도 좋다고 할 수 있으므로, 본인에게 편한 라이브러리를 사용해보자.

간단한 모델을 int8로 변환하기

bitsandbytes를 사용하여 간단한 모델을 int8로 변환하는 방법은 다음과 같다

1
pip install torch bitsandbytes
cs

*주의 : bnb는 현재 GPU를 지원한다. 로컬 환경 구성이 어려운 분은, Colab에서 테스트하는 것을 추천한다.

필요한 라이브러리를 import한다.

1
2
3
4
import torch
import torch.nn as nn
 
import bitsandbytes as bnb
cs

테스트를 위해 간단한 linear 모델을 정의하도록 하자.

1
2
3
4
fp16_model = nn.Sequential(
    nn.Linear(6464),
    nn.Linear(6464)
)
cs

정의했던 FP16 모델을 기반으로, Linear8bitLt을 활용하여 int8 모델을 재정의한다.

1
2
3
4
int8_model = nn.Sequential(
    bnb.nn.Linear8bitLt(6464, has_fp16_weights=False),
    bnb.nn.Linear8bitLt(6464, has_fp16_weights=False)
)
cs

저장해둔 가중치를 int8 모델에 로드하면 끝이다.

1
2
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0# 여기에서 양자화가 진행
cs

이를 통해 int8로 양자화된 모델에 input을 넣어 inference를 진행할 수 있다.

1
2
input_ = torch.randn((164), dtype=torch.float16)
hidden_states = int8_model(input_.to(torch.device('cuda'0)))
cs

전체 코드는 다음과 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
 
import bitsandbytes as bnb
from bnb.nn import Linear8bitLt
# 실행 전
 
fp16_model = nn.Sequential(
    nn.Linear(6464),
    nn.Linear(6464)
)
 
# 훈련 코드 생략
 
torch.save(fp16_model.state_dict(), "model.pt")
 
# 실행 후
 
int8_model = nn.Sequential(
    bnb.nn.Linear8bitLt(6464, has_fp16_weights=False),
    bnb.nn.Linear8bitLt(6464, has_fp16_weights=False)
)
 
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0# 여기에서 양자화가 진행
 
# 실행 결과 비교
 
print("FP16 모델 가중치")
print(fp16_model[0].weight)
print()
print("INT8 모델 가중치")
print(int8_model[0].weight)
cs

이처럼 LLM.int8()은 LLM의 계산 성능을 크게 향상시킬 수 있는 효과적인 방법으로ㅡ bitsandbytes 라이브러리를 사용하여 간단하게 구현할 수 있다.







댓글 쓰기

0 댓글