omkar-surve126 commited on
Commit
ff49c2d
·
verified ·
1 Parent(s): 4c14295

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
2
+ import streamlit as st
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ import re
7
+
8
+ @st.cache_resource
9
+ def init_model():
10
+ tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
11
+ model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
12
+ model = model.eval()
13
+ return model, tokenizer
14
+
15
+ # def init_gpu_model():
16
+ # tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
17
+ # model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
18
+ # model = model.eval().cuda()
19
+ # return model, tokenizer
20
+
21
+ def init_qwen_model():
22
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
23
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
24
+ return model, processor
25
+
26
+ def get_quen_op(image_file, model, processor):
27
+ try:
28
+ image = Image.open(image_file).convert('RGB')
29
+ conversation = [
30
+ {
31
+ "role":"user",
32
+ "content":[
33
+ {
34
+ "type":"image",
35
+ },
36
+ {
37
+ "type":"text",
38
+ "text":"Extract text from this image."
39
+ }
40
+ ]
41
+ }
42
+ ]
43
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
44
+ inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
45
+ inputs = {k: v.to(torch.float32) if torch.is_floating_point(v) else v for k, v in inputs.items()}
46
+
47
+ generation_config = {
48
+ "max_new_tokens": 32,
49
+ "do_sample": False,
50
+ "top_k": 20,
51
+ "top_p": 0.90,
52
+ "temperature": 0.4,
53
+ "num_return_sequences": 1,
54
+ "pad_token_id": processor.tokenizer.pad_token_id,
55
+ "eos_token_id": processor.tokenizer.eos_token_id,
56
+ }
57
+
58
+ output_ids = model.generate(**inputs, **generation_config)
59
+ if 'input_ids' in inputs:
60
+ generated_ids = output_ids[:, inputs['input_ids'].shape[1]:]
61
+ else:
62
+ generated_ids = output_ids
63
+
64
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
65
+
66
+ return output_text[:] if output_text else "No text extracted from the image."
67
+
68
+ except Exception as e:
69
+ return f"An error occurred: {str(e)}"
70
+
71
+ @st.cache_data
72
+ def get_text(image_file, _model, _tokenizer):
73
+ res = _model.chat(_tokenizer, image_file, ocr_type='ocr')
74
+ return res
75
+
76
+ def highlight_text(text, search_term):
77
+ if not search_term:
78
+ return text
79
+ pattern = re.compile(re.escape(search_term), re.IGNORECASE)
80
+ return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
81
+
82
+ st.title("GOT-OCR2.0")
83
+ st.write("Upload an image")
84
+
85
+ MODEL, PROCESSOR = init_model()
86
+
87
+ image_file = st.file_uploader("Upload Image", type=['jpg', 'png', 'jpeg'])
88
+
89
+ if image_file:
90
+ if not os.path.exists("images"):
91
+ os.makedirs("images")
92
+ with open(f"images/{image_file.name}", "wb") as f:
93
+ f.write(image_file.getbuffer())
94
+
95
+ image_file = f"images/{image_file.name}"
96
+
97
+ text = get_text(image_file, MODEL, PROCESSOR)
98
+
99
+ print(text)
100
+
101
+ # Add search functionality
102
+ search_term = st.text_input("Enter a word or phrase to search:")
103
+ highlighted_text = highlight_text(text, search_term)
104
+
105
+ st.markdown(highlighted_text, unsafe_allow_html=True)