MedSSS-8B-PRM
Introduction
MedSSS-PRM is a the PRM model designed for slow-thinking medical reasoning. It will assign a [0-1]
float value for every internal reasoning step of MedSSS-Policy.
For more information, visit our GitHub repository: https://github.com/pixas/MedSSS.
Usage
We build the PRM model as a LoRA adapter, which saves the memory to use it.
As this LoRA adapter is built on Meta-Llama3.1-8B-Instruct
, you need to first prepare the base model in your platform.
def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
# `outputs` generated by the MedSSS-Policy
response = outputs
completions = [f"Step" + completion if not completion.startswith("Step") else completion for k, completion in enumerate(outputs.split("\n\nStep"))]
messages = [
{"role": "user", "content": inputs},
{"role": "assistant", "content": response}
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
response_begin_index = input_text.index(response)
pre_response_input = input_text[:response_begin_index]
after_response_input = input_text[response_begin_index + len(response):]
completion_ids = [
tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions
]
response_id = list(chain(*completion_ids))
pre_response_id = tokenizer(pre_response_input, add_special_tokens=False)['input_ids']
after_response_id = tokenizer(after_response_input, add_special_tokens=False)['input_ids']
input_ids = pre_response_id + response_id + after_response_id
value = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device)) # [1, N]
completion_index = []
for i, completion in enumerate(completion_ids):
if i == 0:
completion_index.append(len(completion) + len(pre_response_id) - 1)
else:
completion_index.append(completion_index[-1] + len(completion))
step_value = value[0, completion_index].cpu().numpy().tolist()
return step_value
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
base_model = AutoModelForTokenClassification.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",torch_dtype="auto",device_map="auto")
model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM")
steps
input_text = "How to stop a cough?"
step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]"
value = obtain_prm_value_for_single_pair(tokenizer, model, input_text, step_wise_generation)
print(value)
MedSSS-PRM uses "\n\nStep" to separate intermediate steps. So the token classification happens before the next "Step k: " or the end of the sequence.
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.
Model tree for pixas/MedSSS_PRM
Base model
meta-llama/Llama-3.1-8B
Finetuned
meta-llama/Llama-3.1-8B-Instruct