LLM.int8() 개요
LLM.int8()은 Large Language Model (LLM)의 계산 성능을 개선하기 위한 8-bit 양자화 방법이다. 기존의 8-bit 양자화 방법은 성능 저하가 발생하는 문제점이 있었는데, LLM.int8()은 이를 해결하여 LLM의 성능을 유지하면서도 계산 성능을 크게 향상시킬 수 있다.
LLM.int8()의 핵심 요소는 vector-wise quantization과 mixed-precision decomposition이다.
vector-wise quantization은 텐서 당 여러 개의 scaling constant를 사용하여 outlier의 영향력을 줄이는 방법이다.
mixed-precision decomposition은 0.1%의 outlier만 16-bit로 나타내어지고, 99.9%의 값들은 8-bit로 matmul 계산이 되는 방법으로 성능에 영향을 최소화 한다.
LLM.int8()은 bitsandbytes 라이브러리를 통해 구현할 수 있다. bitsandbytes는 transformers, accelerate 등 여러 다른 라이브러리를 통해서 쓸 수 있도록 되어있기 때문에 확장성도 좋다고 할 수 있으므로, 본인에게 편한 라이브러리를 사용해보자.
간단한 모델을 int8로 변환하기
bitsandbytes를 사용하여 간단한 모델을 int8로 변환하는 방법은 다음과 같다
1 | pip install torch bitsandbytes | cs |
*주의 : bnb는 현재 GPU를 지원한다. 로컬 환경 구성이 어려운 분은, Colab에서 테스트하는 것을 추천한다.
필요한 라이브러리를 import한다.
1 2 3 4 | import torch import torch.nn as nn import bitsandbytes as bnb | cs |
테스트를 위해 간단한 linear 모델을 정의하도록 하자.
1 2 3 4 | fp16_model = nn.Sequential( nn.Linear(64, 64), nn.Linear(64, 64) ) | cs |
정의했던 FP16 모델을 기반으로, Linear8bitLt을 활용하여 int8 모델을 재정의한다.
1 2 3 4 | int8_model = nn.Sequential( bnb.nn.Linear8bitLt(64, 64, has_fp16_weights=False), bnb.nn.Linear8bitLt(64, 64, has_fp16_weights=False) ) | cs |
저장해둔 가중치를 int8 모델에 로드하면 끝이다.
1 2 | int8_model.load_state_dict(torch.load("model.pt")) int8_model = int8_model.to(0) # 여기에서 양자화가 진행 | cs |
이를 통해 int8로 양자화된 모델에 input을 넣어 inference를 진행할 수 있다.
1 2 | input_ = torch.randn((1, 64), dtype=torch.float16) hidden_states = int8_model(input_.to(torch.device('cuda', 0))) | cs |
전체 코드는 다음과 같다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import torch import torch.nn as nn import bitsandbytes as bnb from bnb.nn import Linear8bitLt # 실행 전 fp16_model = nn.Sequential( nn.Linear(64, 64), nn.Linear(64, 64) ) # 훈련 코드 생략 torch.save(fp16_model.state_dict(), "model.pt") # 실행 후 int8_model = nn.Sequential( bnb.nn.Linear8bitLt(64, 64, has_fp16_weights=False), bnb.nn.Linear8bitLt(64, 64, has_fp16_weights=False) ) int8_model.load_state_dict(torch.load("model.pt")) int8_model = int8_model.to(0) # 여기에서 양자화가 진행 # 실행 결과 비교 print("FP16 모델 가중치") print(fp16_model[0].weight) print() print("INT8 모델 가중치") print(int8_model[0].weight) | cs |
이처럼 LLM.int8()은 LLM의 계산 성능을 크게 향상시킬 수 있는 효과적인 방법으로ㅡ bitsandbytes 라이브러리를 사용하여 간단하게 구현할 수 있다.
0 댓글