-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph.py
More file actions
228 lines (177 loc) · 10.1 KB
/
graph.py
File metadata and controls
228 lines (177 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
LangGraph workflow for 飞享IM Q&A service.
Flow:
user question
│
▼
[classify] ── off-topic ──► [reject] → Claude Sonnet 4.6 直接回答通用问题(流式输出)
│
on-topic
│
▼
[retrieve] → BM25 检索 top-k 文档块
│
▼
[grade_docs] → 并发逐块评分,过滤无关内容
│
├─ 有相关文档 ──► [generate] → Claude Sonnet 4.6 基于知识库回答(流式输出)
│
└─ 无相关文档 ──► [fallback] → 提示知识库不足,建议访问官网
"""
from __future__ import annotations
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Annotated, Literal
import operator
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field
from config import ANTHROPIC_API_KEY, CLAUDE_MODEL, CLASSIFY_MODEL, GENERATE_MODEL
# ─── State ────────────────────────────────────────────────────────────────────
class QAState(BaseModel):
question: str
history: list[dict] = Field(default_factory=list) # [{"role": "user"|"assistant", "content": "..."}]
retrieved_docs: list[str] = Field(default_factory=list)
relevant_docs: list[str] = Field(default_factory=list)
answer: str = ""
route: Literal["on_topic", "off_topic"] = "on_topic"
def _history_to_messages(history: list[dict]) -> list:
"""Convert history dicts to LangChain message objects."""
msgs = []
for h in history:
if h["role"] == "user":
msgs.append(HumanMessage(content=h["content"]))
else:
msgs.append(AIMessage(content=h["content"]))
return msgs
# ─── LLM ──────────────────────────────────────────────────────────────────────
def _llm(streaming: bool = False, model: str = CLAUDE_MODEL, thinking: bool = False, max_tokens: int = 2048) -> ChatAnthropic:
kwargs: dict = dict(
model=model,
anthropic_api_key=ANTHROPIC_API_KEY,
streaming=streaming,
max_tokens=max_tokens,
max_retries=3,
)
if thinking:
kwargs["thinking"] = {"type": "adaptive"}
return ChatAnthropic(**kwargs)
# ─── Node: classify ───────────────────────────────────────────────────────────
CLASSIFY_SYSTEM = """你是飞享IM智能助手的问题分类器。
判断用户问题是否与飞享IM相关(功能、部署、技术、价格、使用方法等)。
只回答 "on_topic" 或 "off_topic",不要解释。"""
def classify(state: QAState, retriever) -> dict:
t0 = time.perf_counter()
llm = _llm(model=CLASSIFY_MODEL)
# 带入最近 2 轮历史,帮助识别追问意图(如"那它的价格呢?")
messages = [SystemMessage(content=CLASSIFY_SYSTEM)]
messages.extend(_history_to_messages(state.history[-4:]))
messages.append(HumanMessage(content=state.question))
resp = llm.invoke(messages)
route = "on_topic" if "on_topic" in resp.content else "off_topic"
print(f"[timing] classify: {time.perf_counter() - t0:.2f}s → {route}")
return {"route": route}
# ─── Node: retrieve ───────────────────────────────────────────────────────────
def retrieve(state: QAState, retriever) -> dict:
t0 = time.perf_counter()
docs = retriever.invoke(state.question)
chunks = [d.page_content for d in docs]
print(f"[timing] retrieve: {time.perf_counter() - t0:.2f}s → {len(chunks)} docs")
return {"retrieved_docs": chunks}
# ─── Node: grade_docs ─────────────────────────────────────────────────────────
GRADE_SYSTEM = """你是文档相关性评分器。
给定用户问题和一段文档,判断文档是否包含回答问题所需的相关信息。
只回答 "relevant" 或 "irrelevant"。"""
def grade_docs(state: QAState, retriever) -> dict:
t0 = time.perf_counter()
def _grade_one(i: int, doc: str) -> tuple[int, str, str]:
t1 = time.perf_counter()
llm = _llm(thinking=True)
resp = llm.invoke([
SystemMessage(content=GRADE_SYSTEM),
HumanMessage(content=f"问题:{state.question}\n\n文档:{doc}"),
])
verdict = "relevant" if "relevant" in resp.content.lower() else "irrelevant"
print(f"[timing] grade doc[{i}]: {time.perf_counter() - t1:.2f}s → {verdict}")
return (i, doc, verdict)
results: list[tuple[int, str, str]] = [None] * len(state.retrieved_docs)
with ThreadPoolExecutor(max_workers=len(state.retrieved_docs)) as pool:
futures = {pool.submit(_grade_one, i, doc): i for i, doc in enumerate(state.retrieved_docs)}
for future in as_completed(futures):
i, doc, verdict = future.result()
results[i] = (i, doc, verdict)
relevant = [doc for _, doc, verdict in results if verdict == "relevant"]
print(f"[timing] grade_docs total: {time.perf_counter() - t0:.2f}s → {len(relevant)}/{len(state.retrieved_docs)} relevant")
return {"relevant_docs": relevant}
# ─── Node: generate ───────────────────────────────────────────────────────────
GENERATE_SYSTEM = """你是飞享IM的专业客服助手,熟悉飞享IM的所有功能和使用方法。
请根据提供的参考资料,用中文准确、友好地回答用户问题。
- 回答控制在300字以内,简洁明了,条理清晰
- 如果参考资料不足以完整回答,请说明哪些信息你不确定
- 不要编造飞享IM没有的功能"""
def generate(state: QAState, retriever) -> dict:
t0 = time.perf_counter()
context = "\n\n---\n\n".join(state.relevant_docs)
messages = [SystemMessage(content=GENERATE_SYSTEM)]
messages.extend(_history_to_messages(state.history))
messages.append(HumanMessage(content=f"参考资料:\n{context}\n\n用户问题:{state.question}"))
llm = _llm(streaming=True, model=GENERATE_MODEL, thinking=False, max_tokens=600)
answer = (llm | StrOutputParser()).invoke(messages)
print(f"[timing] generate: {time.perf_counter() - t0:.2f}s")
return {"answer": answer}
# ─── Node: fallback ───────────────────────────────────────────────────────────
def fallback(state: QAState, retriever) -> dict:
answer = (
"抱歉,我目前的知识库中没有找到与您问题直接相关的信息。\n"
"建议您:\n"
"1. 访问飞享IM官网 https://fsharechat.cn 获取最新资料\n"
"2. 尝试换一种方式描述您的问题\n"
"3. 联系飞享IM官方客服获取专业帮助"
)
return {"answer": answer}
# ─── Node: reject (general Q&A for off-topic questions) ──────────────────────
GENERAL_SYSTEM = """你是一个知识渊博的AI助手,能够回答各种问题。
请用中文友好、准确地回答用户问题。
回答应简洁完整,除非用户明确要求详细展开,否则控制在5000字以内。"""
def reject(state: QAState, retriever) -> dict:
t0 = time.perf_counter()
messages = [SystemMessage(content=GENERAL_SYSTEM)]
messages.extend(_history_to_messages(state.history))
messages.append(HumanMessage(content=state.question))
llm = _llm(streaming=True, model=GENERATE_MODEL, thinking=False, max_tokens=8192)
answer = (llm | StrOutputParser()).invoke(messages)
print(f"[timing] general_answer: {time.perf_counter() - t0:.2f}s")
return {"answer": answer}
# ─── Routing ──────────────────────────────────────────────────────────────────
def route_after_classify(state: QAState) -> str:
return "retrieve" if state.route == "on_topic" else "reject"
def route_after_grade(state: QAState) -> str:
return "generate" if state.relevant_docs else "fallback"
# ─── Graph builder ────────────────────────────────────────────────────────────
def build_graph(retriever):
"""Build and compile the LangGraph workflow."""
# Bind retriever into each node via closure
def _classify(state: QAState): return classify(state, retriever)
def _retrieve(state: QAState): return retrieve(state, retriever)
def _grade(state: QAState): return grade_docs(state, retriever)
def _generate(state: QAState): return generate(state, retriever)
def _fallback(state: QAState): return fallback(state, retriever)
def _reject(state: QAState): return reject(state, retriever)
builder = StateGraph(QAState)
builder.add_node("classify", _classify)
builder.add_node("retrieve", _retrieve)
builder.add_node("grade_docs", _grade)
builder.add_node("generate", _generate)
builder.add_node("fallback", _fallback)
builder.add_node("reject", _reject)
builder.add_edge(START, "classify")
builder.add_conditional_edges("classify", route_after_classify)
builder.add_edge("retrieve", "grade_docs")
builder.add_conditional_edges("grade_docs", route_after_grade)
builder.add_edge("generate", END)
builder.add_edge("fallback", END)
builder.add_edge("reject", END)
return builder.compile()