import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("atharvamundada99/bert-large-question-answering-finetuned-legal",cache_dir="/E/HUG_Models")
model = AutoModelForQuestionAnswering.from_pretrained("atharvamundada99/bert-large-question-answering-finetuned-legal", cache_dir="/E/HUG_Models")

def get_answer( question, context):
inputs = tokenizer(question, context, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()

predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1]
answer=tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
return answer

print(get_answer("What is your name","My name is JACK"))
#output JACK

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment