MedSSS-8B-PRM

Introduction

MedSSS-PRM is a the PRM model designed for slow-thinking medical reasoning. It will assign a [0-1] float value for every internal reasoning step of MedSSS-Policy.

For more information, visit our GitHub repository: https://github.com/pixas/MedSSS.

Usage

We build the PRM model as a LoRA adapter, which saves the memory to use it. As this LoRA adapter is built on Meta-Llama3.1-8B-Instruct, you need to first prepare the base model in your platform.


def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
    # `outputs` generated by the MedSSS-Policy
    response = outputs
    completions = [f"Step" + completion if not completion.startswith("Step") else completion for k, completion in enumerate(outputs.split("\n\nStep"))]
    
    messages = [
        {"role": "user", "content": inputs},
        {"role": "assistant", "content": response}
    ]
    input_text = tokenizer.apply_chat_template(messages, tokenize=False)

    response_begin_index = input_text.index(response)

    pre_response_input = input_text[:response_begin_index]
    after_response_input = input_text[response_begin_index + len(response):]
    completion_ids = [
        tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions
    ]
    
    response_id = list(chain(*completion_ids))
    pre_response_id = tokenizer(pre_response_input, add_special_tokens=False)['input_ids']
    after_response_id = tokenizer(after_response_input, add_special_tokens=False)['input_ids']

    
    input_ids = pre_response_id + response_id + after_response_id
    
    value = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device))  # [1, N]
    
    completion_index = []
    for i, completion in enumerate(completion_ids):
        if i == 0:
            completion_index.append(len(completion) + len(pre_response_id) - 1)
        else:
            completion_index.append(completion_index[-1] + len(completion))
    
    step_value = value[0, completion_index].cpu().numpy().tolist()
    return step_value
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
base_model = AutoModelForTokenClassification.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",torch_dtype="auto",device_map="auto")
model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM")
steps
input_text = "How to stop a cough?"
step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]"

value = obtain_prm_value_for_single_pair(tokenizer, model, input_text, step_wise_generation)
print(value)

MedSSS-PRM uses "\n\nStep" to separate intermediate steps. So the token classification happens before the next "Step k: " or the end of the sequence.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for pixas/MedSSS_PRM

Finetuned
(993)
this model