class llama에 대해 알아보자.
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None, # ?
seed: int = 1,
) -> "Llama":
해당 build
함수는 모델 체크포인트를 로딩하고 initializing해서 Llama instance를 빌드하는 과정.
Args:
가장 먼저 build
가 정의되어 있다.
Returns:
Raises:
Note:
assert 1 <= max_seq_len <= 8192, f"max_seq_len must be between 1 and 8192, got {max_seq_len}." # Llama 3기준인데, 3.1은 어케되어있나..
assert os.path.isdir(ckpt_dir), f"Checkpoint directory '{ckpt_dir}' does not exits."
assert os.path.isfile(tokenizer_path), f"Tokenizer file '{tokenizer_path}' does not exist."
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size) # 나중에 살펴보기
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start.time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
llama/model.py
에 있는 class ModelArgs
에서 불러온다.@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
max_batch_size: int = 32
max_seq_len: int = 2048
max_batch_size
-> inference시에 사용할 수 있는 최대 배치수를 32,max_seq_len
-> 입력할 수 있는 최대 seq length를 2048로 default로 정의되어있음. tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
tokenizer도 가져온다. 이 때, tokenizer가 가지고 있는 word 수 (아마도 vocab size ? )와 모델이 설정한 vocab size가 같아야 한다. (당연한 부분인데, LLM 사전학습에서 사용했던 tokenizer를 통해 vocab를 예측하는데, vocab개수나 그 의미가 다르면 말이 안됨.)
torch.cuda.is_bf16_supported()
지원하면 BFloat16Tensor로 바꾸는데 일반적으로는 학습에서 BFloat16을 사용하는 것이 바람직 한 것으로 보인다. (larger exponent range 8 bit -126 ~ 127). inference 시에는 FP16(Half-Precision Float, Exponent bit 5, -14 to +15)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer)
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
ChatFormat은 여기 들어가면 나옴.
Explanation of @staticmethod
@staticmethod
Decorator: The @staticmethod
decorator is used to define a method in a class that does not operate on an instance of the class (i.e., it doesn't require access to self
or cls
). This means the method can be called on the class itself rather than on an instance of the class.self
) or the class (cls
). It behaves just like a regular function, but it belongs to the class's namespace and can be called using the class name.class MathOperations:
@staticmethod
def add(a, b):
return a + b
result = MathOperations.add(5, 3) # Called directly on the class
__init__
): Handles basic setup, like storing parameters that are always required.build
): Handles complex setup, like loading a pre-trained model from a checkpoint, configuring the tokenizer, and setting up any optional configurations.BFloat16 (Brain Float 16-bit):
- Exponent bits : 8 bits
Exponent range : -126 to +127
Effective rage : Approximately to
Bfloat16 has the same exponent range as FP32 (standard 32-bit floating-point), allowing it to represent a wide range of values, but with less precision in the significand (mantissa)
FP16 (Half-Precision Float):
- Exponent bits : 5 bits
Exponent range : -14 to +15
Effective range : Approximately to
FP16 has a narrower exponent range, meaning it can't represent as large or as small values as BFloat16 or FP32. However, it has more precision than BFloat16 within this smaller range.
In FP16: With 10 bits in the significand, FP16 can more closely approximate the value of 1.1
(it might represent it as 1.099609375
).
In BFloat16: With only 7 bits in the significand, BFloat16 has less precision and might represent 1.1
as 1.125
or 1.0625
, which is less accurate.
The precision in the significand (mantissa) of a floating-point number is related to how many bits are used to represent the fractional part of the number. The exact decimal values that can be represented are determined by the binary fractions like and so on.
The 10-bit significand in FP16 can represent values as binary fractions.
Each bit in the significand represents a power of :
Let’s say you have a 10-bit binary significand like 1.0001100110
:
1.
is implicit and represents the leading 1.0001100110
part is the 10-bit significand.In decimal, this binary number represents:
Which simplifies to: