Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions chat/retrievers/rules_bot_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain.schema import BaseRetriever
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever


class RulesBotRetriever(BaseRetriever):
Expand All @@ -15,18 +15,20 @@ class RulesBotRetriever(BaseRetriever):
index: FAISS
search_kwargs: dict

def get_relevant_documents(self, question):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
):
# TODO: Debug relevancy score and figure out if we should filter at a threshold
docs_with_score = self.index.similarity_search_with_relevance_scores(
question, **self.search_kwargs
query, **self.search_kwargs
)

docs = []
for doc, score in docs_with_score:
doc.metadata["relevancy_score"] = score
docs.append(doc)

if self._is_setup_question(question):
if self._is_setup_question(query):
# If the setup page is not already in the results, add it
if not any(doc.metadata.get("setup_page") for doc in docs):
setup_documents = self.index.similarity_search(
Expand All @@ -39,11 +41,6 @@ def get_relevant_documents(self, question):

return docs

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
):
return self.get_relevant_documents(query)

def _is_setup_question(self, question):
question = question.lower()
return (
Expand Down
12 changes: 7 additions & 5 deletions chat/services/streaming_question_answering_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from enum import Enum, auto

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.schema.messages import AIMessage, HumanMessage
from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_classic.chains.history_aware_retriever import (
create_history_aware_retriever,
)
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI

Expand Down
24 changes: 11 additions & 13 deletions chat/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from django.contrib.auth.models import User
from django.test import TestCase
from django.urls import reverse
from langchain.schema.messages import AIMessage, HumanMessage
from langchain_community.embeddings.fake import DeterministicFakeEmbedding
from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document
from langchain_core.messages import AIMessage, HumanMessage

from chat.models import ChatSession
from chat.retrievers.rules_bot_retriever import RulesBotRetriever
Expand Down Expand Up @@ -203,9 +203,7 @@ def test_happy_path(self):
embedding=DeterministicFakeEmbedding(size=1536),
)

docs = RulesBotRetriever(
index=index, search_kwargs={"k": 3}
).get_relevant_documents("clue")
docs = RulesBotRetriever(index=index, search_kwargs={"k": 3}).invoke("clue")

self.assertEqual(len(docs), 3)
self.assertEqual(docs[0].metadata["page"], 44)
Expand All @@ -218,9 +216,9 @@ def test_setup_question_special_case(self):
embedding=DeterministicFakeEmbedding(size=1536),
)

docs = RulesBotRetriever(
index=index, search_kwargs={"k": 3}
).get_relevant_documents("how many pieces do you start with?")
docs = RulesBotRetriever(index=index, search_kwargs={"k": 3}).invoke(
"how many pieces do you start with?"
)

self.assertEqual(len(docs), 3) # Still only returns 3 results
self.assertEqual(
Expand All @@ -234,9 +232,9 @@ def test_setup_question_no_special_case(self):
embedding=DeterministicFakeEmbedding(size=1536),
)

docs = RulesBotRetriever(
index=index, search_kwargs={"k": 3}
).get_relevant_documents("instructions")
docs = RulesBotRetriever(index=index, search_kwargs={"k": 3}).invoke(
"instructions"
)

self.assertEqual(len(docs), 3)
self.assertEqual(
Expand All @@ -256,9 +254,9 @@ def test_setup_question_special_case_no_setup_page(self):
embedding=DeterministicFakeEmbedding(size=1536),
)

docs = RulesBotRetriever(
index=index, search_kwargs={"k": 3}
).get_relevant_documents("how many pieces do you start with?")
docs = RulesBotRetriever(index=index, search_kwargs={"k": 3}).invoke(
"how many pieces do you start with?"
)

self.assertEqual(len(docs), 3) # Return 3 results
# None are a setup page because there are no setup pages
Expand Down
4 changes: 2 additions & 2 deletions games/loaders/pdf_loader_and_summarizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter

