【多模態&RAG】多模態RAG ColPali實踐 原創
關于??【RAG&多模態】多模態RAG-ColPali:使用視覺語言模型實現高效的文檔檢索??前面已經介紹了(供參考),這次來看看ColPali實踐。
所需權重:
- 多模態問答模型:Qwen2-VL-72B-Instruct,https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct
- 基于 PaliGemma-3B 和 ColBERT 策略的視覺檢索器:
- ColPali(LoRA):https://huggingface.co/vidore/colpali
- ColPali(基座):https://huggingface.co/vidore/colpaligemma-3b-mix-448-base
多模態檢索問答實踐
- lora的adapter_config.json字段base_model_name_or_path修改地址:ColPali(基座)存儲路徑
- qwen_vl_utils下載地址:https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils/src/qwen_vl_utils
- byaldi安裝方式:https://github.com/AnswerDotAI/byaldi
- 完整代碼
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from pdf2image import convert_from_path
class DocumentQA:
def __init__(self, rag_model_name: str, vlm_model_name: str, device: str = 'cuda', system_prompt: str = None):
self.rag_engine = RAGMultiModalModel.from_pretrained(rag_model_name)
self.vlm = Qwen2VLForConditionalGeneration.from_pretrained(
vlm_model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=device
)
self.processor = AutoProcessor.from_pretrained(vlm_model_name, trust_remote_code=True)
self.device = device
if system_prompt is None:
self.system_prompt = (
"你是一位專精于計算機科學和機器學習的AI研究助理。"
"你的任務是分析學術論文,尤其是關于文檔檢索和多模態模型的研究。"
"請仔細分析提供的圖像和文本,提供深入的見解和解釋。"
)
else:
self.system_prompt = system_prompt
def index_document(self, pdf_path: str, index_name: str = 'index', overwrite: bool = True):
self.pdf_path = pdf_path
self.rag_engine.index(
input_path=pdf_path,
index_name=index_name,
store_collection_with_index=False,
overwrite=overwrite
)
self.images = convert_from_path(pdf_path)
def query(self, text_query: str, k: int = 3) -> str:
results = self.rag_engine.search(text_query, k=k)
print("搜索結果:", results)
if not results:
print("未找到相關查詢結果。")
return None
try:
page_num = results[0]["page_num"]
image_index = page_num - 1
image = self.images[image_index]
except (KeyError, IndexError) as e:
print("獲取頁面圖像時出錯:", e)
return None
messages = [
{
"role": "system",
"content": self.system_prompt
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_query},
],
}
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
# 準備模型輸入
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.device)
generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
if __name__ == "__main__":
# 初始化 DocumentQA 實例
document_qa = DocumentQA(
rag_model_name="./colpali",
vlm_model_name="./Qwen2-VL-7B-Instruct",
device='cuda'
)
# 索引 PDF 文檔
document_qa.index_document("test.pdf")
# 定義查詢
text_query = (
"文中模型在哪個數據集上相比其他模型有最大的優勢?"
"該優勢的改進幅度是多少?"
)
# 執行查詢并打印答案
answer = document_qa.query(text_query)
print("答案:", answer)
本文轉載自公眾號大模型自然語言處理 作者:余俊暉
?著作權歸作者所有,如需轉載,請注明出處,否則將追究法律責任
贊
收藏
回復
分享
微博
QQ
微信
舉報

回復
相關推薦