diff --git a/chat/services/agentic_streaming_question_answering_service.py b/chat/services/agentic_streaming_question_answering_service.py new file mode 100644 index 0000000..5e47291 --- /dev/null +++ b/chat/services/agentic_streaming_question_answering_service.py @@ -0,0 +1,204 @@ +from enum import Enum, auto + +from langchain.agents import create_agent +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage +from langchain_core.tools import tool +from langchain_openai import ChatOpenAI + +from chat.retrievers.rules_bot_retriever import RulesBotRetriever +from rulesbot.settings import DEFAULT_CHATGPT_MODEL + +prompt_template = """Please use the available tools to provide a clear and accurate answer to questions regarding the rules of %%GAME%%. +Always use the rulebook search tool before answering rule questions. +Explain your answer in detail using the rulebook information you found. +If the question is not a question but a greeting or a thank you, kindly respond with a greeting or a thank you. +If the question is claiming that the answer is wrong, kindly respond with an apology. +Ignore any variant or optional rules unless specifically instructed not to. +""" + + +class QueueSignals(Enum): + job_done = auto() + error = auto() + + +def ask_question(question, chat_session, response_queue): + """ + Ask a question in the chat session and add response to the chat session. + """ + chat_session.message_set.create(message=question, message_type="human") + + try: + result = _query_agentic_stream(question, chat_session, response_queue) + except Exception: + response_queue.put(QueueSignals.error) + raise + + answer = result["answer"] + ai_message = chat_session.message_set.create(message=answer, message_type="ai") + for source_document in result["context"]: + ai_message.sourcedocument_set.create( + document_id=source_document.metadata["document_id"], + page_number=source_document.metadata["page"] + 1, # 0-indexed + ) + + response_queue.put(QueueSignals.job_done) + return response_queue + + +def _query_agentic_stream(question, chat_session, response_queue): + retrieved_documents = [] + rulebook_search_tool = _build_rulebook_search_tool( + chat_session, retrieved_documents + ) + agent = _build_question_answering_agent(chat_session, [rulebook_search_tool]) + + answer = _stream_agent_answer( + agent=agent, + question=question, + chat_session=chat_session, + response_queue=response_queue, + ) + + return { + "answer": answer, + "context": _deduplicate_documents(retrieved_documents), + } + + +def _build_question_answering_agent(chat_session, tools): + personalized_prompt_template = prompt_template.replace( + "%%GAME%%", chat_session.game.name + ) + + model = ChatOpenAI( + model=DEFAULT_CHATGPT_MODEL, + temperature=0.1, + streaming=True, + ) + + return create_agent( + model=model, + tools=tools, + system_prompt=personalized_prompt_template, + name="rulesbot_question_answering_agent", + ) + + +def _build_rulebook_search_tool(chat_session, retrieved_documents): + retriever = RulesBotRetriever( + index=chat_session.game.vector_store.index, + search_kwargs={"k": 3}, + ) + + @tool("rulebook_search") + def rulebook_search(query: str) -> str: + """Search the current game's rulebook and return relevant passages with page references.""" + docs = retriever.invoke(query) + retrieved_documents.extend(docs) + return _format_rulebook_context(docs) + + return rulebook_search + + +def _format_rulebook_context(documents): + if len(documents) == 0: + return "No relevant rulebook passages were found." + + passages = [] + for i, document in enumerate(documents, start=1): + page_number = document.metadata.get("page") + page_display = page_number + 1 if isinstance(page_number, int) else "unknown" + document_id = document.metadata.get("document_id", "unknown") + relevancy_score = document.metadata.get("relevancy_score") + relevancy_text = ( + f"{relevancy_score:.3f}" + if isinstance(relevancy_score, float) + else "unknown" + ) + passages.append( + "\n".join( + [ + f"Passage {i}", + f"Document ID: {document_id}", + f"Page: {page_display}", + f"Relevancy: {relevancy_text}", + document.page_content, + ] + ) + ) + + return "\n\n".join(passages) + + +def _stream_agent_answer(agent, question, chat_session, response_queue): + answer_chunks = [] + answer = None + + chat_history = _get_chat_history(chat_session) + if ( + not chat_history + or not isinstance(chat_history[-1], HumanMessage) + or chat_history[-1].content != question + ): + chat_history.append(HumanMessage(content=question)) + + agent_input = {"messages": chat_history} + + for chunk in agent.stream( + agent_input, + stream_mode=["messages", "updates"], + version="v2", + ): + if chunk.get("type") == "messages": + token, _metadata = chunk["data"] + if isinstance(token, AIMessageChunk) and token.text: + response_queue.put(token.text) + answer_chunks.append(token.text) + + if chunk.get("type") == "updates": + for source, update in chunk["data"].items(): + if source != "model": + continue + if not isinstance(update, dict): + continue + + messages = update.get("messages", []) + if not messages: + continue + + latest_message = messages[-1] + if isinstance(latest_message, AIMessage) and latest_message.text: + answer = latest_message.text + + if answer is None: + answer = "".join(answer_chunks) + + return answer + + +def _deduplicate_documents(documents): + unique_documents = [] + seen = set() + for document in documents: + key = (document.metadata.get("document_id"), document.metadata.get("page")) + if key in seen: + continue + seen.add(key) + unique_documents.append(document) + + return unique_documents + + +def _get_chat_history(chat_session): + chat_history = [] + latest_messages = chat_session.message_set.order_by("-created_at")[:12] + latest_messages_in_cronological_order = reversed(latest_messages) + for message in latest_messages_in_cronological_order: + if message.message_type == "human": + chat_history.append(HumanMessage(content=message.message)) + elif message.message_type == "ai": + chat_history.append(AIMessage(content=message.message)) + elif message.message_type == "system": + continue + return chat_history if chat_history else [] diff --git a/chat/tests.py b/chat/tests.py index f91a1ae..933e4b8 100644 --- a/chat/tests.py +++ b/chat/tests.py @@ -8,10 +8,11 @@ from langchain_community.llms.fake import FakeListLLM, FakeStreamingListLLM from langchain_community.vectorstores import FAISS from langchain_core.documents.base import Document -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from chat.models import ChatSession from chat.retrievers.rules_bot_retriever import RulesBotRetriever +from chat.services import agentic_streaming_question_answering_service from chat.services.streaming_question_answering_service import ( QueueSignals, _get_chat_history, @@ -263,3 +264,128 @@ def test_setup_question_special_case_no_setup_page(self): self.assertEqual( [doc.metadata.get("setup_page") for doc in docs], [None, None, None] ) + + +class AgenticStreamingQuestionAnsweringServiceTests(TestCase): + def test_ask_question_persists_messages_and_sources(self): + game = Game.objects.create(name="Test Game") + document = game.document_set.create(display_name="Rulebook", url="some-url") + chat_session = ChatSession.objects.create(game=game) + test_queue = SimpleQueue() + + source_document = Document( + page_content="some content", + metadata={"document_id": document.id, "page": 42}, + ) + + with mock.patch( + "chat.services.agentic_streaming_question_answering_service._query_agentic_stream" + ) as mock_query: + mock_query.return_value = { + "answer": "some answer to the question", + "context": [source_document], + } + + agentic_streaming_question_answering_service.ask_question( + "What is the meaning of life?", chat_session, test_queue + ) + + self.assertEqual(chat_session.message_set.count(), 2) + self.assertEqual(chat_session.message_set.first().message_type, "human") + self.assertEqual(chat_session.message_set.last().message_type, "ai") + self.assertEqual( + chat_session.message_set.last().message, "some answer to the question" + ) + ai_message = chat_session.message_set.last() + self.assertEqual(len(ai_message.sourcedocument_set.all()), 1) + self.assertEqual(ai_message.sourcedocument_set.first().document, document) + self.assertEqual(ai_message.sourcedocument_set.first().page_number, 43) + self.assertEqual( + test_queue.get(), + agentic_streaming_question_answering_service.QueueSignals.job_done, + ) + + def test_ask_question_adds_error_signal_on_exception(self): + game = Game.objects.create(name="Test Game") + chat_session = ChatSession.objects.create(game=game) + test_queue = SimpleQueue() + + with mock.patch( + "chat.services.agentic_streaming_question_answering_service._query_agentic_stream", + side_effect=RuntimeError("boom"), + ): + with self.assertRaises(RuntimeError): + agentic_streaming_question_answering_service.ask_question( + "What happened?", chat_session, test_queue + ) + + self.assertEqual( + test_queue.get(), + agentic_streaming_question_answering_service.QueueSignals.error, + ) + + def test_stream_agent_answer_enqueues_text_tokens(self): + game = Game.objects.create(name="Test Game") + chat_session = ChatSession.objects.create(game=game) + test_queue = SimpleQueue() + + class FakeAgent: + def stream(self, *args, **kwargs): + return [ + { + "type": "messages", + "data": (AIMessageChunk(content="some "), {}), + }, + { + "type": "messages", + "data": (AIMessageChunk(content="answer"), {}), + }, + { + "type": "updates", + "data": { + "model": { + "messages": [AIMessage(content="some answer")], + } + }, + }, + ] + + answer = agentic_streaming_question_answering_service._stream_agent_answer( + FakeAgent(), + "What is the meaning of life?", + chat_session, + test_queue, + ) + + self.assertEqual(answer, "some answer") + self.assertEqual(test_queue.get(), "some ") + self.assertEqual(test_queue.get(), "answer") + + @prevent_warnings + def test_rulebook_search_tool_returns_context_and_tracks_documents(self): + game = Game.objects.create(name="Test Game") + document = game.document_set.create(display_name="Rulebook", url="some-url") + chat_session = ChatSession.objects.create(game=game) + + chat_session.game.vector_store.add_documents( + [ + Document( + page_content="Clue game instructions", + metadata={"page": 4}, + ) + ], + document_id=document.id, + ) + + retrieved_documents = [] + rulebook_search_tool = ( + agentic_streaming_question_answering_service._build_rulebook_search_tool( + chat_session, retrieved_documents + ) + ) + + result = rulebook_search_tool.invoke({"query": "clue"}) + + self.assertIn("Passage 1", result) + self.assertEqual(len(retrieved_documents), 1) + self.assertEqual(retrieved_documents[0].metadata["document_id"], document.id) diff --git a/chat/views.py b/chat/views.py index 37177ad..f32531a 100644 --- a/chat/views.py +++ b/chat/views.py @@ -10,7 +10,7 @@ from ads.services import serve_ad_with_impression from chat.forms import ChatForm from chat.models import ChatSession -from chat.services import streaming_question_answering_service +from chat.services import agentic_streaming_question_answering_service from games.models import Game @@ -116,15 +116,17 @@ def stream_from_queue(): while True: result = response_queue.get() if ( - result is streaming_question_answering_service.QueueSignals.job_done - or result is streaming_question_answering_service.QueueSignals.error + result + is agentic_streaming_question_answering_service.QueueSignals.job_done + or result + is agentic_streaming_question_answering_service.QueueSignals.error ): break yield result # run the async ask question function in a thread qa_thread = Thread( - target=streaming_question_answering_service.ask_question, + target=agentic_streaming_question_answering_service.ask_question, args=(question, chat_session, response_queue), ) qa_thread.start() diff --git a/rulesbot/settings.py b/rulesbot/settings.py index 60f22e5..934fccd 100644 --- a/rulesbot/settings.py +++ b/rulesbot/settings.py @@ -216,4 +216,4 @@ # CUSTOM APP SETTINGS # ChatGPT settings -DEFAULT_CHATGPT_MODEL = "gpt-4o-mini" +DEFAULT_CHATGPT_MODEL = "gpt-5.4-nano" diff --git a/tests/evaluate_rulesbot.py b/tests/evaluate_rulesbot.py index a9a78d5..8042028 100644 --- a/tests/evaluate_rulesbot.py +++ b/tests/evaluate_rulesbot.py @@ -5,6 +5,8 @@ Usage: python evaluate_rulesbot.py + python evaluate_rulesbot.py path/to/fixture.json + python evaluate_rulesbot.py --legacy The script will output a table of results to the console. @@ -35,7 +37,7 @@ """ import json import os -import sys +from argparse import ArgumentParser from multiprocessing import SimpleQueue from pathlib import Path @@ -48,6 +50,7 @@ from colorama import Fore, Style # noqa: E402 from chat.models import ChatSession # noqa: E402 +from chat.services import agentic_streaming_question_answering_service # noqa: E402 from chat.services import streaming_question_answering_service # noqa: E402 from games.models import Game # noqa: E402 from games.services import document_ingestion_service # noqa: E402 @@ -97,7 +100,11 @@ def clean_up_game(game: Game) -> None: def evaluate_question( - session: ChatSession, question: str, answer: str, description: str + session: ChatSession, + question: str, + answer: str, + description: str, + qa_service, ) -> bool: # TODO: Change to use langchain evalutor eventually. # evaluator = load_evaluator("labeled_criteria", criteria="correctness") @@ -113,7 +120,7 @@ def evaluate_question( queue = ( SimpleQueue() ) # Ignore the queue and just return the result by reading the message set - streaming_question_answering_service.ask_question(question, session, queue) + qa_service.ask_question(question, session, queue) bot_answer = session.message_set.filter(message_type="ai").last().message @@ -135,7 +142,7 @@ def evaluate_question( return eval_result["score"] -def run_fixture_file(filename: Path) -> None: +def run_fixture_file(filename: Path, qa_service) -> None: print(Fore.CYAN + f"Running fixture file: {filename}" + Style.RESET_ALL) fixture_object = parse_game_json(filename) game = setup_game(fixture_object) @@ -150,6 +157,7 @@ def run_fixture_file(filename: Path) -> None: question=question_session["question"], answer=question_session["answer"], description=question_session["description"], + qa_service=qa_service, ) total_correct += score @@ -160,7 +168,7 @@ def run_fixture_file(filename: Path) -> None: return total_correct, total_questions -def run_tests() -> None: +def run_tests(qa_service) -> None: # List all files in the fixtures/evaluate_rulesbot directory test_files = ( Path(__file__).parent.joinpath("fixtures", "evaluate_rulesbot").glob("*.json") @@ -170,7 +178,7 @@ def run_tests() -> None: total_questions = 0 for test_file in test_files: - correct, questions = run_fixture_file(test_file) + correct, questions = run_fixture_file(test_file, qa_service) total_correct += correct total_questions += questions @@ -183,11 +191,33 @@ def run_tests() -> None: if __name__ == "__main__": - # get if there is an arg for a specific test file - # if so, run that test file - # else, run all test files - if len(sys.argv) > 1: - test_file = Path(sys.argv[1]) - run_fixture_file(test_file) + parser = ArgumentParser() + parser.add_argument( + "fixture", + nargs="?", + type=Path, + help="Optional fixture path to run. If omitted, all fixtures are run.", + ) + parser.add_argument( + "--legacy", + action="store_true", + help="Use legacy non-agentic streaming question answering service.", + ) + args = parser.parse_args() + + qa_service = ( + streaming_question_answering_service + if args.legacy + else agentic_streaming_question_answering_service + ) + + print( + Fore.CYAN + + f"Using QA service: {'legacy' if args.legacy else 'agentic'}" + + Style.RESET_ALL + ) + + if args.fixture is not None: + run_fixture_file(args.fixture, qa_service) else: - run_tests() + run_tests(qa_service)