Skip to content

Commit

Permalink
Feat/retrieval sys improvements (#30)
Browse files Browse the repository at this point in the history
* Add the bm25 library;

* Update in Embedding model url;

* Add the token based chunking strategy;

* Add the context to chunks;

* Add hybrid retrieval, re-ranker methods to improve the performance

* Increase the timeout period;
  • Loading branch information
ranjan-stha authored and thenav56 committed Nov 11, 2024
1 parent af6c07c commit 774f295
Show file tree
Hide file tree
Showing 9 changed files with 891 additions and 620 deletions.
42 changes: 18 additions & 24 deletions chatbotcore/contextual_chunks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import logging
from enum import Enum
from typing import List, Any
from dataclasses import dataclass, field
from django.conf import settings
from enum import Enum
from typing import Any, List

from langchain_community.llms.ollama import Ollama
from django.conf import settings
from langchain.schema import Document
from langchain_community.llms.ollama import Ollama
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

from chatbotcore.utils import LLMType

logger = logging.getLogger(__name__)


@dataclass
class OpenAIHandler:
"""LLM handler using OpenAI for RAG"""
Expand All @@ -26,6 +27,7 @@ def __post_init__(self):
except Exception as e:
raise Exception(f"OpenAI LLM model is not successfully loaded. {str(e)}")


@dataclass
class OllamaHandler:
"""LLM Handler using Ollama"""
Expand All @@ -36,17 +38,16 @@ class OllamaHandler:
def __post_init__(self):
try:
self.llm = Ollama(
model=settings.LLM_MODEL_NAME,
base_url=settings.LLM_OLLAMA_BASE_URL,
temperature=self.temperature
model=settings.LLM_MODEL_NAME, base_url=settings.LLM_OLLAMA_BASE_URL, temperature=self.temperature
)
except Exception as e:
raise Exception(f"Ollama LLM model is not successfully loaded. {str(e)}")


@dataclass
class ContextualChunking:
""" Context retrieval for the chunk documents """
"""Context retrieval for the chunk documents"""

model: Any = field(init=False)
model_type: Enum = LLMType.OLLAMA

Expand All @@ -60,9 +61,9 @@ def __post_init__(self):
raise ValueError("Wront LLM Type")

def get_prompt(self):
""" Creates a prompt """
prompt = """
You are an AI assistant specializing in Human Resources data processing in a company.
"""Creates a prompt"""
prompt = """
You are an AI assistant who can generate a short context of the chunk text from the document.
Here is the document:
<document>
{document}
Expand All @@ -73,30 +74,23 @@ def get_prompt(self):
{chunk}
</chunk>
Please give a short succint context using maximum 20 words to situate this chunk within the overall document\n
Please give a short succint context (within 30 tokens) to situate this chunk within the overall document\n
for the purposes of improving search retrieval of the chunk. Answer only with the succint context and nothing else.
Context:
"""
return prompt

def _generate_context(self, document: str, chunk: str):
""" Generates contextualized document chunk response """
"""Generates contextualized document chunk response"""
prompt_template = ChatPromptTemplate.from_messages([("system", self.get_prompt())])
messages = prompt_template.format_messages(
document=document,
chunk=chunk
)
messages = prompt_template.format_messages(document=document, chunk=chunk)
response = self.model.llm.invoke(messages)
return response

def generate_contextualized_chunks(self, document: str, chunks: List[Document]):
""" Generates contextualized document chunks """
"""Generates contextualized document chunks"""
contextualized_chunks = []
for chunk in chunks:
context = self._generate_context(document, chunk.page_content)
contextualized_content = f"{context}\n\n\n{chunk.page_content}"
contextualized_chunks.append(
Document(page_content=contextualized_content, metadata=chunk.metadata)
)
contextualized_content = f"""{context.strip()}. {chunk.page_content.strip()}"""
contextualized_chunks.append(Document(page_content=contextualized_content, metadata=chunk.metadata))
return contextualized_chunks
2 changes: 1 addition & 1 deletion chatbotcore/custom_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __post_init__(self):
if not (self.url and self.model_name and self.base_url):
raise Exception("Url or base_url or both are not provided.")

def embed_query(self, text: str, timeout: int = 30) -> List[float]:
def embed_query(self, text: str, timeout: int = 240) -> List[float]:
"""
Sends the request to Embedding module to
embed the query to the vector representation
Expand Down
64 changes: 63 additions & 1 deletion chatbotcore/database.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any
from typing import Any, List

import qdrant_client.http.models as q_models
from django.conf import settings
from langchain.schema import Document
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import (
Expand Down Expand Up @@ -58,6 +59,29 @@ def store_data(self, data: list) -> None:
response = self.db_client.upsert(collection_name=self.collection_name, points=point_vectors)
return response

def retrieve_vectors(self, points: List[str]) -> List[List[float]]:
"""Retrieve vectors"""
retrieved_data = self.db_client.retrieve(
collection_name=self.collection_name, ids=points, with_vectors=True, with_payload=False
)
return [v.vector for v in retrieved_data]

def search_vectors_by_id(self, uuid_to_search: str):
"""
Search data vectors by id
"""
filter_condition = Filter(must=[FieldCondition(key="_id", match=MatchValue(value=uuid_to_search))])
results = self.db_client.search(
collection_name=self.collection_name,
query_vector=[0.0] * settings.EMBEDDING_MODEL_VECTOR_SIZE,
query_filter=filter_condition,
limit=1,
)
if results:
vector = results[0].vector
return vector
return None

def data_search(
self, collection_names: list, query_vector: list, top_n_retrieval: int = 5, score_threshold: float = 0.7
):
Expand All @@ -83,3 +107,41 @@ def delete_data_by_src_uuid(self, collection_name: str, key: str, value: Any) ->
result = self.db_client.delete(collection_name=collection_name, points_selector=points_selector)

return result.status == q_models.UpdateStatus.COMPLETED

def convert_record_to_document(self, records):
"""
Converts Record type to Document type
"""
documents = []
for record in records:
page_content = record.payload.get("page_content", "") # Adjust this to match your payload structure
if page_content:
page_content = page_content.replace("\n", "").strip()
metadata = {k: v for k, v in record.payload.items() if k != "text"} # All other metadata
metadata["_id"] = record.id
# Create a LangChain Document
doc = Document(
page_content=page_content,
metadata=metadata,
)
documents.append(doc)
return documents

def load_all_documents(self):
"""Load all the documents"""
all_docs = []
offset = 0
limit = 10_000

while True:
response = self.db_client.scroll(
collection_name=self.collection_name, offset=offset, limit=limit, with_payload=True, with_vectors=False
)
documents = self.convert_record_to_document(response[0])
offset = response[-1]

all_docs.extend(documents)

if len(documents) < limit:
break
return all_docs
59 changes: 44 additions & 15 deletions chatbotcore/doc_loaders.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,67 @@
import json
from dataclasses import dataclass, field
from typing import List

import requests
from django.conf import settings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader

from chatbotcore.contextual_chunks import ContextualChunking


@dataclass(kw_only=True)
class DocumentLoader:
"""
Base Class for Document Loaders
"""

chunk_size: int = 100
chunk_overlap: int = 20
chunk_overlap: int = 40
context_retrieval: ContextualChunking = field(init=False)

def __post_init__(self):
self.context_retrieval = ContextualChunking()

def _get_split_documents(self, documents: List[Document]):
def _get_split_documents_with_recursive_char(self, documents: List[Document], multiplier: int = 3):
"""
Splits documents into multiple chunks
Splits documents into multiple chunks using Recursive Character splitter
"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, length_function=len
chunk_size=self.chunk_size * multiplier, chunk_overlap=self.chunk_overlap * multiplier, length_function=len
)
return splitter.split_documents(documents=documents)

def langchain_document_to_dict(self, doc: Document):
"""
Converts langchain Document to dictionary
"""
return {"page_content": doc.page_content, "metadata": doc.metadata}

def dict_to_langchain_document(self, doc: dict):
"""
Converts dictionary to Langchain docuemnt
"""
return Document(page_content=doc["page_content"], metadata=doc["metadata"])

return splitter.split_documents(documents)
def _get_split_documents_using_token_based(self, documents: List[Document], timeout: int = 60):
"""
Splits documents into multiple chunks using Sentence Transformer
token based.
"""
url = f"{settings.EMBEDDING_MODEL_URL}/split_docs_based_on_tokens"
documents_dict = [self.langchain_document_to_dict(d) for d in documents]
payload = {
"model": settings.EMBEDDING_MODEL_NAME,
"documents": json.dumps(documents_dict),
"chunk_size": self.chunk_size,
"chunk_overlap": self.chunk_overlap,
}
headers = {"Content-Type": "application/json"}
response = requests.post(url=url, headers=headers, json=payload, timeout=timeout)
data = response.json()
return [self.dict_to_langchain_document(d) for d in data]


@dataclass
Expand All @@ -43,11 +77,9 @@ def create_document_chunks(self):
Creates multiple documents from the input texts
"""
documents = [Document(page_content=self.text)]
doc_chunks = self._get_split_documents(documents=documents)
contextualized_chunks = self.context_retrieval.generate_contextualized_chunks(
document=self.text,
chunks=doc_chunks
)
# doc_chunks = self._get_split_documents_using_token_based(documents=documents)
doc_chunks = self._get_split_documents_with_recursive_char(documents=documents)
contextualized_chunks = self.context_retrieval.generate_contextualized_chunks(document=self.text, chunks=doc_chunks)
return contextualized_chunks


Expand All @@ -65,9 +97,6 @@ def create_document_chunks(self):
"""
loader = WebBaseLoader(web_path=self.url)
docs = loader.load()
doc_chunks = self._get_split_documents(documents=docs)
contextualized_chunks = self.context_retrieval.generate_contextualized_chunks(
document=docs,
chunks=doc_chunks
)
doc_chunks = self._get_split_documents_using_token_based(documents=docs)
contextualized_chunks = self.context_retrieval.generate_contextualized_chunks(document=docs, chunks=doc_chunks)
return contextualized_chunks
Loading

0 comments on commit 774f295

Please sign in to comment.