Spaces:
Running
Running
geekyrakshit
commited on
Commit
•
170d9a9
1
Parent(s):
e6f968c
update: app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +89 -72
- docs/app.md +0 -61
- docs/assistant/figure_annotation.md +0 -3
- docs/assistant/llm_client.md +0 -3
- docs/assistant/medqa_assistant.md +0 -3
- docs/chunking.md +0 -3
- docs/document_loader/image_loader/base_img_loader.md +0 -3
- docs/document_loader/image_loader/fitzpil_img_loader.md +0 -22
- docs/document_loader/image_loader/marker_img_loader.md +0 -21
- docs/document_loader/image_loader/pdf2image_img_loader.md +0 -26
- docs/document_loader/image_loader/pdfplumber_img_loader.md +0 -22
- docs/document_loader/image_loader/pymupdf_img_loader.md +0 -23
- docs/document_loader/text_loader/base_text_loader.md +0 -3
- docs/document_loader/text_loader/marker_text_loader.md +0 -23
- docs/document_loader/text_loader/pdfplumber_text_loader.md +0 -22
- docs/document_loader/text_loader/pymupdf4llm_text_loader.md +0 -23
- docs/document_loader/text_loader/pypdf2_text_loader.md +0 -23
- docs/index.md +0 -40
- docs/installation/development.md +0 -40
- docs/installation/install.md +0 -9
- docs/retreival/bm25s.md +0 -3
- docs/retreival/colpali.md +0 -3
- docs/retreival/contriever.md +0 -3
- docs/retreival/medcpt.md +0 -3
- docs/retreival/nv_embed_2.md +0 -3
- install.sh +0 -30
- medrag_multi_modal/assistant/figure_annotation.py +4 -13
- medrag_multi_modal/assistant/llm_client.py +19 -11
- medrag_multi_modal/assistant/medqa_assistant.py +94 -28
- medrag_multi_modal/assistant/schema.py +27 -0
- medrag_multi_modal/cli.py +54 -3
- medrag_multi_modal/document_loader/image_loader/base_img_loader.py +80 -29
- medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py +16 -16
- medrag_multi_modal/document_loader/image_loader/marker_img_loader.py +15 -26
- medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py +7 -16
- medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py +16 -16
- medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py +16 -16
- medrag_multi_modal/document_loader/text_loader/base_text_loader.py +58 -20
- medrag_multi_modal/document_loader/text_loader/marker_text_loader.py +8 -15
- medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py +7 -13
- medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py +7 -15
- medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py +7 -13
- medrag_multi_modal/metrics/__init__.py +3 -0
- medrag_multi_modal/metrics/base.py +108 -0
- medrag_multi_modal/metrics/mmlu.py +24 -0
- medrag_multi_modal/retrieval/__init__.py +1 -13
- medrag_multi_modal/retrieval/colpali_retrieval.py +1 -1
- medrag_multi_modal/retrieval/common.py +0 -23
- medrag_multi_modal/retrieval/text_retrieval/__init__.py +11 -0
- medrag_multi_modal/retrieval/{bm25s_retrieval.py → text_retrieval/bm25s_retrieval.py} +87 -61
app.py
CHANGED
@@ -1,26 +1,20 @@
|
|
1 |
-
import os
|
2 |
-
import wandb
|
3 |
-
|
4 |
-
wandb.login(relogin=True, key=os.getenv("WANDB_API_KEY"))
|
5 |
-
|
6 |
-
|
7 |
import streamlit as st
|
8 |
-
import weave
|
9 |
|
10 |
-
from medrag_multi_modal.assistant import
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
GOOGLE_MODELS,
|
17 |
-
MISTRAL_MODELS,
|
18 |
-
OPENAI_MODELS,
|
19 |
)
|
20 |
-
from medrag_multi_modal.retrieval import MedCPTRetriever
|
21 |
|
22 |
# Define constants
|
23 |
-
ALL_AVAILABLE_MODELS =
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
# Sidebar for configuration settings
|
26 |
st.sidebar.title("Configuration Settings")
|
@@ -30,68 +24,91 @@ project_name = st.sidebar.text_input(
|
|
30 |
placeholder="wandb project name",
|
31 |
help="format: wandb_username/wandb_project_name",
|
32 |
)
|
33 |
-
|
34 |
-
label="
|
35 |
-
|
36 |
-
placeholder="wandb dataset name",
|
37 |
-
help="format: wandb_dataset_name:version",
|
38 |
)
|
39 |
-
|
40 |
-
label="
|
41 |
-
|
42 |
-
placeholder="wandb artifact address",
|
43 |
-
help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
|
44 |
)
|
45 |
-
|
46 |
-
label="
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
)
|
51 |
-
|
52 |
-
label="
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
)
|
57 |
-
|
58 |
-
label="
|
59 |
-
|
60 |
-
index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"),
|
61 |
-
help="select a model from the list",
|
62 |
)
|
63 |
-
|
64 |
-
label="
|
65 |
-
options=
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
)
|
69 |
|
70 |
-
|
71 |
-
st.title("MedQA Assistant App")
|
72 |
|
73 |
-
|
74 |
-
weave.init(project_name=project_name)
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
)
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
)
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
query = st.chat_input("Enter your question here")
|
92 |
-
if query:
|
93 |
-
with st.chat_message("user"):
|
94 |
-
st.markdown(query)
|
95 |
-
response = medqa_assistant.predict(query=query)
|
96 |
with st.chat_message("assistant"):
|
97 |
-
st.markdown(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
2 |
|
3 |
+
from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
|
4 |
+
from medrag_multi_modal.retrieval.text_retrieval import (
|
5 |
+
BM25sRetriever,
|
6 |
+
ContrieverRetriever,
|
7 |
+
MedCPTRetriever,
|
8 |
+
NVEmbed2Retriever,
|
|
|
|
|
|
|
9 |
)
|
|
|
10 |
|
11 |
# Define constants
|
12 |
+
ALL_AVAILABLE_MODELS = [
|
13 |
+
"gemini-1.5-flash-latest",
|
14 |
+
"gemini-1.5-pro-latest",
|
15 |
+
"gpt-4o",
|
16 |
+
"gpt-4o-mini",
|
17 |
+
]
|
18 |
|
19 |
# Sidebar for configuration settings
|
20 |
st.sidebar.title("Configuration Settings")
|
|
|
24 |
placeholder="wandb project name",
|
25 |
help="format: wandb_username/wandb_project_name",
|
26 |
)
|
27 |
+
chunk_dataset_id = st.sidebar.selectbox(
|
28 |
+
label="Chunk Dataset ID",
|
29 |
+
options=["ashwiniai/medrag-text-corpus-chunks"],
|
|
|
|
|
30 |
)
|
31 |
+
llm_model = st.sidebar.selectbox(
|
32 |
+
label="LLM Model",
|
33 |
+
options=ALL_AVAILABLE_MODELS,
|
|
|
|
|
34 |
)
|
35 |
+
top_k_chunks_for_query = st.sidebar.slider(
|
36 |
+
label="Top K Chunks for Query",
|
37 |
+
min_value=1,
|
38 |
+
max_value=20,
|
39 |
+
value=5,
|
40 |
)
|
41 |
+
top_k_chunks_for_options = st.sidebar.slider(
|
42 |
+
label="Top K Chunks for Options",
|
43 |
+
min_value=1,
|
44 |
+
max_value=20,
|
45 |
+
value=3,
|
46 |
)
|
47 |
+
rely_only_on_context = st.sidebar.checkbox(
|
48 |
+
label="Rely Only on Context",
|
49 |
+
value=False,
|
|
|
|
|
50 |
)
|
51 |
+
retriever_type = st.sidebar.selectbox(
|
52 |
+
label="Retriever Type",
|
53 |
+
options=[
|
54 |
+
"",
|
55 |
+
"BM25S",
|
56 |
+
"Contriever",
|
57 |
+
"MedCPT",
|
58 |
+
"NV-Embed-v2",
|
59 |
+
],
|
60 |
)
|
61 |
|
62 |
+
if retriever_type != "":
|
|
|
63 |
|
64 |
+
llm_model = LLMClient(model_name=llm_model)
|
|
|
65 |
|
66 |
+
retriever = None
|
67 |
+
|
68 |
+
if retriever_type == "BM25S":
|
69 |
+
retriever = BM25sRetriever.from_index(
|
70 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
|
71 |
+
)
|
72 |
+
elif retriever_type == "Contriever":
|
73 |
+
retriever = ContrieverRetriever.from_index(
|
74 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
|
75 |
+
chunk_dataset_id=chunk_dataset_id,
|
76 |
+
)
|
77 |
+
elif retriever_type == "MedCPT":
|
78 |
+
retriever = MedCPTRetriever.from_index(
|
79 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
|
80 |
+
chunk_dataset_id=chunk_dataset_id,
|
81 |
+
)
|
82 |
+
elif retriever_type == "NV-Embed-v2":
|
83 |
+
retriever = NVEmbed2Retriever.from_index(
|
84 |
+
index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
|
85 |
+
chunk_dataset_id=chunk_dataset_id,
|
86 |
+
)
|
87 |
+
|
88 |
+
medqa_assistant = MedQAAssistant(
|
89 |
+
llm_client=llm_model,
|
90 |
+
retriever=retriever,
|
91 |
+
top_k_chunks_for_query=top_k_chunks_for_query,
|
92 |
+
top_k_chunks_for_options=top_k_chunks_for_options,
|
93 |
+
)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
95 |
with st.chat_message("assistant"):
|
96 |
+
st.markdown(
|
97 |
+
"""
|
98 |
+
Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
|
99 |
+
I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.
|
100 |
+
|
101 |
+
**Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
|
102 |
+
Please consult a medical professional for any medical advice.
|
103 |
+
|
104 |
+
In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
|
105 |
+
""",
|
106 |
+
unsafe_allow_html=True,
|
107 |
+
)
|
108 |
+
query = st.chat_input("Enter your question here")
|
109 |
+
if query:
|
110 |
+
with st.chat_message("user"):
|
111 |
+
st.markdown(query)
|
112 |
+
response = medqa_assistant.predict(query=query)
|
113 |
+
with st.chat_message("assistant"):
|
114 |
+
st.markdown(response.response)
|
docs/app.md
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
# MedQA Assistant App
|
2 |
-
|
3 |
-
The MedQA Assistant App is a Streamlit-based application designed to provide a chat interface for medical question answering. It leverages advanced language models (LLMs) and retrieval augmented generation (RAG) techniques to deliver accurate and informative responses to medical queries.
|
4 |
-
|
5 |
-
## Features
|
6 |
-
|
7 |
-
- **Interactive Chat Interface**: Engage with the app through a user-friendly chat interface.
|
8 |
-
- **Configurable Settings**: Customize model selection and data sources via the sidebar.
|
9 |
-
- **Retrieval-Augmented Generation**: Ensures precise and contextually relevant responses.
|
10 |
-
- **Figure Annotation Capabilities**: Extracts and annotates figures from medical texts.
|
11 |
-
|
12 |
-
## Usage
|
13 |
-
|
14 |
-
1. Install the package using:
|
15 |
-
```bash
|
16 |
-
uv pip install .
|
17 |
-
```
|
18 |
-
1. **Launch the App**: Start the application using Streamlit:
|
19 |
-
```bash
|
20 |
-
medrag run
|
21 |
-
```
|
22 |
-
2. **Configure Settings**: Adjust configuration settings in the sidebar to suit your needs.
|
23 |
-
3. **Ask a Question**: Enter your medical question in the chat input field.
|
24 |
-
4. **Receive a Response**: Get a detailed answer from the MedQA Assistant.
|
25 |
-
|
26 |
-
## Configuration
|
27 |
-
|
28 |
-
The app allows users to customize various settings through the sidebar:
|
29 |
-
|
30 |
-
- **Project Name**: Specify the WandB project name.
|
31 |
-
- **Text Chunk WandB Dataset Name**: Define the dataset containing text chunks.
|
32 |
-
- **WandB Index Artifact Address**: Provide the address of the index artifact.
|
33 |
-
- **WandB Image Artifact Address**: Provide the address of the image artifact.
|
34 |
-
- **LLM Client Model Name**: Choose a language model for generating responses.
|
35 |
-
- **Figure Extraction Model Name**: Select a model for extracting figures from images.
|
36 |
-
- **Structured Output Model Name**: Choose a model for generating structured outputs.
|
37 |
-
|
38 |
-
## Technical Details
|
39 |
-
|
40 |
-
The app is built using the following components:
|
41 |
-
|
42 |
-
- **Streamlit**: For the user interface.
|
43 |
-
- **Weave**: For project initialization and artifact management.
|
44 |
-
- **MedQAAssistant**: For processing queries and generating responses.
|
45 |
-
- **LLMClient**: For interacting with language models.
|
46 |
-
- **MedCPTRetriever**: For retrieving relevant text chunks.
|
47 |
-
- **FigureAnnotatorFromPageImage**: For annotating figures in medical texts.
|
48 |
-
|
49 |
-
## Development and Deployment
|
50 |
-
|
51 |
-
- **Environment Setup**: Ensure all dependencies are installed as per the `pyproject.toml`.
|
52 |
-
- **Running the App**: Use Streamlit to run the app locally.
|
53 |
-
- **Deployment**: coming soon...
|
54 |
-
|
55 |
-
## Additional Resources
|
56 |
-
|
57 |
-
For more detailed information on the components and their usage, refer to the following documentation sections:
|
58 |
-
|
59 |
-
- [MedQA Assistant](/assistant/medqa_assistant)
|
60 |
-
- [LLM Client](/assistant/llm_client)
|
61 |
-
- [Figure Annotation](/assistant/figure_annotation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/assistant/figure_annotation.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# Figure Annotation
|
2 |
-
|
3 |
-
::: medrag_multi_modal.assistant.figure_annotation
|
|
|
|
|
|
|
|
docs/assistant/llm_client.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# LLM Client
|
2 |
-
|
3 |
-
::: medrag_multi_modal.assistant.llm_client
|
|
|
|
|
|
|
|
docs/assistant/medqa_assistant.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# MedQA Assistant
|
2 |
-
|
3 |
-
::: medrag_multi_modal.assistant.medqa_assistant
|
|
|
|
|
|
|
|
docs/chunking.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# Chunking
|
2 |
-
|
3 |
-
::: medrag_multi_modal.semantic_chunking
|
|
|
|
|
|
|
|
docs/document_loader/image_loader/base_img_loader.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
## Load images from PDF files
|
2 |
-
|
3 |
-
::: medrag_multi_modal.document_loader.image_loader.base_img_loader
|
|
|
|
|
|
|
|
docs/document_loader/image_loader/fitzpil_img_loader.md
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# Load images from PDF files (using Fitz & PIL)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `fitz` & `pillow`
|
5 |
-
|
6 |
-
Extract images from PDF files using `fitz` and `pillow`.
|
7 |
-
|
8 |
-
Use it in our library with:
|
9 |
-
```python
|
10 |
-
from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
|
11 |
-
```
|
12 |
-
|
13 |
-
For more details, please refer to the sources below.
|
14 |
-
|
15 |
-
**Sources:**
|
16 |
-
|
17 |
-
- [Docs](https://pymupdf.readthedocs.io/en/latest/intro.html)
|
18 |
-
- [GitHub](https://github.com/kastman/fitz)
|
19 |
-
- [PyPI](https://pypi.org/project/fitz/)
|
20 |
-
- [PyPI](https://pypi.org/project/pillow/)
|
21 |
-
|
22 |
-
::: medrag_multi_modal.document_loader.image_loader.fitzpil_img_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/image_loader/marker_img_loader.md
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
# Load images from PDF files (using Marker)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `marker-pdf`
|
5 |
-
|
6 |
-
Extract images from PDF files using `marker-pdf`.
|
7 |
-
|
8 |
-
Use it in our library with:
|
9 |
-
```python
|
10 |
-
from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
|
11 |
-
```
|
12 |
-
|
13 |
-
For details, please refer to the sources below.
|
14 |
-
|
15 |
-
**Sources:**
|
16 |
-
|
17 |
-
- [DataLab](https://www.datalab.to)
|
18 |
-
- [GitHub](https://github.com/VikParuchuri/marker)
|
19 |
-
- [PyPI](https://pypi.org/project/marker-pdf/)
|
20 |
-
|
21 |
-
::: medrag_multi_modal.document_loader.image_loader.marker_img_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/image_loader/pdf2image_img_loader.md
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
# Load images from PDF files (using PDF2Image)
|
2 |
-
|
3 |
-
!!! danger "Warning"
|
4 |
-
Unlike other image extraction methods in `document_loader.image_loader`, this loader does not extract embedded images from the PDF.
|
5 |
-
Instead, it creates a snapshot image version of each selected page from the PDF.
|
6 |
-
|
7 |
-
??? note "Note"
|
8 |
-
**Underlying Library:** `pdf2image`
|
9 |
-
|
10 |
-
Extract images from PDF files using `pdf2image`.
|
11 |
-
|
12 |
-
|
13 |
-
Use it in our library with:
|
14 |
-
```python
|
15 |
-
from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
|
16 |
-
```
|
17 |
-
|
18 |
-
For details and available `**kwargs`, please refer to the sources below.
|
19 |
-
|
20 |
-
**Sources:**
|
21 |
-
|
22 |
-
- [DataLab](https://www.datalab.to)
|
23 |
-
- [GitHub](https://github.com/VikParuchuri/marker)
|
24 |
-
- [PyPI](https://pypi.org/project/marker-pdf/)
|
25 |
-
|
26 |
-
::: medrag_multi_modal.document_loader.image_loader.pdf2image_img_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/image_loader/pdfplumber_img_loader.md
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# Load images from PDF files (using PDFPlumber)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `pdfplumber`
|
5 |
-
|
6 |
-
Extract images from PDF files using `pdfplumber`.
|
7 |
-
|
8 |
-
You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
|
9 |
-
|
10 |
-
Use it in our library with:
|
11 |
-
```python
|
12 |
-
from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
|
13 |
-
```
|
14 |
-
|
15 |
-
For details, please refer to the sources below.
|
16 |
-
|
17 |
-
**Sources:**
|
18 |
-
|
19 |
-
- [GitHub](https://github.com/jsvine/pdfplumber)
|
20 |
-
- [PyPI](https://pypi.org/project/pdfplumber/)
|
21 |
-
|
22 |
-
::: medrag_multi_modal.document_loader.image_loader.pdfplumber_img_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/image_loader/pymupdf_img_loader.md
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
# Load images from PDF files (using PyMuPDF)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `pymupdf`
|
5 |
-
|
6 |
-
PyMuPDF is a high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents.
|
7 |
-
|
8 |
-
You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
|
9 |
-
|
10 |
-
Use it in our library with:
|
11 |
-
```python
|
12 |
-
from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
|
13 |
-
```
|
14 |
-
|
15 |
-
For details, please refer to the sources below.
|
16 |
-
|
17 |
-
**Sources:**
|
18 |
-
|
19 |
-
- [Docs](https://pymupdf.readthedocs.io/en/latest/)
|
20 |
-
- [GitHub](https://github.com/pymupdf/PyMuPDF)
|
21 |
-
- [PyPI](https://pypi.org/project/PyMuPDF/)
|
22 |
-
|
23 |
-
::: medrag_multi_modal.document_loader.image_loader.pymupdf_img_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/text_loader/base_text_loader.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
## Load text from PDF files
|
2 |
-
|
3 |
-
::: medrag_multi_modal.document_loader.text_loader.base_text_loader
|
|
|
|
|
|
|
|
docs/document_loader/text_loader/marker_text_loader.md
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
## Load text from PDF files (using Marker)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `marker-pdf`
|
5 |
-
|
6 |
-
Convert PDF to markdown quickly and accurately using a pipeline of deep learning models.
|
7 |
-
|
8 |
-
You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
|
9 |
-
|
10 |
-
Use it in our library with:
|
11 |
-
```python
|
12 |
-
from medrag_multi_modal.document_loader.text_loader import MarkerTextLoader
|
13 |
-
```
|
14 |
-
|
15 |
-
For details and available `**kwargs`, please refer to the sources below.
|
16 |
-
|
17 |
-
**Sources:**
|
18 |
-
|
19 |
-
- [DataLab](https://www.datalab.to)
|
20 |
-
- [GitHub](https://github.com/VikParuchuri/marker)
|
21 |
-
- [PyPI](https://pypi.org/project/marker-pdf/)
|
22 |
-
|
23 |
-
::: medrag_multi_modal.document_loader.text_loader.marker_text_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/text_loader/pdfplumber_text_loader.md
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
## Load text from PDF files (using PDFPlumber)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `pdfplumber`
|
5 |
-
|
6 |
-
Plumb a PDF for detailed information about each char, rectangle, line, et cetera — and easily extract text and tables.
|
7 |
-
|
8 |
-
You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
|
9 |
-
|
10 |
-
Use it in our library with:
|
11 |
-
```python
|
12 |
-
from medrag_multi_modal.document_loader.text_loader import PDFPlumberTextLoader
|
13 |
-
```
|
14 |
-
|
15 |
-
For details and available `**kwargs`, please refer to the sources below.
|
16 |
-
|
17 |
-
**Sources:**
|
18 |
-
|
19 |
-
- [GitHub](https://github.com/jsvine/pdfplumber)
|
20 |
-
- [PyPI](https://pypi.org/project/pdfplumber/)
|
21 |
-
|
22 |
-
::: medrag_multi_modal.document_loader.text_loader.pdfplumber_text_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/text_loader/pymupdf4llm_text_loader.md
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
## Load text from PDF files (using PyMuPDF4LLM)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `pymupdf4llm`
|
5 |
-
|
6 |
-
PyMuPDF is a high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents.
|
7 |
-
|
8 |
-
You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
|
9 |
-
|
10 |
-
Use it in our library with:
|
11 |
-
```python
|
12 |
-
from medrag_multi_modal.document_loader.text_loader import PyMuPDF4LLMTextLoader
|
13 |
-
```
|
14 |
-
|
15 |
-
For details and available `**kwargs`, please refer to the sources below.
|
16 |
-
|
17 |
-
**Sources:**
|
18 |
-
|
19 |
-
- [Docs](https://pymupdf.readthedocs.io/en/latest/pymupdf4llm/)
|
20 |
-
- [GitHub](https://github.com/pymupdf/PyMuPDF)
|
21 |
-
- [PyPI](https://pypi.org/project/pymupdf4llm/)
|
22 |
-
|
23 |
-
::: medrag_multi_modal.document_loader.text_loader.pymupdf4llm_text_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/document_loader/text_loader/pypdf2_text_loader.md
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
## Load text from PDF files (using PyPDF2)
|
2 |
-
|
3 |
-
??? note "Note"
|
4 |
-
**Underlying Library:** `pypdf2`
|
5 |
-
|
6 |
-
A pure-python PDF library capable of splitting, merging, cropping, and transforming the pages of PDF files
|
7 |
-
|
8 |
-
You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
|
9 |
-
|
10 |
-
Use it in our library with:
|
11 |
-
```python
|
12 |
-
from medrag_multi_modal.document_loader.text_loader import PyPDF2TextLoader
|
13 |
-
```
|
14 |
-
|
15 |
-
For details and available `**kwargs`, please refer to the sources below.
|
16 |
-
|
17 |
-
**Sources:**
|
18 |
-
|
19 |
-
- [Docs](https://pypdf2.readthedocs.io/en/3.x/)
|
20 |
-
- [GitHub](https://github.com/py-pdf/pypdf)
|
21 |
-
- [PyPI](https://pypi.org/project/PyPDF2/)
|
22 |
-
|
23 |
-
::: medrag_multi_modal.document_loader.text_loader.pypdf2_text_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/index.md
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
# MedRAG Multi-Modal
|
2 |
-
|
3 |
-
Multi-modal RAG for medical docmain.
|
4 |
-
|
5 |
-
## Installation
|
6 |
-
|
7 |
-
### For Development
|
8 |
-
|
9 |
-
For MacOS, you need to run
|
10 |
-
|
11 |
-
```bash
|
12 |
-
brew install poppler
|
13 |
-
```
|
14 |
-
|
15 |
-
For Debian/Ubuntu, you need to run
|
16 |
-
|
17 |
-
```bash
|
18 |
-
sudo apt-get install -y poppler-utils
|
19 |
-
```
|
20 |
-
|
21 |
-
Then, you can install the dependencies using uv in the virtual environment `.venv` using
|
22 |
-
|
23 |
-
```bash
|
24 |
-
git clone https://github.com/soumik12345/medrag-multi-modal
|
25 |
-
cd medrag-multi-modal
|
26 |
-
pip install -U pip uv
|
27 |
-
uv sync
|
28 |
-
```
|
29 |
-
|
30 |
-
After this, you need to activate the virtual environment using
|
31 |
-
|
32 |
-
```bash
|
33 |
-
source .venv/bin/activate
|
34 |
-
```
|
35 |
-
|
36 |
-
In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using
|
37 |
-
|
38 |
-
```bash
|
39 |
-
uv pip install flash-attn --no-build-isolation
|
40 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/installation/development.md
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
# Setting up the development environment
|
2 |
-
|
3 |
-
## Install Poppler
|
4 |
-
|
5 |
-
For MacOS, you need to run
|
6 |
-
|
7 |
-
```bash
|
8 |
-
brew install poppler
|
9 |
-
```
|
10 |
-
|
11 |
-
For Debian/Ubuntu, you need to run
|
12 |
-
|
13 |
-
```bash
|
14 |
-
sudo apt-get install -y poppler-utils
|
15 |
-
```
|
16 |
-
|
17 |
-
## Install the dependencies
|
18 |
-
|
19 |
-
Then, you can install the dependencies using uv in the virtual environment `.venv` using
|
20 |
-
|
21 |
-
```bash
|
22 |
-
git clone https://github.com/soumik12345/medrag-multi-modal
|
23 |
-
cd medrag-multi-modal
|
24 |
-
pip install -U pip uv
|
25 |
-
uv sync
|
26 |
-
```
|
27 |
-
|
28 |
-
After this, you need to activate the virtual environment using
|
29 |
-
|
30 |
-
```bash
|
31 |
-
source .venv/bin/activate
|
32 |
-
```
|
33 |
-
|
34 |
-
## [Optional] Install Flash Attention
|
35 |
-
|
36 |
-
In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using
|
37 |
-
|
38 |
-
```bash
|
39 |
-
uv pip install flash-attn --no-build-isolation
|
40 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/installation/install.md
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
# Installation
|
2 |
-
|
3 |
-
You just need to clone the repository and run the install.sh script
|
4 |
-
|
5 |
-
```bash
|
6 |
-
git clone https://github.com/soumik12345/medrag-multi-modal
|
7 |
-
cd medrag-multi-modal
|
8 |
-
sh install.sh
|
9 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/retreival/bm25s.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# BM25-Sparse Retrieval
|
2 |
-
|
3 |
-
::: medrag_multi_modal.retrieval.bm25s_retrieval
|
|
|
|
|
|
|
|
docs/retreival/colpali.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# ColPali Retrieval
|
2 |
-
|
3 |
-
::: medrag_multi_modal.retrieval.colpali_retrieval
|
|
|
|
|
|
|
|
docs/retreival/contriever.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# Contriever Retrieval
|
2 |
-
|
3 |
-
::: medrag_multi_modal.retrieval.contriever_retrieval
|
|
|
|
|
|
|
|
docs/retreival/medcpt.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# MedCPT Retrieval
|
2 |
-
|
3 |
-
::: medrag_multi_modal.retrieval.medcpt_retrieval
|
|
|
|
|
|
|
|
docs/retreival/nv_embed_2.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# NV-Embed-v2 Retrieval
|
2 |
-
|
3 |
-
::: medrag_multi_modal.retrieval.nv_embed_2
|
|
|
|
|
|
|
|
install.sh
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
OS_TYPE=$(uname -s)
|
4 |
-
|
5 |
-
if [ "$OS_TYPE" = "Darwin" ]; then
|
6 |
-
echo "Detected macOS."
|
7 |
-
brew install poppler
|
8 |
-
elif [ "$OS_TYPE" = "Linux" ]; then
|
9 |
-
if [ -f /etc/os-release ]; then
|
10 |
-
. /etc/os-release
|
11 |
-
if [ "$ID" = "ubuntu" ] || [ "$ID" = "debian" ]; then
|
12 |
-
echo "Detected Ubuntu/Debian."
|
13 |
-
sudo apt-get update
|
14 |
-
sudo apt-get install -y poppler-utils
|
15 |
-
else
|
16 |
-
echo "Unsupported Linux distribution: $ID"
|
17 |
-
exit 1
|
18 |
-
fi
|
19 |
-
else
|
20 |
-
echo "Cannot detect Linux distribution."
|
21 |
-
exit 1
|
22 |
-
fi
|
23 |
-
else
|
24 |
-
echo "Unsupported OS: $OS_TYPE"
|
25 |
-
exit 1
|
26 |
-
fi
|
27 |
-
|
28 |
-
git clone https://github.com/soumik12345/medrag-multi-modal
|
29 |
-
cd medrag-multi-modal
|
30 |
-
pip install -U .[core]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrag_multi_modal/assistant/figure_annotation.py
CHANGED
@@ -5,19 +5,10 @@ from typing import Optional, Union
|
|
5 |
import cv2
|
6 |
import weave
|
7 |
from PIL import Image
|
8 |
-
from pydantic import BaseModel
|
9 |
|
10 |
-
from
|
11 |
-
from .
|
12 |
-
|
13 |
-
|
14 |
-
class FigureAnnotation(BaseModel):
|
15 |
-
figure_id: str
|
16 |
-
figure_description: str
|
17 |
-
|
18 |
-
|
19 |
-
class FigureAnnotations(BaseModel):
|
20 |
-
annotations: list[FigureAnnotation]
|
21 |
|
22 |
|
23 |
class FigureAnnotatorFromPageImage(weave.Model):
|
@@ -108,7 +99,7 @@ Here are some clues you need to follow:
|
|
108 |
)
|
109 |
|
110 |
@weave.op()
|
111 |
-
def predict(self, page_idx: int) -> dict[int, list[
|
112 |
"""
|
113 |
Predicts figure annotations for a specific page in a document.
|
114 |
|
|
|
5 |
import cv2
|
6 |
import weave
|
7 |
from PIL import Image
|
|
|
8 |
|
9 |
+
from medrag_multi_modal.assistant.llm_client import LLMClient
|
10 |
+
from medrag_multi_modal.assistant.schema import FigureAnnotations
|
11 |
+
from medrag_multi_modal.utils import get_wandb_artifact, read_jsonl_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class FigureAnnotatorFromPageImage(weave.Model):
|
|
|
99 |
)
|
100 |
|
101 |
@weave.op()
|
102 |
+
def predict(self, page_idx: int) -> dict[int, list[FigureAnnotations]]:
|
103 |
"""
|
104 |
Predicts figure annotations for a specific page in a document.
|
105 |
|
medrag_multi_modal/assistant/llm_client.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
from enum import Enum
|
3 |
from typing import Any, Optional, Union
|
@@ -93,6 +94,7 @@ class LLMClient(weave.Model):
|
|
93 |
schema: Optional[Any] = None,
|
94 |
) -> Union[str, Any]:
|
95 |
import google.generativeai as genai
|
|
|
96 |
|
97 |
system_prompt = (
|
98 |
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
@@ -100,18 +102,25 @@ class LLMClient(weave.Model):
|
|
100 |
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
101 |
|
102 |
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
|
103 |
-
model = genai.GenerativeModel(self.model_name)
|
104 |
generation_config = (
|
105 |
None
|
106 |
if schema is None
|
107 |
else genai.GenerationConfig(
|
108 |
-
response_mime_type="application/json", response_schema=
|
109 |
)
|
110 |
)
|
111 |
response = model.generate_content(
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
)
|
114 |
-
return response.text if schema is None else response
|
115 |
|
116 |
@weave.op()
|
117 |
def execute_mistral_sdk(
|
@@ -146,14 +155,13 @@ class LLMClient(weave.Model):
|
|
146 |
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
|
147 |
client = instructor.from_mistral(client) if schema is not None else client
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
else client.messages.create(
|
153 |
-
response_model=schema, messages=messages, temperature=0
|
154 |
)
|
155 |
-
|
156 |
-
|
|
|
157 |
|
158 |
@weave.op()
|
159 |
def execute_openai_sdk(
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
from enum import Enum
|
4 |
from typing import Any, Optional, Union
|
|
|
94 |
schema: Optional[Any] = None,
|
95 |
) -> Union[str, Any]:
|
96 |
import google.generativeai as genai
|
97 |
+
from google.generativeai.types import HarmBlockThreshold, HarmCategory
|
98 |
|
99 |
system_prompt = (
|
100 |
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
|
|
102 |
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
103 |
|
104 |
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
|
105 |
+
model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt)
|
106 |
generation_config = (
|
107 |
None
|
108 |
if schema is None
|
109 |
else genai.GenerationConfig(
|
110 |
+
response_mime_type="application/json", response_schema=schema
|
111 |
)
|
112 |
)
|
113 |
response = model.generate_content(
|
114 |
+
user_prompt,
|
115 |
+
generation_config=generation_config,
|
116 |
+
# This is necessary in order to answer questions about anatomy, sexual diseases,
|
117 |
+
# medical devices, medicines, etc.
|
118 |
+
safety_settings={
|
119 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
120 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
121 |
+
},
|
122 |
)
|
123 |
+
return response.text if schema is None else json.loads(response.text)
|
124 |
|
125 |
@weave.op()
|
126 |
def execute_mistral_sdk(
|
|
|
155 |
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
|
156 |
client = instructor.from_mistral(client) if schema is not None else client
|
157 |
|
158 |
+
if schema is None:
|
159 |
+
raise NotImplementedError(
|
160 |
+
"Mistral does not support structured output using a schema"
|
|
|
|
|
161 |
)
|
162 |
+
else:
|
163 |
+
response = client.chat.complete(model=self.model_name, messages=messages)
|
164 |
+
return response.choices[0].message.content
|
165 |
|
166 |
@weave.op()
|
167 |
def execute_openai_sdk(
|
medrag_multi_modal/assistant/medqa_assistant.py
CHANGED
@@ -1,8 +1,16 @@
|
|
|
|
|
|
1 |
import weave
|
2 |
|
3 |
-
from
|
4 |
-
from .
|
5 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class MedQAAssistant(weave.Model):
|
@@ -47,39 +55,68 @@ class MedQAAssistant(weave.Model):
|
|
47 |
llm_client (LLMClient): The language model client used to generate responses.
|
48 |
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
|
49 |
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
|
50 |
-
|
|
|
51 |
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
|
52 |
"""
|
53 |
|
54 |
llm_client: LLMClient
|
55 |
retriever: weave.Model
|
56 |
-
figure_annotator: FigureAnnotatorFromPageImage
|
57 |
-
|
|
|
|
|
58 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
59 |
|
60 |
@weave.op()
|
61 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
"""
|
63 |
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
|
64 |
from a medical document and using a language model to generate the final response.
|
65 |
|
66 |
This function performs the following steps:
|
67 |
-
1. Retrieves relevant text chunks from the medical document based on the query
|
|
|
68 |
2. Extracts the text and page indices from the retrieved chunks.
|
69 |
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
|
70 |
-
4. Constructs a system prompt and user prompt combining the query, retrieved text chunks,
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
Args:
|
75 |
query (str): The medical query to be answered.
|
|
|
|
|
76 |
|
77 |
Returns:
|
78 |
-
|
79 |
"""
|
80 |
-
retrieved_chunks = self.
|
81 |
-
|
82 |
-
)
|
83 |
|
84 |
retrieved_chunk_texts = []
|
85 |
page_indices = set()
|
@@ -88,21 +125,50 @@ class MedQAAssistant(weave.Model):
|
|
88 |
page_indices.add(int(chunk["page_idx"]))
|
89 |
|
90 |
figure_descriptions = []
|
91 |
-
|
92 |
-
|
93 |
-
page_idx
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
You are an expert in medical science. You are given a
|
|
|
|
|
|
|
|
|
101 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
response = self.llm_client.predict(
|
103 |
system_prompt=system_prompt,
|
104 |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
|
|
105 |
)
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
import weave
|
4 |
|
5 |
+
from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage
|
6 |
+
from medrag_multi_modal.assistant.llm_client import LLMClient
|
7 |
+
from medrag_multi_modal.assistant.schema import (
|
8 |
+
MedQACitation,
|
9 |
+
MedQAMCQResponse,
|
10 |
+
MedQAResponse,
|
11 |
+
)
|
12 |
+
from medrag_multi_modal.retrieval.common import SimilarityMetric
|
13 |
+
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
|
14 |
|
15 |
|
16 |
class MedQAAssistant(weave.Model):
|
|
|
55 |
llm_client (LLMClient): The language model client used to generate responses.
|
56 |
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
|
57 |
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
|
58 |
+
top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query.
|
59 |
+
top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options.
|
60 |
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
|
61 |
"""
|
62 |
|
63 |
llm_client: LLMClient
|
64 |
retriever: weave.Model
|
65 |
+
figure_annotator: Optional[FigureAnnotatorFromPageImage] = None
|
66 |
+
top_k_chunks_for_query: int = 2
|
67 |
+
top_k_chunks_for_options: int = 2
|
68 |
+
rely_only_on_context: bool = True
|
69 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
70 |
|
71 |
@weave.op()
|
72 |
+
def retrieve_chunks_for_query(self, query: str) -> list[dict]:
|
73 |
+
retriever_kwargs = {"top_k": self.top_k_chunks_for_query}
|
74 |
+
if not isinstance(self.retriever, BM25sRetriever):
|
75 |
+
retriever_kwargs["metric"] = self.retrieval_similarity_metric
|
76 |
+
return self.retriever.predict(query, **retriever_kwargs)
|
77 |
+
|
78 |
+
@weave.op()
|
79 |
+
def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
|
80 |
+
retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
|
81 |
+
if not isinstance(self.retriever, BM25sRetriever):
|
82 |
+
retriever_kwargs["metric"] = self.retrieval_similarity_metric
|
83 |
+
retrieved_chunks = []
|
84 |
+
for option in options:
|
85 |
+
retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
|
86 |
+
return retrieved_chunks
|
87 |
+
|
88 |
+
@weave.op()
|
89 |
+
def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse:
|
90 |
"""
|
91 |
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
|
92 |
from a medical document and using a language model to generate the final response.
|
93 |
|
94 |
This function performs the following steps:
|
95 |
+
1. Retrieves relevant text chunks from the medical document based on the query and any provided options
|
96 |
+
using the retriever model.
|
97 |
2. Extracts the text and page indices from the retrieved chunks.
|
98 |
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
|
99 |
+
4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks,
|
100 |
+
and figure descriptions.
|
101 |
+
5. Uses the language model client to generate a response based on the constructed prompts, either choosing
|
102 |
+
from provided options or generating a free-form response.
|
103 |
+
6. Returns the generated response, which includes the answer and explanation if options were provided.
|
104 |
+
|
105 |
+
The function can operate in two modes:
|
106 |
+
- Multiple choice: When options are provided, it selects the best answer from the options and explains the choice
|
107 |
+
- Free response: When no options are provided, it generates a comprehensive response based on the context
|
108 |
|
109 |
Args:
|
110 |
query (str): The medical query to be answered.
|
111 |
+
options (Optional[list[str]]): The list of options to choose from.
|
112 |
+
rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation.
|
113 |
|
114 |
Returns:
|
115 |
+
MedQAResponse: The generated response to the query, including source information.
|
116 |
"""
|
117 |
+
retrieved_chunks = self.retrieve_chunks_for_query(query)
|
118 |
+
options = options or []
|
119 |
+
retrieved_chunks += self.retrieve_chunks_for_options(options)
|
120 |
|
121 |
retrieved_chunk_texts = []
|
122 |
page_indices = set()
|
|
|
125 |
page_indices.add(int(chunk["page_idx"]))
|
126 |
|
127 |
figure_descriptions = []
|
128 |
+
if self.figure_annotator is not None:
|
129 |
+
for page_idx in page_indices:
|
130 |
+
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
|
131 |
+
page_idx
|
132 |
+
]
|
133 |
+
figure_descriptions += [
|
134 |
+
item["figure_description"] for item in figure_annotations
|
135 |
+
]
|
136 |
+
|
137 |
+
system_prompt = """You are an expert in medical science. You are given a question
|
138 |
+
and a list of excerpts from various medical documents.
|
139 |
+
"""
|
140 |
+
query = f"""# Question
|
141 |
+
{query}
|
142 |
"""
|
143 |
+
|
144 |
+
if len(options) > 0:
|
145 |
+
system_prompt += """\nYou are also given a list of options to choose your answer from.
|
146 |
+
You are supposed to choose the best possible option based on the context provided. You should also
|
147 |
+
explain your answer to justify why you chose that option.
|
148 |
+
"""
|
149 |
+
query += "## Options\n"
|
150 |
+
for option in options:
|
151 |
+
query += f"- {option}\n"
|
152 |
+
else:
|
153 |
+
system_prompt += "\nYou are supposed to answer the question based on the context provided."
|
154 |
+
|
155 |
+
if self.rely_only_on_context:
|
156 |
+
system_prompt += """\n\nYou are only allowed to use the context provided to answer the question.
|
157 |
+
You are not allowed to use any external knowledge to answer the question.
|
158 |
+
"""
|
159 |
+
|
160 |
response = self.llm_client.predict(
|
161 |
system_prompt=system_prompt,
|
162 |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
163 |
+
schema=MedQAMCQResponse if len(options) > 0 else None,
|
164 |
)
|
165 |
+
|
166 |
+
# TODO: Add figure citations
|
167 |
+
# TODO: Add source document name from retrieved chunks as citations
|
168 |
+
citations = []
|
169 |
+
for page_idx in page_indices:
|
170 |
+
citations.append(
|
171 |
+
MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy")
|
172 |
+
)
|
173 |
+
|
174 |
+
return MedQAResponse(response=response, citations=citations)
|
medrag_multi_modal/assistant/schema.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class FigureAnnotation(BaseModel):
|
7 |
+
figure_id: str
|
8 |
+
figure_description: str
|
9 |
+
|
10 |
+
|
11 |
+
class FigureAnnotations(BaseModel):
|
12 |
+
annotations: list[FigureAnnotation]
|
13 |
+
|
14 |
+
|
15 |
+
class MedQAMCQResponse(BaseModel):
|
16 |
+
answer: str
|
17 |
+
explanation: str
|
18 |
+
|
19 |
+
|
20 |
+
class MedQACitation(BaseModel):
|
21 |
+
page_number: int
|
22 |
+
document_name: str
|
23 |
+
|
24 |
+
|
25 |
+
class MedQAResponse(BaseModel):
|
26 |
+
response: Union[str, MedQAMCQResponse]
|
27 |
+
citations: list[MedQACitation]
|
medrag_multi_modal/cli.py
CHANGED
@@ -1,16 +1,67 @@
|
|
1 |
import argparse
|
|
|
2 |
import subprocess
|
3 |
import sys
|
4 |
|
5 |
|
6 |
def main():
|
7 |
parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI")
|
8 |
-
parser.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
args = parser.parse_args()
|
10 |
|
11 |
if args.command == "run":
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
if __name__ == "__main__":
|
|
|
1 |
import argparse
|
2 |
+
import os
|
3 |
import subprocess
|
4 |
import sys
|
5 |
|
6 |
|
7 |
def main():
|
8 |
parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI")
|
9 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
10 |
+
|
11 |
+
# Run subcommand
|
12 |
+
run_parser = subparsers.add_parser("run", help="Run the Streamlit application")
|
13 |
+
run_parser.add_argument(
|
14 |
+
"--port", type=int, default=8501, help="Port to run Streamlit on"
|
15 |
+
)
|
16 |
+
|
17 |
+
# Evaluate subcommand
|
18 |
+
eval_parser = subparsers.add_parser("evaluate", help="Run evaluation tests")
|
19 |
+
eval_parser.add_argument(
|
20 |
+
"--test-file",
|
21 |
+
default=os.path.join("tests", "evals", "test_assistant_mmlu_anatomy.py"),
|
22 |
+
help="Path to test file",
|
23 |
+
)
|
24 |
+
eval_parser.add_argument(
|
25 |
+
"--test-case",
|
26 |
+
type=str,
|
27 |
+
help="Only run tests which match the given substring expression",
|
28 |
+
)
|
29 |
+
eval_parser.add_argument(
|
30 |
+
"--model-name",
|
31 |
+
type=str,
|
32 |
+
default="gemini-1.5-flash",
|
33 |
+
help="Model name to use for evaluation",
|
34 |
+
)
|
35 |
+
|
36 |
args = parser.parse_args()
|
37 |
|
38 |
if args.command == "run":
|
39 |
+
subprocess.run(
|
40 |
+
[
|
41 |
+
sys.executable,
|
42 |
+
"-m",
|
43 |
+
"streamlit",
|
44 |
+
"run",
|
45 |
+
"app.py",
|
46 |
+
"--server.port",
|
47 |
+
str(args.port),
|
48 |
+
]
|
49 |
+
)
|
50 |
+
|
51 |
+
elif args.command == "evaluate":
|
52 |
+
test_file = (
|
53 |
+
args.test_file + "::" + args.test_case if args.test_case else args.test_file
|
54 |
+
)
|
55 |
+
cmd = [
|
56 |
+
sys.executable,
|
57 |
+
"-m",
|
58 |
+
"pytest",
|
59 |
+
"-s",
|
60 |
+
test_file,
|
61 |
+
"-v",
|
62 |
+
f"--model-name={args.model_name}",
|
63 |
+
]
|
64 |
+
subprocess.run(cmd)
|
65 |
|
66 |
|
67 |
if __name__ == "__main__":
|
medrag_multi_modal/document_loader/image_loader/base_img_loader.py
CHANGED
@@ -1,11 +1,21 @@
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from abc import abstractmethod
|
|
|
4 |
from typing import Dict, List, Optional
|
5 |
|
|
|
6 |
import jsonlines
|
7 |
import rich
|
8 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
11 |
BaseTextLoader,
|
@@ -36,14 +46,72 @@ class BaseImageLoader(BaseTextLoader):
|
|
36 |
"""
|
37 |
pass
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
async def load_data(
|
40 |
self,
|
41 |
start_page: Optional[int] = None,
|
42 |
end_page: Optional[int] = None,
|
43 |
-
|
|
|
44 |
image_save_dir: str = "./images",
|
45 |
exclude_file_extensions: list[str] = [],
|
46 |
-
cleanup: bool = False,
|
47 |
**kwargs,
|
48 |
) -> List[Dict[str, str]]:
|
49 |
"""
|
@@ -65,21 +133,15 @@ class BaseImageLoader(BaseTextLoader):
|
|
65 |
Args:
|
66 |
start_page (Optional[int]): The starting page index (0-based) to process.
|
67 |
end_page (Optional[int]): The ending page index (0-based) to process.
|
68 |
-
|
|
|
69 |
image_save_dir (str): The directory to save the extracted images.
|
70 |
exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
|
71 |
-
cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
|
72 |
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
73 |
|
74 |
Returns:
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
- "page_idx": (int) the index of the page.
|
79 |
-
- "document_name": (str) the name of the document.
|
80 |
-
- "file_path": (str) the local file path where the PDF is stored.
|
81 |
-
- "file_url": (str) the URL of the PDF file.
|
82 |
-
- "image_file_path" or "image_file_paths": (str) the local file path where the image/images are stored.
|
83 |
Raises:
|
84 |
ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
|
85 |
"""
|
@@ -111,19 +173,8 @@ class BaseImageLoader(BaseTextLoader):
|
|
111 |
if file.endswith(tuple(exclude_file_extensions)):
|
112 |
os.remove(os.path.join(image_save_dir, file))
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
)
|
120 |
-
artifact.add_dir(local_path=image_save_dir)
|
121 |
-
artifact.save()
|
122 |
-
rich.print("Artifact saved and uploaded to wandb!")
|
123 |
-
|
124 |
-
if cleanup:
|
125 |
-
for file in os.listdir(image_save_dir):
|
126 |
-
file_path = os.path.join(image_save_dir, file)
|
127 |
-
if os.path.isfile(file_path):
|
128 |
-
os.remove(file_path)
|
129 |
-
return pages
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from abc import abstractmethod
|
4 |
+
from glob import glob
|
5 |
from typing import Dict, List, Optional
|
6 |
|
7 |
+
import huggingface_hub
|
8 |
import jsonlines
|
9 |
import rich
|
10 |
+
from datasets import (
|
11 |
+
Dataset,
|
12 |
+
Features,
|
13 |
+
Image,
|
14 |
+
Sequence,
|
15 |
+
Value,
|
16 |
+
concatenate_datasets,
|
17 |
+
load_dataset,
|
18 |
+
)
|
19 |
|
20 |
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
21 |
BaseTextLoader,
|
|
|
46 |
"""
|
47 |
pass
|
48 |
|
49 |
+
def save_as_dataset(
|
50 |
+
self,
|
51 |
+
start_page: int,
|
52 |
+
end_page: int,
|
53 |
+
image_save_dir: str,
|
54 |
+
dataset_repo_id: Optional[str] = None,
|
55 |
+
overwrite_dataset: bool = False,
|
56 |
+
):
|
57 |
+
features = Features(
|
58 |
+
{
|
59 |
+
"page_image": Image(decode=True),
|
60 |
+
"page_figure_images": Sequence(Image(decode=True)),
|
61 |
+
"document_name": Value(dtype="string"),
|
62 |
+
"page_idx": Value(dtype="int32"),
|
63 |
+
}
|
64 |
+
)
|
65 |
+
|
66 |
+
all_examples = []
|
67 |
+
for page_idx in range(start_page, end_page):
|
68 |
+
page_image_file_paths = glob(
|
69 |
+
os.path.join(image_save_dir, f"page{page_idx}*.png")
|
70 |
+
)
|
71 |
+
if len(page_image_file_paths) > 0:
|
72 |
+
page_image_path = page_image_file_paths[0]
|
73 |
+
figure_image_paths = [
|
74 |
+
image_file_path
|
75 |
+
for image_file_path in glob(
|
76 |
+
os.path.join(image_save_dir, f"page{page_idx}*_fig*.png")
|
77 |
+
)
|
78 |
+
]
|
79 |
+
|
80 |
+
example = {
|
81 |
+
"page_image": page_image_path,
|
82 |
+
"page_figure_images": figure_image_paths,
|
83 |
+
"document_name": self.document_name,
|
84 |
+
"page_idx": page_idx,
|
85 |
+
}
|
86 |
+
all_examples.append(example)
|
87 |
+
|
88 |
+
dataset = Dataset.from_list(all_examples, features=features)
|
89 |
+
|
90 |
+
if dataset_repo_id:
|
91 |
+
if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
|
92 |
+
if not overwrite_dataset:
|
93 |
+
dataset = concatenate_datasets(
|
94 |
+
[dataset, load_dataset(dataset_repo_id)["corpus"]]
|
95 |
+
)
|
96 |
+
|
97 |
+
dataset.push_to_hub(dataset_repo_id, split="corpus")
|
98 |
+
|
99 |
+
return dataset
|
100 |
+
|
101 |
+
def cleanup_image_dir(self, image_save_dir: str = "./images"):
|
102 |
+
for file in os.listdir(image_save_dir):
|
103 |
+
file_path = os.path.join(image_save_dir, file)
|
104 |
+
if os.path.isfile(file_path):
|
105 |
+
os.remove(file_path)
|
106 |
+
|
107 |
async def load_data(
|
108 |
self,
|
109 |
start_page: Optional[int] = None,
|
110 |
end_page: Optional[int] = None,
|
111 |
+
dataset_repo_id: Optional[str] = None,
|
112 |
+
overwrite_dataset: bool = False,
|
113 |
image_save_dir: str = "./images",
|
114 |
exclude_file_extensions: list[str] = [],
|
|
|
115 |
**kwargs,
|
116 |
) -> List[Dict[str, str]]:
|
117 |
"""
|
|
|
133 |
Args:
|
134 |
start_page (Optional[int]): The starting page index (0-based) to process.
|
135 |
end_page (Optional[int]): The ending page index (0-based) to process.
|
136 |
+
dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
|
137 |
+
overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
|
138 |
image_save_dir (str): The directory to save the extracted images.
|
139 |
exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
|
|
|
140 |
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
141 |
|
142 |
Returns:
|
143 |
+
Dataset: A HuggingFace dataset containing the processed pages.
|
144 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
Raises:
|
146 |
ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
|
147 |
"""
|
|
|
173 |
if file.endswith(tuple(exclude_file_extensions)):
|
174 |
os.remove(os.path.join(image_save_dir, file))
|
175 |
|
176 |
+
dataset = self.save_as_dataset(
|
177 |
+
start_page, end_page, image_save_dir, dataset_repo_id, overwrite_dataset
|
178 |
+
)
|
179 |
+
|
180 |
+
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py
CHANGED
@@ -3,9 +3,12 @@ import os
|
|
3 |
from typing import Any, Dict
|
4 |
|
5 |
import fitz
|
|
|
6 |
from PIL import Image, ImageOps, UnidentifiedImageError
|
7 |
|
8 |
-
from .base_img_loader import
|
|
|
|
|
9 |
|
10 |
|
11 |
class FitzPILImageLoader(BaseImageLoader):
|
@@ -20,27 +23,16 @@ class FitzPILImageLoader(BaseImageLoader):
|
|
20 |
```python
|
21 |
import asyncio
|
22 |
|
23 |
-
import weave
|
24 |
-
|
25 |
-
import wandb
|
26 |
from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
31 |
loader = FitzPILImageLoader(
|
32 |
-
url=
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
-
asyncio.run(
|
37 |
-
loader.load_data(
|
38 |
-
start_page=32,
|
39 |
-
end_page=37,
|
40 |
-
wandb_artifact_name="grays-anatomy-images-fitzpil",
|
41 |
-
cleanup=False,
|
42 |
-
)
|
43 |
-
)
|
44 |
```
|
45 |
|
46 |
Args:
|
@@ -118,6 +110,14 @@ class FitzPILImageLoader(BaseImageLoader):
|
|
118 |
|
119 |
pdf_document.close()
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
return {
|
122 |
"page_idx": page_idx,
|
123 |
"document_name": self.document_name,
|
|
|
3 |
from typing import Any, Dict
|
4 |
|
5 |
import fitz
|
6 |
+
from pdf2image.pdf2image import convert_from_path
|
7 |
from PIL import Image, ImageOps, UnidentifiedImageError
|
8 |
|
9 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
10 |
+
BaseImageLoader,
|
11 |
+
)
|
12 |
|
13 |
|
14 |
class FitzPILImageLoader(BaseImageLoader):
|
|
|
23 |
```python
|
24 |
import asyncio
|
25 |
|
|
|
|
|
|
|
26 |
from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
|
27 |
|
28 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
29 |
+
|
|
|
30 |
loader = FitzPILImageLoader(
|
31 |
+
url=URL,
|
32 |
document_name="Gray's Anatomy",
|
33 |
document_file_path="grays_anatomy.pdf",
|
34 |
)
|
35 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
```
|
37 |
|
38 |
Args:
|
|
|
110 |
|
111 |
pdf_document.close()
|
112 |
|
113 |
+
page_image = convert_from_path(
|
114 |
+
self.document_file_path,
|
115 |
+
first_page=page_idx + 1,
|
116 |
+
last_page=page_idx + 1,
|
117 |
+
**kwargs,
|
118 |
+
)[0]
|
119 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
120 |
+
|
121 |
return {
|
122 |
"page_idx": page_idx,
|
123 |
"document_name": self.document_name,
|
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py
CHANGED
@@ -5,7 +5,9 @@ from marker.convert import convert_single_pdf
|
|
5 |
from marker.models import load_all_models
|
6 |
from pdf2image.pdf2image import convert_from_path
|
7 |
|
8 |
-
from .base_img_loader import
|
|
|
|
|
9 |
|
10 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
11 |
|
@@ -22,27 +24,16 @@ class MarkerImageLoader(BaseImageLoader):
|
|
22 |
```python
|
23 |
import asyncio
|
24 |
|
25 |
-
import weave
|
26 |
-
|
27 |
-
import wandb
|
28 |
from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
33 |
loader = MarkerImageLoader(
|
34 |
-
url=
|
35 |
document_name="Gray's Anatomy",
|
36 |
document_file_path="grays_anatomy.pdf",
|
37 |
)
|
38 |
-
asyncio.run(
|
39 |
-
loader.load_data(
|
40 |
-
start_page=31,
|
41 |
-
end_page=36,
|
42 |
-
wandb_artifact_name="grays-anatomy-images-marker",
|
43 |
-
cleanup=False,
|
44 |
-
)
|
45 |
-
)
|
46 |
```
|
47 |
|
48 |
Args:
|
@@ -84,7 +75,7 @@ class MarkerImageLoader(BaseImageLoader):
|
|
84 |
- "file_url": (str) the URL of the PDF file.
|
85 |
- "image_file_path": (str) the local file path where the image is stored.
|
86 |
"""
|
87 |
-
_, images,
|
88 |
self.document_file_path,
|
89 |
self.model_lst,
|
90 |
max_pages=1,
|
@@ -101,14 +92,13 @@ class MarkerImageLoader(BaseImageLoader):
|
|
101 |
image.save(image_file_path, "png")
|
102 |
image_file_paths.append(image_file_path)
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
112 |
|
113 |
return {
|
114 |
"page_idx": page_idx,
|
@@ -116,7 +106,6 @@ class MarkerImageLoader(BaseImageLoader):
|
|
116 |
"file_path": self.document_file_path,
|
117 |
"file_url": self.url,
|
118 |
"image_file_paths": os.path.join(image_save_dir, "*.png"),
|
119 |
-
"meta": out_meta,
|
120 |
}
|
121 |
|
122 |
def load_data(
|
|
|
5 |
from marker.models import load_all_models
|
6 |
from pdf2image.pdf2image import convert_from_path
|
7 |
|
8 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
9 |
+
BaseImageLoader,
|
10 |
+
)
|
11 |
|
12 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
13 |
|
|
|
24 |
```python
|
25 |
import asyncio
|
26 |
|
|
|
|
|
|
|
27 |
from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
|
28 |
|
29 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
30 |
+
|
|
|
31 |
loader = MarkerImageLoader(
|
32 |
+
url=URL,
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
```
|
38 |
|
39 |
Args:
|
|
|
75 |
- "file_url": (str) the URL of the PDF file.
|
76 |
- "image_file_path": (str) the local file path where the image is stored.
|
77 |
"""
|
78 |
+
_, images, _ = convert_single_pdf(
|
79 |
self.document_file_path,
|
80 |
self.model_lst,
|
81 |
max_pages=1,
|
|
|
92 |
image.save(image_file_path, "png")
|
93 |
image_file_paths.append(image_file_path)
|
94 |
|
95 |
+
page_image = convert_from_path(
|
96 |
+
self.document_file_path,
|
97 |
+
first_page=page_idx + 1,
|
98 |
+
last_page=page_idx + 1,
|
99 |
+
**kwargs,
|
100 |
+
)[0]
|
101 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
|
|
102 |
|
103 |
return {
|
104 |
"page_idx": page_idx,
|
|
|
106 |
"file_path": self.document_file_path,
|
107 |
"file_url": self.url,
|
108 |
"image_file_paths": os.path.join(image_save_dir, "*.png"),
|
|
|
109 |
}
|
110 |
|
111 |
def load_data(
|
medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py
CHANGED
@@ -3,7 +3,9 @@ from typing import Any, Dict
|
|
3 |
|
4 |
from pdf2image.pdf2image import convert_from_path
|
5 |
|
6 |
-
from .base_img_loader import
|
|
|
|
|
7 |
|
8 |
|
9 |
class PDF2ImageLoader(BaseImageLoader):
|
@@ -19,27 +21,16 @@ class PDF2ImageLoader(BaseImageLoader):
|
|
19 |
```python
|
20 |
import asyncio
|
21 |
|
22 |
-
import weave
|
23 |
-
|
24 |
-
import wandb
|
25 |
from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
30 |
loader = PDF2ImageLoader(
|
31 |
-
url=
|
32 |
document_name="Gray's Anatomy",
|
33 |
document_file_path="grays_anatomy.pdf",
|
34 |
)
|
35 |
-
asyncio.run(
|
36 |
-
loader.load_data(
|
37 |
-
start_page=31,
|
38 |
-
end_page=36,
|
39 |
-
wandb_artifact_name="grays-anatomy-images-pdf2image",
|
40 |
-
cleanup=False,
|
41 |
-
)
|
42 |
-
)
|
43 |
```
|
44 |
|
45 |
Args:
|
|
|
3 |
|
4 |
from pdf2image.pdf2image import convert_from_path
|
5 |
|
6 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
7 |
+
BaseImageLoader,
|
8 |
+
)
|
9 |
|
10 |
|
11 |
class PDF2ImageLoader(BaseImageLoader):
|
|
|
21 |
```python
|
22 |
import asyncio
|
23 |
|
|
|
|
|
|
|
24 |
from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
|
25 |
|
26 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
27 |
+
|
|
|
28 |
loader = PDF2ImageLoader(
|
29 |
+
url=URL,
|
30 |
document_name="Gray's Anatomy",
|
31 |
document_file_path="grays_anatomy.pdf",
|
32 |
)
|
33 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
```
|
35 |
|
36 |
Args:
|
medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py
CHANGED
@@ -2,8 +2,11 @@ import os
|
|
2 |
from typing import Any, Dict
|
3 |
|
4 |
import pdfplumber
|
|
|
5 |
|
6 |
-
from .base_img_loader import
|
|
|
|
|
7 |
|
8 |
|
9 |
class PDFPlumberImageLoader(BaseImageLoader):
|
@@ -18,27 +21,16 @@ class PDFPlumberImageLoader(BaseImageLoader):
|
|
18 |
```python
|
19 |
import asyncio
|
20 |
|
21 |
-
import weave
|
22 |
-
|
23 |
-
import wandb
|
24 |
from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
29 |
loader = PDFPlumberImageLoader(
|
30 |
-
url=
|
31 |
document_name="Gray's Anatomy",
|
32 |
document_file_path="grays_anatomy.pdf",
|
33 |
)
|
34 |
-
asyncio.run(
|
35 |
-
loader.load_data(
|
36 |
-
start_page=32,
|
37 |
-
end_page=37,
|
38 |
-
wandb_artifact_name="grays-anatomy-images-pdfplumber",
|
39 |
-
cleanup=False,
|
40 |
-
)
|
41 |
-
)
|
42 |
```
|
43 |
|
44 |
Args:
|
@@ -92,6 +84,14 @@ class PDFPlumberImageLoader(BaseImageLoader):
|
|
92 |
extracted_image.save(image_file_path, "png")
|
93 |
image_file_paths.append(image_file_path)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
return {
|
96 |
"page_idx": page_idx,
|
97 |
"document_name": self.document_name,
|
|
|
2 |
from typing import Any, Dict
|
3 |
|
4 |
import pdfplumber
|
5 |
+
from pdf2image.pdf2image import convert_from_path
|
6 |
|
7 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
8 |
+
BaseImageLoader,
|
9 |
+
)
|
10 |
|
11 |
|
12 |
class PDFPlumberImageLoader(BaseImageLoader):
|
|
|
21 |
```python
|
22 |
import asyncio
|
23 |
|
|
|
|
|
|
|
24 |
from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
|
25 |
|
26 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
27 |
+
|
|
|
28 |
loader = PDFPlumberImageLoader(
|
29 |
+
url=URL,
|
30 |
document_name="Gray's Anatomy",
|
31 |
document_file_path="grays_anatomy.pdf",
|
32 |
)
|
33 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
```
|
35 |
|
36 |
Args:
|
|
|
84 |
extracted_image.save(image_file_path, "png")
|
85 |
image_file_paths.append(image_file_path)
|
86 |
|
87 |
+
page_image = convert_from_path(
|
88 |
+
self.document_file_path,
|
89 |
+
first_page=page_idx + 1,
|
90 |
+
last_page=page_idx + 1,
|
91 |
+
**kwargs,
|
92 |
+
)[0]
|
93 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
94 |
+
|
95 |
return {
|
96 |
"page_idx": page_idx,
|
97 |
"document_name": self.document_name,
|
medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py
CHANGED
@@ -3,9 +3,12 @@ import os
|
|
3 |
from typing import Any, Dict
|
4 |
|
5 |
import fitz
|
|
|
6 |
from PIL import Image
|
7 |
|
8 |
-
from .base_img_loader import
|
|
|
|
|
9 |
|
10 |
|
11 |
class PyMuPDFImageLoader(BaseImageLoader):
|
@@ -20,27 +23,16 @@ class PyMuPDFImageLoader(BaseImageLoader):
|
|
20 |
```python
|
21 |
import asyncio
|
22 |
|
23 |
-
import weave
|
24 |
-
|
25 |
-
import wandb
|
26 |
from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
31 |
loader = PyMuPDFImageLoader(
|
32 |
-
url=
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
-
asyncio.run(
|
37 |
-
loader.load_data(
|
38 |
-
start_page=32,
|
39 |
-
end_page=37,
|
40 |
-
wandb_artifact_name="grays-anatomy-images-pymupdf",
|
41 |
-
cleanup=False,
|
42 |
-
)
|
43 |
-
)
|
44 |
```
|
45 |
|
46 |
Args:
|
@@ -115,6 +107,14 @@ class PyMuPDFImageLoader(BaseImageLoader):
|
|
115 |
|
116 |
pdf_document.close()
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
return {
|
119 |
"page_idx": page_idx,
|
120 |
"document_name": self.document_name,
|
|
|
3 |
from typing import Any, Dict
|
4 |
|
5 |
import fitz
|
6 |
+
from pdf2image.pdf2image import convert_from_path
|
7 |
from PIL import Image
|
8 |
|
9 |
+
from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
|
10 |
+
BaseImageLoader,
|
11 |
+
)
|
12 |
|
13 |
|
14 |
class PyMuPDFImageLoader(BaseImageLoader):
|
|
|
23 |
```python
|
24 |
import asyncio
|
25 |
|
|
|
|
|
|
|
26 |
from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
|
27 |
|
28 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
29 |
+
|
|
|
30 |
loader = PyMuPDFImageLoader(
|
31 |
+
url=URL,
|
32 |
document_name="Gray's Anatomy",
|
33 |
document_file_path="grays_anatomy.pdf",
|
34 |
)
|
35 |
+
dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
```
|
37 |
|
38 |
Args:
|
|
|
107 |
|
108 |
pdf_document.close()
|
109 |
|
110 |
+
page_image = convert_from_path(
|
111 |
+
self.document_file_path,
|
112 |
+
first_page=page_idx + 1,
|
113 |
+
last_page=page_idx + 1,
|
114 |
+
**kwargs,
|
115 |
+
)[0]
|
116 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
117 |
+
|
118 |
return {
|
119 |
"page_idx": page_idx,
|
120 |
"document_name": self.document_name,
|
medrag_multi_modal/document_loader/text_loader/base_text_loader.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from abc import ABC, abstractmethod
|
4 |
-
from typing import
|
5 |
|
|
|
6 |
import PyPDF2
|
7 |
-
import
|
8 |
-
import weave
|
9 |
from firerequests import FireRequests
|
|
|
10 |
|
11 |
|
12 |
class BaseTextLoader(ABC):
|
@@ -22,14 +23,22 @@ class BaseTextLoader(ABC):
|
|
22 |
url (str): The URL of the PDF file to download if not present locally.
|
23 |
document_name (str): The name of the document for metadata purposes.
|
24 |
document_file_path (str): The local file path where the PDF is stored or will be downloaded.
|
|
|
25 |
"""
|
26 |
|
27 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
self.url = url
|
29 |
self.document_name = document_name
|
30 |
self.document_file_path = document_file_path
|
|
|
31 |
if not os.path.exists(self.document_file_path):
|
32 |
-
FireRequests().download(url,
|
33 |
with open(self.document_file_path, "rb") as file:
|
34 |
pdf_reader = PyPDF2.PdfReader(file)
|
35 |
self.page_count = len(pdf_reader.pages)
|
@@ -85,9 +94,11 @@ class BaseTextLoader(ABC):
|
|
85 |
self,
|
86 |
start_page: Optional[int] = None,
|
87 |
end_page: Optional[int] = None,
|
88 |
-
|
|
|
|
|
89 |
**kwargs,
|
90 |
-
) ->
|
91 |
"""
|
92 |
Asynchronously loads text from a PDF file specified by a URL or local file path.
|
93 |
The overrided processing abstract method then processes the text into markdown format,
|
@@ -102,23 +113,26 @@ class BaseTextLoader(ABC):
|
|
102 |
each page, extract the text from the PDF, and convert it to markdown.
|
103 |
It processes pages concurrently using `asyncio` for efficiency.
|
104 |
|
105 |
-
If a
|
106 |
|
107 |
Args:
|
108 |
start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
|
109 |
end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
|
110 |
-
|
|
|
|
|
111 |
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
112 |
|
113 |
Returns:
|
114 |
-
|
115 |
-
Each
|
116 |
|
117 |
- "text": (str) the processed page data in markdown format.
|
118 |
- "page_idx": (int) the index of the page.
|
119 |
- "document_name": (str) the name of the document.
|
120 |
- "file_path": (str) the local file path where the PDF is stored.
|
121 |
- "file_url": (str) the URL of the PDF file.
|
|
|
122 |
|
123 |
Raises:
|
124 |
ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
|
@@ -127,21 +141,45 @@ class BaseTextLoader(ABC):
|
|
127 |
pages = []
|
128 |
processed_pages_counter: int = 1
|
129 |
total_pages = end_page - start_page
|
|
|
130 |
|
131 |
async def process_page(page_idx):
|
132 |
nonlocal processed_pages_counter
|
133 |
page_data = await self.extract_page_data(page_idx, **kwargs)
|
134 |
page_data["loader_name"] = self.__class__.__name__
|
|
|
|
|
|
|
135 |
pages.append(page_data)
|
136 |
-
|
137 |
-
|
|
|
|
|
138 |
)
|
139 |
processed_pages_counter += 1
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from abc import ABC, abstractmethod
|
4 |
+
from typing import Any, Dict, Optional
|
5 |
|
6 |
+
import huggingface_hub
|
7 |
import PyPDF2
|
8 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
|
|
9 |
from firerequests import FireRequests
|
10 |
+
from rich.progress import Progress
|
11 |
|
12 |
|
13 |
class BaseTextLoader(ABC):
|
|
|
23 |
url (str): The URL of the PDF file to download if not present locally.
|
24 |
document_name (str): The name of the document for metadata purposes.
|
25 |
document_file_path (str): The local file path where the PDF is stored or will be downloaded.
|
26 |
+
metadata (Optional[dict[str, any]]): Additional metadata to be added to each row of the dataset.
|
27 |
"""
|
28 |
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
url: str,
|
32 |
+
document_name: str,
|
33 |
+
document_file_path: str,
|
34 |
+
metadata: Optional[dict[str, Any]] = None,
|
35 |
+
):
|
36 |
self.url = url
|
37 |
self.document_name = document_name
|
38 |
self.document_file_path = document_file_path
|
39 |
+
self.metadata = metadata or {}
|
40 |
if not os.path.exists(self.document_file_path):
|
41 |
+
FireRequests().download(url, filenames=self.document_file_path)
|
42 |
with open(self.document_file_path, "rb") as file:
|
43 |
pdf_reader = PyPDF2.PdfReader(file)
|
44 |
self.page_count = len(pdf_reader.pages)
|
|
|
94 |
self,
|
95 |
start_page: Optional[int] = None,
|
96 |
end_page: Optional[int] = None,
|
97 |
+
exclude_pages: Optional[list[int]] = None,
|
98 |
+
dataset_repo_id: Optional[str] = None,
|
99 |
+
overwrite_dataset: bool = False,
|
100 |
**kwargs,
|
101 |
+
) -> Dataset:
|
102 |
"""
|
103 |
Asynchronously loads text from a PDF file specified by a URL or local file path.
|
104 |
The overrided processing abstract method then processes the text into markdown format,
|
|
|
113 |
each page, extract the text from the PDF, and convert it to markdown.
|
114 |
It processes pages concurrently using `asyncio` for efficiency.
|
115 |
|
116 |
+
If a `dataset_repo_id` is provided, the processed pages are published to a HuggingFace dataset.
|
117 |
|
118 |
Args:
|
119 |
start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
|
120 |
end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
|
121 |
+
exclude_pages (Optional[list[int]]): The list of page indices to exclude from processing.
|
122 |
+
dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
|
123 |
+
overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
|
124 |
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
125 |
|
126 |
Returns:
|
127 |
+
Dataset: A HuggingFace Dataset object containing the text and metadata for processed pages.
|
128 |
+
Each entry in the dataset will have the following keys and values:
|
129 |
|
130 |
- "text": (str) the processed page data in markdown format.
|
131 |
- "page_idx": (int) the index of the page.
|
132 |
- "document_name": (str) the name of the document.
|
133 |
- "file_path": (str) the local file path where the PDF is stored.
|
134 |
- "file_url": (str) the URL of the PDF file.
|
135 |
+
- "loader_name": (str) the name of the loader class used to process the page.
|
136 |
|
137 |
Raises:
|
138 |
ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
|
|
|
141 |
pages = []
|
142 |
processed_pages_counter: int = 1
|
143 |
total_pages = end_page - start_page
|
144 |
+
exclude_pages = exclude_pages or []
|
145 |
|
146 |
async def process_page(page_idx):
|
147 |
nonlocal processed_pages_counter
|
148 |
page_data = await self.extract_page_data(page_idx, **kwargs)
|
149 |
page_data["loader_name"] = self.__class__.__name__
|
150 |
+
for key, value in self.metadata.items():
|
151 |
+
if key not in page_data:
|
152 |
+
page_data[key] = value
|
153 |
pages.append(page_data)
|
154 |
+
progress.update(
|
155 |
+
task_id,
|
156 |
+
advance=1,
|
157 |
+
description=f"Loading page {page_idx} using {self.__class__.__name__}",
|
158 |
)
|
159 |
processed_pages_counter += 1
|
160 |
|
161 |
+
progress = Progress()
|
162 |
+
with progress:
|
163 |
+
task_id = progress.add_task("Starting...", total=total_pages)
|
164 |
+
tasks = [
|
165 |
+
process_page(page_idx)
|
166 |
+
for page_idx in range(start_page, end_page + 1)
|
167 |
+
if page_idx not in exclude_pages
|
168 |
+
]
|
169 |
+
for task in asyncio.as_completed(tasks):
|
170 |
+
await task
|
171 |
+
|
172 |
+
pages.sort(key=lambda x: x["page_idx"])
|
173 |
+
|
174 |
+
dataset = Dataset.from_list(pages)
|
175 |
+
if dataset_repo_id:
|
176 |
+
if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
|
177 |
+
print("Dataset already exists")
|
178 |
+
if not overwrite_dataset:
|
179 |
+
print("Not overwriting dataset")
|
180 |
+
dataset = concatenate_datasets(
|
181 |
+
[dataset, load_dataset(dataset_repo_id, split="corpus")]
|
182 |
+
)
|
183 |
+
dataset.push_to_hub(repo_id=dataset_repo_id, split="corpus", private=False)
|
184 |
+
|
185 |
+
return dataset
|
medrag_multi_modal/document_loader/text_loader/marker_text_loader.py
CHANGED
@@ -4,7 +4,9 @@ from typing import Dict
|
|
4 |
from marker.convert import convert_single_pdf
|
5 |
from marker.models import load_all_models
|
6 |
|
7 |
-
from .base_text_loader import
|
|
|
|
|
8 |
|
9 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
10 |
|
@@ -26,24 +28,16 @@ class MarkerTextLoader(BaseTextLoader):
|
|
26 |
```python
|
27 |
import asyncio
|
28 |
|
29 |
-
import
|
30 |
|
31 |
-
|
32 |
|
33 |
-
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
34 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
35 |
loader = MarkerTextLoader(
|
36 |
-
url=
|
37 |
document_name="Gray's Anatomy",
|
38 |
document_file_path="grays_anatomy.pdf",
|
39 |
)
|
40 |
-
asyncio.run(
|
41 |
-
loader.load_data(
|
42 |
-
start_page=31,
|
43 |
-
end_page=36,
|
44 |
-
weave_dataset_name="grays-anatomy-text",
|
45 |
-
)
|
46 |
-
)
|
47 |
```
|
48 |
|
49 |
Args:
|
@@ -76,7 +70,7 @@ class MarkerTextLoader(BaseTextLoader):
|
|
76 |
"""
|
77 |
model_lst = load_all_models()
|
78 |
|
79 |
-
text, _,
|
80 |
self.document_file_path,
|
81 |
model_lst,
|
82 |
max_pages=1,
|
@@ -92,5 +86,4 @@ class MarkerTextLoader(BaseTextLoader):
|
|
92 |
"document_name": self.document_name,
|
93 |
"file_path": self.document_file_path,
|
94 |
"file_url": self.url,
|
95 |
-
"meta": out_meta,
|
96 |
}
|
|
|
4 |
from marker.convert import convert_single_pdf
|
5 |
from marker.models import load_all_models
|
6 |
|
7 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
8 |
+
BaseTextLoader,
|
9 |
+
)
|
10 |
|
11 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
12 |
|
|
|
28 |
```python
|
29 |
import asyncio
|
30 |
|
31 |
+
from medrag_multi_modal.document_loader import MarkerTextLoader
|
32 |
|
33 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
34 |
|
|
|
|
|
35 |
loader = MarkerTextLoader(
|
36 |
+
url=URL,
|
37 |
document_name="Gray's Anatomy",
|
38 |
document_file_path="grays_anatomy.pdf",
|
39 |
)
|
40 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
```
|
42 |
|
43 |
Args:
|
|
|
70 |
"""
|
71 |
model_lst = load_all_models()
|
72 |
|
73 |
+
text, _, _ = convert_single_pdf(
|
74 |
self.document_file_path,
|
75 |
model_lst,
|
76 |
max_pages=1,
|
|
|
86 |
"document_name": self.document_name,
|
87 |
"file_path": self.document_file_path,
|
88 |
"file_url": self.url,
|
|
|
89 |
}
|
medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py
CHANGED
@@ -2,7 +2,9 @@ from typing import Dict
|
|
2 |
|
3 |
import pdfplumber
|
4 |
|
5 |
-
from .base_text_loader import
|
|
|
|
|
6 |
|
7 |
|
8 |
class PDFPlumberTextLoader(BaseTextLoader):
|
@@ -22,24 +24,16 @@ class PDFPlumberTextLoader(BaseTextLoader):
|
|
22 |
```python
|
23 |
import asyncio
|
24 |
|
25 |
-
import
|
26 |
|
27 |
-
|
28 |
|
29 |
-
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
30 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
31 |
loader = PDFPlumberTextLoader(
|
32 |
-
url=
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
-
asyncio.run(
|
37 |
-
loader.load_data(
|
38 |
-
start_page=31,
|
39 |
-
end_page=36,
|
40 |
-
weave_dataset_name="grays-anatomy-text",
|
41 |
-
)
|
42 |
-
)
|
43 |
```
|
44 |
|
45 |
Args:
|
|
|
2 |
|
3 |
import pdfplumber
|
4 |
|
5 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
6 |
+
BaseTextLoader,
|
7 |
+
)
|
8 |
|
9 |
|
10 |
class PDFPlumberTextLoader(BaseTextLoader):
|
|
|
24 |
```python
|
25 |
import asyncio
|
26 |
|
27 |
+
from medrag_multi_modal.document_loader import PDFPlumberTextLoader
|
28 |
|
29 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
30 |
|
|
|
|
|
31 |
loader = PDFPlumberTextLoader(
|
32 |
+
url=URL,
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
```
|
38 |
|
39 |
Args:
|
medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py
CHANGED
@@ -2,7 +2,9 @@ from typing import Dict
|
|
2 |
|
3 |
import pymupdf4llm
|
4 |
|
5 |
-
from .base_text_loader import
|
|
|
|
|
6 |
|
7 |
|
8 |
class PyMuPDF4LLMTextLoader(BaseTextLoader):
|
@@ -20,26 +22,16 @@ class PyMuPDF4LLMTextLoader(BaseTextLoader):
|
|
20 |
```python
|
21 |
import asyncio
|
22 |
|
23 |
-
import
|
24 |
|
25 |
-
|
26 |
-
PyMuPDF4LLMTextLoader
|
27 |
-
)
|
28 |
|
29 |
-
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
30 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
31 |
loader = PyMuPDF4LLMTextLoader(
|
32 |
-
url=
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
-
asyncio.run(
|
37 |
-
loader.load_data(
|
38 |
-
start_page=31,
|
39 |
-
end_page=36,
|
40 |
-
weave_dataset_name="grays-anatomy-text",
|
41 |
-
)
|
42 |
-
)
|
43 |
```
|
44 |
|
45 |
Args:
|
|
|
2 |
|
3 |
import pymupdf4llm
|
4 |
|
5 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
6 |
+
BaseTextLoader,
|
7 |
+
)
|
8 |
|
9 |
|
10 |
class PyMuPDF4LLMTextLoader(BaseTextLoader):
|
|
|
22 |
```python
|
23 |
import asyncio
|
24 |
|
25 |
+
from medrag_multi_modal.document_loader import PyMuPDF4LLMTextLoader
|
26 |
|
27 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
|
|
|
|
28 |
|
|
|
|
|
29 |
loader = PyMuPDF4LLMTextLoader(
|
30 |
+
url=URL,
|
31 |
document_name="Gray's Anatomy",
|
32 |
document_file_path="grays_anatomy.pdf",
|
33 |
)
|
34 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
```
|
36 |
|
37 |
Args:
|
medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py
CHANGED
@@ -2,7 +2,9 @@ from typing import Dict
|
|
2 |
|
3 |
import PyPDF2
|
4 |
|
5 |
-
from .base_text_loader import
|
|
|
|
|
6 |
|
7 |
|
8 |
class PyPDF2TextLoader(BaseTextLoader):
|
@@ -22,24 +24,16 @@ class PyPDF2TextLoader(BaseTextLoader):
|
|
22 |
```python
|
23 |
import asyncio
|
24 |
|
25 |
-
import
|
26 |
|
27 |
-
|
28 |
|
29 |
-
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
30 |
-
url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
31 |
loader = PyPDF2TextLoader(
|
32 |
-
url=
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
-
asyncio.run(
|
37 |
-
loader.load_data(
|
38 |
-
start_page=31,
|
39 |
-
end_page=36,
|
40 |
-
weave_dataset_name="grays-anatomy-text",
|
41 |
-
)
|
42 |
-
)
|
43 |
```
|
44 |
|
45 |
Args:
|
|
|
2 |
|
3 |
import PyPDF2
|
4 |
|
5 |
+
from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
|
6 |
+
BaseTextLoader,
|
7 |
+
)
|
8 |
|
9 |
|
10 |
class PyPDF2TextLoader(BaseTextLoader):
|
|
|
24 |
```python
|
25 |
import asyncio
|
26 |
|
27 |
+
from medrag_multi_modal.document_loader import PyPDF2TextLoader
|
28 |
|
29 |
+
URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
|
30 |
|
|
|
|
|
31 |
loader = PyPDF2TextLoader(
|
32 |
+
url=URL,
|
33 |
document_name="Gray's Anatomy",
|
34 |
document_file_path="grays_anatomy.pdf",
|
35 |
)
|
36 |
+
dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
```
|
38 |
|
39 |
Args:
|
medrag_multi_modal/metrics/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .mmlu import MMLUOptionAccuracy
|
2 |
+
|
3 |
+
__all__ = ["MMLUOptionAccuracy"]
|
medrag_multi_modal/metrics/base.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import weave
|
5 |
+
|
6 |
+
|
7 |
+
class BaseAccuracyMetric(weave.Scorer):
|
8 |
+
"""
|
9 |
+
BaseAccuracyMetric is a class that extends the
|
10 |
+
[`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers)
|
11 |
+
to provide a comprehensive evaluation of accuracy metrics for a given set of score rows.
|
12 |
+
|
13 |
+
This class is designed to process a list of score rows, each containing a
|
14 |
+
'correct' key that indicates whether a particular prediction was correct.
|
15 |
+
The `summarize` method calculates various statistical measures and metrics
|
16 |
+
based on this data, including:
|
17 |
+
|
18 |
+
- True and false counts: The number of true and false predictions.
|
19 |
+
- True and false fractions: The proportion of true and false predictions.
|
20 |
+
- Standard error: The standard error of the mean for the true predictions.
|
21 |
+
- Precision: The ratio of true positive predictions to the total number of
|
22 |
+
positive predictions.
|
23 |
+
- Recall: The ratio of true positive predictions to the total number of
|
24 |
+
actual positives.
|
25 |
+
- F1 Score: The harmonic mean of precision and recall, providing a balance
|
26 |
+
between the two metrics.
|
27 |
+
|
28 |
+
The `summarize` method returns a dictionary containing these metrics,
|
29 |
+
allowing for a detailed analysis of the model's performance.
|
30 |
+
|
31 |
+
Methods:
|
32 |
+
summarize(score_rows: list) -> Optional[dict]:
|
33 |
+
Processes the input score rows to compute and return a dictionary
|
34 |
+
of accuracy metrics.
|
35 |
+
"""
|
36 |
+
@weave.op()
|
37 |
+
def summarize(self, score_rows: list) -> Optional[dict]:
|
38 |
+
"""
|
39 |
+
Summarizes the accuracy metrics from a list of score rows.
|
40 |
+
|
41 |
+
This method processes a list of score rows, each containing a 'correct' key
|
42 |
+
that indicates whether a particular prediction was correct. It calculates
|
43 |
+
various statistical measures and metrics based on this data, including:
|
44 |
+
|
45 |
+
- True and false counts: The number of true and false predictions.
|
46 |
+
- True and false fractions: The proportion of true and false predictions.
|
47 |
+
- Standard error: The standard error of the mean for the true predictions.
|
48 |
+
- Precision: The ratio of true positive predictions to the total number of
|
49 |
+
positive predictions.
|
50 |
+
- Recall: The ratio of true positive predictions to the total number of
|
51 |
+
actual positives.
|
52 |
+
- F1 Score: The harmonic mean of precision and recall, providing a balance
|
53 |
+
between the two metrics.
|
54 |
+
|
55 |
+
The method returns a dictionary containing these metrics, allowing for a
|
56 |
+
detailed analysis of the model's performance.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
score_rows (list): A list of dictionaries, each containing a 'correct'
|
60 |
+
key with a boolean value indicating the correctness of a prediction.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Optional[dict]: A dictionary containing the calculated accuracy metrics,
|
64 |
+
or None if the input list is empty.
|
65 |
+
"""
|
66 |
+
valid_data = [
|
67 |
+
x.get("correct") for x in score_rows if x.get("correct") is not None
|
68 |
+
]
|
69 |
+
count_true = list(valid_data).count(True)
|
70 |
+
int_data = [int(x) for x in valid_data]
|
71 |
+
|
72 |
+
sample_mean = np.mean(int_data) if int_data else 0
|
73 |
+
sample_variance = np.var(int_data) if int_data else 0
|
74 |
+
sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0
|
75 |
+
|
76 |
+
# Calculate precision, recall, and F1 score
|
77 |
+
true_positives = count_true
|
78 |
+
false_positives = len(valid_data) - count_true
|
79 |
+
false_negatives = len(score_rows) - len(valid_data)
|
80 |
+
|
81 |
+
precision = (
|
82 |
+
true_positives / (true_positives + false_positives)
|
83 |
+
if (true_positives + false_positives) > 0
|
84 |
+
else 0
|
85 |
+
)
|
86 |
+
recall = (
|
87 |
+
true_positives / (true_positives + false_negatives)
|
88 |
+
if (true_positives + false_negatives) > 0
|
89 |
+
else 0
|
90 |
+
)
|
91 |
+
f1_score = (
|
92 |
+
(2 * precision * recall) / (precision + recall)
|
93 |
+
if (precision + recall) > 0
|
94 |
+
else 0
|
95 |
+
)
|
96 |
+
|
97 |
+
return {
|
98 |
+
"correct": {
|
99 |
+
"true_count": count_true,
|
100 |
+
"false_count": len(score_rows) - count_true,
|
101 |
+
"true_fraction": float(sample_mean),
|
102 |
+
"false_fraction": 1.0 - float(sample_mean),
|
103 |
+
"stderr": float(sample_error),
|
104 |
+
"precision": precision,
|
105 |
+
"recall": recall,
|
106 |
+
"f1_score": f1_score,
|
107 |
+
}
|
108 |
+
}
|
medrag_multi_modal/metrics/mmlu.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import weave
|
2 |
+
|
3 |
+
from medrag_multi_modal.assistant.schema import MedQAResponse
|
4 |
+
from medrag_multi_modal.metrics.base import BaseAccuracyMetric
|
5 |
+
|
6 |
+
|
7 |
+
class MMLUOptionAccuracy(BaseAccuracyMetric):
|
8 |
+
"""
|
9 |
+
MMLUOptionAccuracy is a metric class that inherits from `BaseAccuracyMetric`.
|
10 |
+
|
11 |
+
This class is designed to evaluate the accuracy of a multiple-choice question
|
12 |
+
response by comparing the provided answer with the correct answer from the
|
13 |
+
given options. It uses the MedQAResponse schema to extract the response
|
14 |
+
and checks if it matches the correct answer.
|
15 |
+
|
16 |
+
Methods:
|
17 |
+
--------
|
18 |
+
score(output: MedQAResponse, options: list[str], answer: str) -> dict:
|
19 |
+
Compares the provided answer with the correct answer and returns a
|
20 |
+
dictionary indicating whether the answer is correct.
|
21 |
+
"""
|
22 |
+
@weave.op()
|
23 |
+
def score(self, output: MedQAResponse, options: list[str], answer: str):
|
24 |
+
return {"correct": options[answer] == output.response.answer}
|
medrag_multi_modal/retrieval/__init__.py
CHANGED
@@ -1,15 +1,3 @@
|
|
1 |
-
from .bm25s_retrieval import BM25sRetriever
|
2 |
from .colpali_retrieval import CalPaliRetriever
|
3 |
-
from .common import SimilarityMetric
|
4 |
-
from .contriever_retrieval import ContrieverRetriever
|
5 |
-
from .medcpt_retrieval import MedCPTRetriever
|
6 |
-
from .nv_embed_2 import NVEmbed2Retriever
|
7 |
|
8 |
-
__all__ = [
|
9 |
-
"CalPaliRetriever",
|
10 |
-
"BM25sRetriever",
|
11 |
-
"ContrieverRetriever",
|
12 |
-
"SimilarityMetric",
|
13 |
-
"MedCPTRetriever",
|
14 |
-
"NVEmbed2Retriever",
|
15 |
-
]
|
|
|
|
|
1 |
from .colpali_retrieval import CalPaliRetriever
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
__all__ = ["CalPaliRetriever"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrag_multi_modal/retrieval/colpali_retrieval.py
CHANGED
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
|
9 |
import wandb
|
10 |
from PIL import Image
|
11 |
|
12 |
-
from
|
13 |
|
14 |
|
15 |
class CalPaliRetriever(weave.Model):
|
|
|
9 |
import wandb
|
10 |
from PIL import Image
|
11 |
|
12 |
+
from medrag_multi_modal.utils import get_wandb_artifact
|
13 |
|
14 |
|
15 |
class CalPaliRetriever(weave.Model):
|
medrag_multi_modal/retrieval/common.py
CHANGED
@@ -1,10 +1,5 @@
|
|
1 |
from enum import Enum
|
2 |
|
3 |
-
import safetensors
|
4 |
-
import safetensors.torch
|
5 |
-
import torch
|
6 |
-
import wandb
|
7 |
-
|
8 |
|
9 |
class SimilarityMetric(Enum):
|
10 |
COSINE = "cosine"
|
@@ -24,21 +19,3 @@ def argsort_scores(scores: list[float], descending: bool = False):
|
|
24 |
list(enumerate(scores)), key=lambda x: x[1], reverse=descending
|
25 |
)
|
26 |
]
|
27 |
-
|
28 |
-
|
29 |
-
def save_vector_index(
|
30 |
-
vector_index: torch.Tensor,
|
31 |
-
type: str,
|
32 |
-
index_name: str,
|
33 |
-
metadata: dict,
|
34 |
-
filename: str = "vector_index.safetensors",
|
35 |
-
):
|
36 |
-
safetensors.torch.save_file({"vector_index": vector_index.cpu()}, filename)
|
37 |
-
if wandb.run:
|
38 |
-
artifact = wandb.Artifact(
|
39 |
-
name=index_name,
|
40 |
-
type=type,
|
41 |
-
metadata=metadata,
|
42 |
-
)
|
43 |
-
artifact.add_file(filename)
|
44 |
-
artifact.save()
|
|
|
1 |
from enum import Enum
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class SimilarityMetric(Enum):
|
5 |
COSINE = "cosine"
|
|
|
19 |
list(enumerate(scores)), key=lambda x: x[1], reverse=descending
|
20 |
)
|
21 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
medrag_multi_modal/retrieval/text_retrieval/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bm25s_retrieval import BM25sRetriever
|
2 |
+
from .contriever_retrieval import ContrieverRetriever
|
3 |
+
from .medcpt_retrieval import MedCPTRetriever
|
4 |
+
from .nv_embed_2 import NVEmbed2Retriever
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"BM25sRetriever",
|
8 |
+
"ContrieverRetriever",
|
9 |
+
"MedCPTRetriever",
|
10 |
+
"NVEmbed2Retriever",
|
11 |
+
]
|
medrag_multi_modal/retrieval/{bm25s_retrieval.py → text_retrieval/bm25s_retrieval.py}
RENAMED
@@ -1,12 +1,17 @@
|
|
|
|
1 |
import os
|
2 |
-
|
3 |
-
from typing import Optional
|
4 |
|
5 |
import bm25s
|
6 |
-
import
|
7 |
import weave
|
|
|
|
|
8 |
from Stemmer import Stemmer
|
9 |
|
|
|
|
|
10 |
LANGUAGE_DICT = {
|
11 |
"english": "en",
|
12 |
"french": "fr",
|
@@ -26,49 +31,60 @@ class BM25sRetriever(weave.Model):
|
|
26 |
a new instance is created.
|
27 |
"""
|
28 |
|
29 |
-
language: str
|
30 |
-
use_stemmer: bool
|
31 |
-
_retriever: Optional[
|
32 |
|
33 |
def __init__(
|
34 |
self,
|
35 |
language: str = "english",
|
36 |
use_stemmer: bool = True,
|
37 |
-
retriever: Optional[
|
38 |
):
|
39 |
super().__init__(language=language, use_stemmer=use_stemmer)
|
40 |
-
self._retriever = retriever or
|
41 |
|
42 |
-
def index(
|
|
|
|
|
|
|
|
|
|
|
43 |
"""
|
44 |
Indexes a dataset of text chunks using the BM25 algorithm.
|
45 |
|
46 |
-
This
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
!!! example "Example Usage"
|
52 |
```python
|
53 |
import weave
|
54 |
from dotenv import load_dotenv
|
55 |
|
56 |
-
import
|
57 |
-
from medrag_multi_modal.retrieval import BM25sRetriever
|
58 |
|
59 |
load_dotenv()
|
60 |
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
61 |
-
wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="bm25s-index")
|
62 |
retriever = BM25sRetriever()
|
63 |
-
retriever.index(
|
|
|
|
|
|
|
64 |
```
|
65 |
|
66 |
Args:
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
"""
|
71 |
-
chunk_dataset =
|
|
|
|
|
|
|
|
|
72 |
corpus = [row["text"] for row in chunk_dataset]
|
73 |
corpus_tokens = bm25s.tokenize(
|
74 |
corpus,
|
@@ -76,28 +92,40 @@ class BM25sRetriever(weave.Model):
|
|
76 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
77 |
)
|
78 |
self._retriever.index(corpus_tokens)
|
79 |
-
if
|
|
|
|
|
80 |
self._retriever.save(
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
)
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
type="bm25s-index",
|
87 |
-
metadata={
|
88 |
"language": self.language,
|
89 |
"use_stemmer": self.use_stemmer,
|
90 |
},
|
|
|
|
|
91 |
)
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
@classmethod
|
96 |
-
def
|
97 |
"""
|
98 |
-
Creates an instance of the class from a
|
99 |
|
100 |
-
This class method retrieves a BM25 index artifact from
|
101 |
downloads the artifact, and loads the BM25 retriever with the index and its
|
102 |
associated corpus. The method also extracts metadata from the artifact to
|
103 |
initialize the class instance with the appropriate language and stemming
|
@@ -108,41 +136,26 @@ class BM25sRetriever(weave.Model):
|
|
108 |
import weave
|
109 |
from dotenv import load_dotenv
|
110 |
|
111 |
-
from medrag_multi_modal.retrieval import BM25sRetriever
|
112 |
|
113 |
load_dotenv()
|
114 |
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
115 |
-
retriever = BM25sRetriever
|
116 |
-
|
117 |
-
)
|
118 |
```
|
119 |
|
120 |
Args:
|
121 |
-
|
122 |
-
containing the BM25 index.
|
123 |
|
124 |
Returns:
|
125 |
An instance of the class initialized with the BM25 retriever and metadata
|
126 |
from the artifact.
|
127 |
"""
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
)
|
132 |
-
|
133 |
-
else:
|
134 |
-
api = wandb.Api()
|
135 |
-
artifact = api.artifact(index_artifact_address)
|
136 |
-
artifact_dir = artifact.download()
|
137 |
-
retriever = bm25s.BM25.load(
|
138 |
-
glob(os.path.join(artifact_dir, "*"))[0], load_corpus=True
|
139 |
-
)
|
140 |
-
metadata = artifact.metadata
|
141 |
-
return cls(
|
142 |
-
language=metadata["language"],
|
143 |
-
use_stemmer=metadata["use_stemmer"],
|
144 |
-
retriever=retriever,
|
145 |
-
)
|
146 |
|
147 |
@weave.op()
|
148 |
def retrieve(self, query: str, top_k: int = 2):
|
@@ -155,6 +168,20 @@ class BM25sRetriever(weave.Model):
|
|
155 |
The results are returned as a list of dictionaries, each containing a chunk and
|
156 |
its corresponding relevance score.
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
Args:
|
159 |
query (str): The input query string to search for relevant chunks.
|
160 |
top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
|
@@ -192,13 +219,12 @@ class BM25sRetriever(weave.Model):
|
|
192 |
import weave
|
193 |
from dotenv import load_dotenv
|
194 |
|
195 |
-
from medrag_multi_modal.retrieval import BM25sRetriever
|
196 |
|
197 |
load_dotenv()
|
198 |
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
199 |
-
retriever = BM25sRetriever
|
200 |
-
|
201 |
-
)
|
202 |
retrieved_chunks = retriever.predict(query="What are Ribosomes?")
|
203 |
```
|
204 |
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
+
import shutil
|
4 |
+
from typing import Optional, Union
|
5 |
|
6 |
import bm25s
|
7 |
+
import huggingface_hub
|
8 |
import weave
|
9 |
+
from bm25s import BM25
|
10 |
+
from datasets import Dataset, load_dataset
|
11 |
from Stemmer import Stemmer
|
12 |
|
13 |
+
from medrag_multi_modal.utils import fetch_from_huggingface, save_to_huggingface
|
14 |
+
|
15 |
LANGUAGE_DICT = {
|
16 |
"english": "en",
|
17 |
"french": "fr",
|
|
|
31 |
a new instance is created.
|
32 |
"""
|
33 |
|
34 |
+
language: Optional[str]
|
35 |
+
use_stemmer: bool = True
|
36 |
+
_retriever: Optional[BM25]
|
37 |
|
38 |
def __init__(
|
39 |
self,
|
40 |
language: str = "english",
|
41 |
use_stemmer: bool = True,
|
42 |
+
retriever: Optional[BM25] = None,
|
43 |
):
|
44 |
super().__init__(language=language, use_stemmer=use_stemmer)
|
45 |
+
self._retriever = retriever or BM25()
|
46 |
|
47 |
+
def index(
|
48 |
+
self,
|
49 |
+
chunk_dataset: Union[Dataset, str],
|
50 |
+
index_repo_id: Optional[str] = None,
|
51 |
+
cleanup: bool = True,
|
52 |
+
):
|
53 |
"""
|
54 |
Indexes a dataset of text chunks using the BM25 algorithm.
|
55 |
|
56 |
+
This method retrieves a dataset of text chunks from a specified source, tokenizes
|
57 |
+
the text using the BM25 tokenizer with optional stemming, and indexes the tokenized
|
58 |
+
text using the BM25 retriever. If an `index_repo_id` is provided, the index is saved
|
59 |
+
to disk and optionally logged as a Huggingface artifact.
|
60 |
|
61 |
!!! example "Example Usage"
|
62 |
```python
|
63 |
import weave
|
64 |
from dotenv import load_dotenv
|
65 |
|
66 |
+
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
|
|
|
67 |
|
68 |
load_dotenv()
|
69 |
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
|
|
70 |
retriever = BM25sRetriever()
|
71 |
+
retriever.index(
|
72 |
+
chunk_dataset="geekyrakshit/grays-anatomy-chunks-test",
|
73 |
+
index_repo_id="geekyrakshit/grays-anatomy-index",
|
74 |
+
)
|
75 |
```
|
76 |
|
77 |
Args:
|
78 |
+
chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
|
79 |
+
dataset repository name or a dataset object can be provided.
|
80 |
+
index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
|
81 |
+
cleanup (bool, optional): Whether to delete the local index directory after saving the vector index.
|
82 |
"""
|
83 |
+
chunk_dataset = (
|
84 |
+
load_dataset(chunk_dataset, split="chunks")
|
85 |
+
if isinstance(chunk_dataset, str)
|
86 |
+
else chunk_dataset
|
87 |
+
)
|
88 |
corpus = [row["text"] for row in chunk_dataset]
|
89 |
corpus_tokens = bm25s.tokenize(
|
90 |
corpus,
|
|
|
92 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
93 |
)
|
94 |
self._retriever.index(corpus_tokens)
|
95 |
+
if index_repo_id:
|
96 |
+
os.makedirs(".huggingface", exist_ok=True)
|
97 |
+
index_save_dir = os.path.join(".huggingface", index_repo_id.split("/")[-1])
|
98 |
self._retriever.save(
|
99 |
+
index_save_dir, corpus=[dict(row) for row in chunk_dataset]
|
100 |
+
)
|
101 |
+
commit_type = (
|
102 |
+
"update"
|
103 |
+
if huggingface_hub.repo_exists(index_repo_id, repo_type="model")
|
104 |
+
else "add"
|
105 |
)
|
106 |
+
with open(os.path.join(index_save_dir, "config.json"), "w") as config_file:
|
107 |
+
json.dump(
|
108 |
+
{
|
|
|
|
|
109 |
"language": self.language,
|
110 |
"use_stemmer": self.use_stemmer,
|
111 |
},
|
112 |
+
config_file,
|
113 |
+
indent=4,
|
114 |
)
|
115 |
+
save_to_huggingface(
|
116 |
+
index_repo_id,
|
117 |
+
index_save_dir,
|
118 |
+
commit_message=f"{commit_type}: BM25s index",
|
119 |
+
)
|
120 |
+
if cleanup:
|
121 |
+
shutil.rmtree(index_save_dir)
|
122 |
|
123 |
@classmethod
|
124 |
+
def from_index(cls, index_repo_id: str):
|
125 |
"""
|
126 |
+
Creates an instance of the class from a Huggingface repository.
|
127 |
|
128 |
+
This class method retrieves a BM25 index artifact from a Huggingface repository,
|
129 |
downloads the artifact, and loads the BM25 retriever with the index and its
|
130 |
associated corpus. The method also extracts metadata from the artifact to
|
131 |
initialize the class instance with the appropriate language and stemming
|
|
|
136 |
import weave
|
137 |
from dotenv import load_dotenv
|
138 |
|
139 |
+
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
|
140 |
|
141 |
load_dotenv()
|
142 |
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
143 |
+
retriever = BM25sRetriever()
|
144 |
+
retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
|
|
|
145 |
```
|
146 |
|
147 |
Args:
|
148 |
+
index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
|
|
|
149 |
|
150 |
Returns:
|
151 |
An instance of the class initialized with the BM25 retriever and metadata
|
152 |
from the artifact.
|
153 |
"""
|
154 |
+
index_dir = fetch_from_huggingface(index_repo_id, ".huggingface")
|
155 |
+
retriever = bm25s.BM25.load(index_dir, load_corpus=True)
|
156 |
+
with open(os.path.join(index_dir, "config.json"), "r") as config_file:
|
157 |
+
config = json.load(config_file)
|
158 |
+
return cls(retriever=retriever, **config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
@weave.op()
|
161 |
def retrieve(self, query: str, top_k: int = 2):
|
|
|
168 |
The results are returned as a list of dictionaries, each containing a chunk and
|
169 |
its corresponding relevance score.
|
170 |
|
171 |
+
!!! example "Example Usage"
|
172 |
+
```python
|
173 |
+
import weave
|
174 |
+
from dotenv import load_dotenv
|
175 |
+
|
176 |
+
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
|
177 |
+
|
178 |
+
load_dotenv()
|
179 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
180 |
+
retriever = BM25sRetriever()
|
181 |
+
retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
|
182 |
+
retrieved_chunks = retriever.retrieve(query="What are Ribosomes?")
|
183 |
+
```
|
184 |
+
|
185 |
Args:
|
186 |
query (str): The input query string to search for relevant chunks.
|
187 |
top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
|
|
|
219 |
import weave
|
220 |
from dotenv import load_dotenv
|
221 |
|
222 |
+
from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
|
223 |
|
224 |
load_dotenv()
|
225 |
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
226 |
+
retriever = BM25sRetriever()
|
227 |
+
retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
|
|
|
228 |
retrieved_chunks = retriever.predict(query="What are Ribosomes?")
|
229 |
```
|
230 |
|