DeciLM-7B / benchmark_hf_model.py
tomer-deci's picture
Upload benchmark_hf_model.py
b943e32
import json
from argparse import ArgumentParser
import datasets
import torch
import transformers
from transformers import AutoModelForCausalLM, BatchEncoding
"""
Usage examples (with the best batch sizes on A100-80GB-400W)
============================================================
python -m benchmark_hf_model --model_name_or_path="Deci/DeciLM-7B" --batch_size=352
python -m benchmark_hf_model --model_name_or_path="mistralai/Mistral-7B-v0.1" --batch_size=192 --model_kwargs_json='{"use_flash_attention_2": true}'
python -m benchmark_hf_model --model_name_or_path="meta-llama/Llama-2-7b-hf" --batch_size=48 --model_kwargs_json='{"use_flash_attention_2": true}'
"""
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
required=True,
)
parser.add_argument(
"--warmup_iters",
type=int,
default=10,
)
parser.add_argument(
"--iterations",
type=int,
default=5,
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
)
parser.add_argument(
"--prompt_length",
type=int,
default=512,
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=512,
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
help="Model precision, from: fp32, fp16 or bf16",
)
parser.add_argument(
"--model_kwargs_json",
type=str,
default=None,
)
return parser.parse_args()
def main():
args = parse_args()
transformers.logging.set_verbosity_error()
datasets.logging.set_verbosity_error()
dict_precisions = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if args.precision not in dict_precisions:
raise ValueError(
f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16"
)
dtype = dict_precisions[args.precision]
model_kwargs = {}
if args.model_kwargs_json is not None:
model_kwargs = json.loads(args.model_kwargs_json)
print(f"loading model...")
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True,
torch_dtype=dtype, **model_kwargs)
try:
print(model.model.layers[0].self_attn)
except:
print("couldn't print the model's attention module")
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
model.cuda()
model.eval()
prompt = torch.ones(args.prompt_length, dtype=torch.long)
inputs = BatchEncoding({"input_ids": prompt.repeat(args.batch_size, 1)})
inputs = inputs.to(model.device)
# warmup
print(f"warming up for {args.warmup_iters} iterations...")
for _ in range(args.warmup_iters):
with torch.no_grad():
_ = model.generate(
**inputs,
max_new_tokens=1,
do_sample=False,
eos_token_id=-1234,
)
print('finished warmup')
torch.cuda.synchronize()
print(
f"prefill ({args.prompt_length} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}) + generation ({args.max_new_tokens} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}):")
tokens_generated = args.max_new_tokens * args.batch_size
prefill_and_generation = []
for gen_iter in range(args.iterations):
starter.record()
with torch.no_grad():
_ = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
eos_token_id=-1234,
)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender) / 1000
prefill_and_generation.append(t)
print(f" iter {gen_iter + 1}: {t:.03f} sec total, {tokens_generated / t:.02f} generated tokens/sec")
aver = sum(prefill_and_generation) / len(prefill_and_generation)
print(f" average: {aver:.03f} sec total, {tokens_generated / aver:.02f} generated tokens/sec")
print(f"These results are obtained for model '{args.model_name_or_path}' with {args.batch_size=}.")
if __name__ == "__main__":
main()