Optimizing Inference Time for Chat Conversations on Falcon

#104
by humza-sami - opened

How can I optimize the inference time of my 4-bit quantized Falcon 7B, which was trained on a chat dataset using Qlora+PEFT. During inference, I loaded the model in 4 bits using the bits and bytes library. While the model performs well in inference, I've observed that the inference time increases significantly as the length of the chat grows. To give you an idea, here's an example of how I'm using the model:

First Message prompt:
< user>: Hi ..
< bot>:

2nd Message prompt:
< user>: Hi ..
< bot>: How are you ?
< user>: I am good thanks. I need your help !
< bot>:

As the length of the chat increases, the inference time sometimes doubles and can take up to 2-3 minutes per prompt. I'm using an NVIDIA RTX 3090 Ti for inference. Below is the code snippet I'm using for prediction:

MODEL_NAME = "tiiuae/falcon-7b"

bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True,
                                bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

SAVED_MODEL = "Saved_models/model_path"

saved_model_config = PeftConfig.from_pretrained(SAVED_MODEL)
saved_model = AutoModelForCausalLM.from_pretrained(saved_model_config.base_model_name_or_path,
                                             return_dict=True,
                                             quantization_config=bnb_config,
                                             device_map="auto",
                                             trust_remote_code=True
                                            )

tokenizer = AutoTokenizer.from_pretrained(saved_model_config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

saved_model = PeftModel.from_pretrained(saved_model, SAVED_MODEL)

pipeline = transformers.pipeline(
    "text-generation",
    model=saved_model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)

prompt = "< user>: Hi ..
< bot>: How are you ?
< user>: I am good thanks. I need your help ! 
< bot>: "

response = pipeline(
        prompt,
        bos_token_id=11,
        max_length=2000,
        temperature=0.7,
        top_p=0.7,
        do_sample=True,
        num_return_sequences=1,
        eos_token_id=[15564]
        )[0]['generated_text']

I would greatly appreciate any insights or suggestions on how to improve the inference time of model while dealing with longer chat interactions. Thank you in advance for your assistance!

From my experience working with Falcon40, it gets really slow. Especially compared to its competitors like Llama 7B.

However, there are some steps you can take to speed it up if you want to use it for production purposes:

  1. Use Sagemaker Jumpstart and fine-tune it again using the dedicated fine-tune feature. They are using a very optimized inference setup that responds quickly, in a matter of seconds.
  2. Use https://huggingface.co./text-generation-inference and set it up manually by yourself.

@chelouche9 Thanks, As I can't use sagemaker for some reasons for now but I will try 2nd option and let you know.

Sign up or log in to comment