from rulesbot.settings import DEFAULT_CHATGPT_MODEL

Expand Down
3 changes: 2 additions & 1 deletion games/services/document_ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def _valid_pdf(filename):
Check that the file is a valid PDF file
"""
try:
PdfReader(filename)
with open(filename, "rb") as pdf_file:
PdfReader(pdf_file)
return True
except:
return False
2 changes: 1 addition & 1 deletion games/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_ingest_setup_pages(self):
"games.loaders.pdf_loader_and_summarizer.ChatOpenAI"
) as chat_opem_ai_mock:
llm = mock.Mock(spec=ChatOpenAI)
llm.predict.return_value = "some summarized text"
llm.invoke.return_value = mock.Mock(content="some summarized text")
chat_opem_ai_mock.return_value = llm

ingest_document(document)
Expand Down
86 changes: 80 additions & 6 deletions games/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import importlib
import sys
from functools import lru_cache

from django.conf import settings
Expand All @@ -8,6 +10,17 @@

EMBEDDING_LENGTH = 1536

LEGACY_LANGCHAIN_MODULE_ALIASES = [
# Compatibility aliases for FAISS blobs serialized before the LangChain 1.x migration.
# Once all persisted indexes have been rebuilt on 1.x, this alias map can be removed.
("langchain.docstore", "langchain_classic.docstore"),
("langchain.docstore.base", "langchain_classic.docstore.base"),
("langchain.docstore.document", "langchain_classic.docstore.document"),
("langchain.docstore.in_memory", "langchain_classic.docstore.in_memory"),
("langchain.schema", "langchain_classic.schema"),
("langchain.schema.document", "langchain_classic.schema.document"),
]

DEFAULT_EMBEDDING = (
DeterministicFakeEmbedding(size=1536) if settings.TESTING else OpenAIEmbeddings()
)
Expand Down Expand Up @@ -64,17 +77,78 @@ def _try_load_index(self):
If the index exists, load it. Otherwise return None
"""
try:
self.game.faiss_file.open()
return FAISS.deserialize_from_bytes(
self.game.faiss_file.read(),
self.embedding,
allow_dangerous_deserialization=True,
)
self.game.faiss_file.open("rb")
try:
data = self.game.faiss_file.read()
finally:
self.game.faiss_file.close()

self._register_legacy_langchain_module_aliases()

try:
return FAISS.deserialize_from_bytes(
data,
self.embedding,
allow_dangerous_deserialization=True,
)
except ModuleNotFoundError as e:
if self._register_legacy_langchain_module_alias(e.name):
return FAISS.deserialize_from_bytes(
data,
self.embedding,
allow_dangerous_deserialization=True,
)
raise
except ValueError as e:
if "attribute has no file associated with it." in str(e):
return None
raise e

@classmethod
def _register_legacy_langchain_module_aliases(cls):
for module_name, target_module in LEGACY_LANGCHAIN_MODULE_ALIASES:
cls._register_legacy_langchain_module_alias(
module_name, preferred_target=target_module
)

@staticmethod
def _register_legacy_langchain_module_alias(module_name, preferred_target=None):
if not module_name or not module_name.startswith("langchain."):
return False

candidates = []

if preferred_target is not None:
candidates.append(preferred_target)

if module_name.startswith("langchain.docstore"):
candidates.append(
module_name.replace(
"langchain.docstore", "langchain_classic.docstore", 1
)
)
candidates.append(
module_name.replace(
"langchain.docstore", "langchain_community.docstore", 1
)
)

if module_name.startswith("langchain.schema"):
candidates.append(
module_name.replace("langchain.schema", "langchain_classic.schema", 1)
)

candidates.append(module_name.replace("langchain.", "langchain_classic.", 1))

for candidate in candidates:
try:
sys.modules[module_name] = importlib.import_module(candidate)
return True
except ModuleNotFoundError:
continue

return False

def _persist_index(self):
"""
Persist the index to storage
Expand Down
Loading
Loading