Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions chat/services/agentic_streaming_question_answering_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from enum import Enum, auto

from langchain.agents import create_agent
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI

from chat.retrievers.rules_bot_retriever import RulesBotRetriever
from rulesbot.settings import DEFAULT_CHATGPT_MODEL

prompt_template = """Please use the available tools to provide a clear and accurate answer to questions regarding the rules of %%GAME%%.
Always use the rulebook search tool before answering rule questions.
Explain your answer in detail using the rulebook information you found.
If the question is not a question but a greeting or a thank you, kindly respond with a greeting or a thank you.
If the question is claiming that the answer is wrong, kindly respond with an apology.
Ignore any variant or optional rules unless specifically instructed not to.
"""


class QueueSignals(Enum):
job_done = auto()
error = auto()


def ask_question(question, chat_session, response_queue):
"""
Ask a question in the chat session and add response to the chat session.
"""
chat_session.message_set.create(message=question, message_type="human")

try:
result = _query_agentic_stream(question, chat_session, response_queue)
except Exception:
response_queue.put(QueueSignals.error)
raise

answer = result["answer"]
ai_message = chat_session.message_set.create(message=answer, message_type="ai")
for source_document in result["context"]:
ai_message.sourcedocument_set.create(
document_id=source_document.metadata["document_id"],
page_number=source_document.metadata["page"] + 1, # 0-indexed
)

response_queue.put(QueueSignals.job_done)
return response_queue


def _query_agentic_stream(question, chat_session, response_queue):
retrieved_documents = []
rulebook_search_tool = _build_rulebook_search_tool(
chat_session, retrieved_documents
)
agent = _build_question_answering_agent(chat_session, [rulebook_search_tool])

answer = _stream_agent_answer(
agent=agent,
question=question,
chat_session=chat_session,
response_queue=response_queue,
)

return {
"answer": answer,
"context": _deduplicate_documents(retrieved_documents),
}


def _build_question_answering_agent(chat_session, tools):
personalized_prompt_template = prompt_template.replace(
"%%GAME%%", chat_session.game.name
)

model = ChatOpenAI(
model=DEFAULT_CHATGPT_MODEL,
temperature=0.1,
streaming=True,
)

return create_agent(
model=model,
tools=tools,
system_prompt=personalized_prompt_template,
name="rulesbot_question_answering_agent",
)


def _build_rulebook_search_tool(chat_session, retrieved_documents):
retriever = RulesBotRetriever(
index=chat_session.game.vector_store.index,
search_kwargs={"k": 3},
)

@tool("rulebook_search")
def rulebook_search(query: str) -> str:
"""Search the current game's rulebook and return relevant passages with page references."""
docs = retriever.invoke(query)
retrieved_documents.extend(docs)
return _format_rulebook_context(docs)

return rulebook_search


def _format_rulebook_context(documents):
if len(documents) == 0:
return "No relevant rulebook passages were found."

passages = []
for i, document in enumerate(documents, start=1):
page_number = document.metadata.get("page")
page_display = page_number + 1 if isinstance(page_number, int) else "unknown"
document_id = document.metadata.get("document_id", "unknown")
relevancy_score = document.metadata.get("relevancy_score")
relevancy_text = (
f"{relevancy_score:.3f}"
if isinstance(relevancy_score, float)
else "unknown"
)
passages.append(
"\n".join(
[
f"Passage {i}",
f"Document ID: {document_id}",
f"Page: {page_display}",
f"Relevancy: {relevancy_text}",
document.page_content,
]
)
)

return "\n\n".join(passages)


def _stream_agent_answer(agent, question, chat_session, response_queue):
answer_chunks = []
answer = None

chat_history = _get_chat_history(chat_session)
if (
not chat_history
or not isinstance(chat_history[-1], HumanMessage)
or chat_history[-1].content != question
):
chat_history.append(HumanMessage(content=question))

agent_input = {"messages": chat_history}

for chunk in agent.stream(
agent_input,
stream_mode=["messages", "updates"],
version="v2",
):
if chunk.get("type") == "messages":
token, _metadata = chunk["data"]
if isinstance(token, AIMessageChunk) and token.text:
response_queue.put(token.text)
answer_chunks.append(token.text)

if chunk.get("type") == "updates":
for source, update in chunk["data"].items():
if source != "model":
continue
if not isinstance(update, dict):
continue

messages = update.get("messages", [])
if not messages:
continue

latest_message = messages[-1]
if isinstance(latest_message, AIMessage) and latest_message.text:
answer = latest_message.text

if answer is None:
answer = "".join(answer_chunks)

return answer


def _deduplicate_documents(documents):
unique_documents = []
seen = set()
for document in documents:
key = (document.metadata.get("document_id"), document.metadata.get("page"))
if key in seen:
continue
seen.add(key)
unique_documents.append(document)

return unique_documents


def _get_chat_history(chat_session):
chat_history = []
latest_messages = chat_session.message_set.order_by("-created_at")[:12]
latest_messages_in_cronological_order = reversed(latest_messages)
for message in latest_messages_in_cronological_order:
if message.message_type == "human":
chat_history.append(HumanMessage(content=message.message))
elif message.message_type == "ai":
chat_history.append(AIMessage(content=message.message))
elif message.message_type == "system":
continue
return chat_history if chat_history else []
128 changes: 127 additions & 1 deletion chat/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage

from chat.models import ChatSession
from chat.retrievers.rules_bot_retriever import RulesBotRetriever
from chat.services import agentic_streaming_question_answering_service
from chat.services.streaming_question_answering_service import (
QueueSignals,
_get_chat_history,
Expand Down Expand Up @@ -263,3 +264,128 @@ def test_setup_question_special_case_no_setup_page(self):
self.assertEqual(
[doc.metadata.get("setup_page") for doc in docs], [None, None, None]
)


class AgenticStreamingQuestionAnsweringServiceTests(TestCase):
def test_ask_question_persists_messages_and_sources(self):
game = Game.objects.create(name="Test Game")
document = game.document_set.create(display_name="Rulebook", url="some-url")
chat_session = ChatSession.objects.create(game=game)
test_queue = SimpleQueue()

source_document = Document(
page_content="some content",
metadata={"document_id": document.id, "page": 42},
)

with mock.patch(
"chat.services.agentic_streaming_question_answering_service._query_agentic_stream"
) as mock_query:
mock_query.return_value = {
"answer": "some answer to the question",
"context": [source_document],
}

agentic_streaming_question_answering_service.ask_question(
"What is the meaning of life?", chat_session, test_queue
)

self.assertEqual(chat_session.message_set.count(), 2)
self.assertEqual(chat_session.message_set.first().message_type, "human")
self.assertEqual(chat_session.message_set.last().message_type, "ai")
self.assertEqual(
chat_session.message_set.last().message, "some answer to the question"
)
ai_message = chat_session.message_set.last()
self.assertEqual(len(ai_message.sourcedocument_set.all()), 1)
self.assertEqual(ai_message.sourcedocument_set.first().document, document)
self.assertEqual(ai_message.sourcedocument_set.first().page_number, 43)
self.assertEqual(
test_queue.get(),
agentic_streaming_question_answering_service.QueueSignals.job_done,
)

def test_ask_question_adds_error_signal_on_exception(self):
game = Game.objects.create(name="Test Game")
chat_session = ChatSession.objects.create(game=game)
test_queue = SimpleQueue()

with mock.patch(
"chat.services.agentic_streaming_question_answering_service._query_agentic_stream",
side_effect=RuntimeError("boom"),
):
with self.assertRaises(RuntimeError):
agentic_streaming_question_answering_service.ask_question(
"What happened?", chat_session, test_queue
)

self.assertEqual(
test_queue.get(),
agentic_streaming_question_answering_service.QueueSignals.error,
)

def test_stream_agent_answer_enqueues_text_tokens(self):
game = Game.objects.create(name="Test Game")
chat_session = ChatSession.objects.create(game=game)
test_queue = SimpleQueue()

class FakeAgent:
def stream(self, *args, **kwargs):
return [
{
"type": "messages",
"data": (AIMessageChunk(content="some "), {}),
},
{
"type": "messages",
"data": (AIMessageChunk(content="answer"), {}),
},
{
"type": "updates",
"data": {
"model": {
"messages": [AIMessage(content="some answer")],
}
},
},
]

answer = agentic_streaming_question_answering_service._stream_agent_answer(
FakeAgent(),
"What is the meaning of life?",
chat_session,
test_queue,
)

self.assertEqual(answer, "some answer")
self.assertEqual(test_queue.get(), "some ")
self.assertEqual(test_queue.get(), "answer")

@prevent_warnings
def test_rulebook_search_tool_returns_context_and_tracks_documents(self):
game = Game.objects.create(name="Test Game")
document = game.document_set.create(display_name="Rulebook", url="some-url")
chat_session = ChatSession.objects.create(game=game)

chat_session.game.vector_store.add_documents(
[
Document(
page_content="Clue game instructions",
metadata={"page": 4},
)
],
document_id=document.id,
)

retrieved_documents = []
rulebook_search_tool = (
agentic_streaming_question_answering_service._build_rulebook_search_tool(
chat_session, retrieved_documents
)
)

result = rulebook_search_tool.invoke({"query": "clue"})

self.assertIn("Passage 1", result)
self.assertEqual(len(retrieved_documents), 1)
self.assertEqual(retrieved_documents[0].metadata["document_id"], document.id)
10 changes: 6 additions & 4 deletions chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ads.services import serve_ad_with_impression
from chat.forms import ChatForm
from chat.models import ChatSession
from chat.services import streaming_question_answering_service
from chat.services import agentic_streaming_question_answering_service
from games.models import Game


Expand Down Expand Up @@ -116,15 +116,17 @@ def stream_from_queue():
while True:
result = response_queue.get()
if (
result is streaming_question_answering_service.QueueSignals.job_done
or result is streaming_question_answering_service.QueueSignals.error
result
is agentic_streaming_question_answering_service.QueueSignals.job_done
or result
is agentic_streaming_question_answering_service.QueueSignals.error
):
break
yield result

# run the async ask question function in a thread
qa_thread = Thread(
target=streaming_question_answering_service.ask_question,
target=agentic_streaming_question_answering_service.ask_question,
args=(question, chat_session, response_queue),
)
qa_thread.start()
Expand Down
2 changes: 1 addition & 1 deletion rulesbot/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,4 @@
# CUSTOM APP SETTINGS

# ChatGPT settings
DEFAULT_CHATGPT_MODEL = "gpt-4o-mini"
DEFAULT_CHATGPT_MODEL = "gpt-5.4-nano"
Loading
Loading