对话翻译App
20.8MB · 2026-04-06
在构建任何记忆策略之前,先定义统一接口与 Agent 调度流程,这样策略可以“即插即用”。
import abc
# 记忆策略的统一接口(抽象基类)
class BaseMemoryStrategy(abc.ABC):
@abc.abstractmethod
def add_message(self, user_input: str, ai_response: str):
"""将一轮对话写入记忆"""
pass
@abc.abstractmethod
def get_context(self, query: str) -> str:
"""根据当前问题提取上下文"""
pass
@abc.abstractmethod
def clear(self):
"""清空记忆"""
pass
class AIAgent:
"""统一的 Agent 逻辑:取记忆 -> 构造提示词 -> 调用 LLM -> 更新记忆"""
def __init__(self, memory_strategy: BaseMemoryStrategy, system_prompt: str = "You are a helpful AI assistant."):
self.memory = memory_strategy
self.system_prompt = system_prompt
def chat(self, user_input: str) -> str:
# 1) 获取记忆上下文
context = self.memory.get_context(query=user_input)
# 2) 拼接提示词
full_user_prompt = f"### MEMORY CONTEXTn{context}nn### CURRENT REQUESTn{user_input}"
# 3) 调用 LLM(generate_text 来自你的工具函数)
ai_response = generate_text(self.system_prompt, full_user_prompt)
# 4) 写回记忆
self.memory.add_message(user_input, ai_response)
return ai_response
特点:把所有对话完整保存并拼接。优点是“记得全”,缺点是上下文无限增长。
class SequentialMemory(BaseMemoryStrategy):
def __init__(self):
# 使用列表存所有对话
self.history = []
def add_message(self, user_input: str, ai_response: str):
# 依次写入用户与助手消息
self.history.append({"role": "user", "content": user_input})
self.history.append({"role": "assistant", "content": ai_response})
def get_context(self, query: str) -> str:
# 将历史对话拼接成一段文本
return "n".join([f"{t['role'].capitalize()}: {t['content']}" for t in self.history])
def clear(self):
# 清空历史
self.history = []
特点:只保留最近 N 轮对话,成本稳定但会遗忘。
from collections import deque
class SlidingWindowMemory(BaseMemoryStrategy):
def __init__(self, window_size: int = 4):
# deque 自动维护长度上限
self.history = deque(maxlen=window_size)
def add_message(self, user_input: str, ai_response: str):
# 一轮对话作为一个“turn”写入
self.history.append([
{"role": "user", "content": user_input},
{"role": "assistant", "content": ai_response}
])
def get_context(self, query: str) -> str:
# 展开 deque 生成上下文
ctx = []
for turn in self.history:
for msg in turn:
ctx.append(f"{msg['role'].capitalize()}: {msg['content']}")
return "n".join(ctx)
def clear(self):
self.history.clear()
特点:对话到阈值后,让 LLM 生成摘要并合并。适合长对话。
class SummarizationMemory(BaseMemoryStrategy):
def __init__(self, summary_threshold: int = 4):
self.running_summary = ""
self.buffer = []
self.summary_threshold = summary_threshold
def add_message(self, user_input: str, ai_response: str):
# 先把消息放进缓冲区
self.buffer.append({"role": "user", "content": user_input})
self.buffer.append({"role": "assistant", "content": ai_response})
# 达到阈值则触发总结
if len(self.buffer) >= self.summary_threshold:
self._consolidate_memory()
def _consolidate_memory(self):
# 将缓冲区文本拼接
buffer_text = "n".join([f"{m['role'].capitalize()}: {m['content']}" for m in self.buffer])
# 构造总结提示词
prompt = (
"You are a summarization expert.n"
f"### Previous Summary:n{self.running_summary}nn"
f"### New Conversation:n{buffer_text}nn"
"### Updated Summary:"
)
# 调用 LLM 生成摘要
self.running_summary = generate_text("You are a summarization engine.", prompt)
# 清空缓冲区
self.buffer = []
def get_context(self, query: str) -> str:
buffer_text = "n".join([f"{m['role'].capitalize()}: {m['content']}" for m in self.buffer])
return f"### Summary:n{self.running_summary}nn### Recent:n{buffer_text}"
def clear(self):
self.running_summary = ""
self.buffer = []
特点:用向量检索长程相关信息,最常用的“长期记忆”方案。
import numpy as np
import faiss
class RetrievalMemory(BaseMemoryStrategy):
def __init__(self, k: int = 2, embedding_dim: int | None = None):
self.k = k
self.embedding_dim = embedding_dim
self.documents = []
# embedding_dim 未知时,先不初始化 index
self.index = faiss.IndexFlatL2(embedding_dim) if embedding_dim else None
def _ensure_index(self, embedding: list):
# 首次写入时用向量长度确定维度
if self.embedding_dim is None:
self.embedding_dim = len(embedding)
self.index = faiss.IndexFlatL2(self.embedding_dim)
# 若维度不匹配,直接抛错
elif len(embedding) != self.embedding_dim:
raise ValueError(f"Embedding dim {len(embedding)} != index dim {self.embedding_dim}")
def add_message(self, user_input: str, ai_response: str):
docs = [f"User said: {user_input}", f"AI responded: {ai_response}"]
for doc in docs:
emb = get_embedding(doc)
if emb:
self._ensure_index(emb)
self.documents.append(doc)
self.index.add(np.array([emb], dtype="float32"))
def get_context(self, query: str) -> str:
if self.index is None or self.index.ntotal == 0:
return "No information in memory yet."
q = get_embedding(query)
if not q:
return "Could not process query for retrieval."
if len(q) != self.embedding_dim:
return "Query embedding dimension mismatch with index."
D, I = self.index.search(np.array([q], dtype="float32"), self.k)
retrieved = [self.documents[i] for i in I[0] if i != -1]
return "### Retrieved:n" + "n---n".join(retrieved)
def clear(self):
self.documents = []
if self.index is not None:
self.index.reset()
特点:让 LLM 识别“关键事实”,生成长期“记忆 token”。
class MemoryAugmentedMemory(BaseMemoryStrategy):
def __init__(self, window_size: int = 2):
self.recent_memory = SlidingWindowMemory(window_size=window_size)
self.memory_tokens = []
def add_message(self, user_input: str, ai_response: str):
# 先写入短期记忆
self.recent_memory.add_message(user_input, ai_response)
# 让 LLM 抽取关键事实
prompt = (
"Analyze the following turn and extract any long-term fact.n"
f"User: {user_input}nAI: {ai_response}n"
"If none, reply 'No important fact.'"
)
fact = generate_text("You are a fact-extraction expert.", prompt)
if "no important fact" not in fact.lower():
self.memory_tokens.append(fact)
def get_context(self, query: str) -> str:
recent = self.recent_memory.get_context(query)
tokens = "n".join([f"- {t}" for t in self.memory_tokens])
return f"### Memory Tokens:n{tokens}nn### Recent:n{recent}"
def clear(self):
self.recent_memory.clear()
self.memory_tokens = []
特点:短期用滑窗,长期用检索,触发关键词时晋升。
class HierarchicalMemory(BaseMemoryStrategy):
def __init__(self, window_size: int = 2, k: int = 2, embedding_dim: int = 4096):
self.working_memory = SlidingWindowMemory(window_size=window_size)
self.long_term_memory = RetrievalMemory(k=k, embedding_dim=embedding_dim)
self.promotion_keywords = ["remember", "rule", "preference", "always", "never", "allergic"]
def add_message(self, user_input: str, ai_response: str):
self.working_memory.add_message(user_input, ai_response)
# 触发关键词则进入长期记忆
if any(k in user_input.lower() for k in self.promotion_keywords):
self.long_term_memory.add_message(user_input, ai_response)
def get_context(self, query: str) -> str:
working = self.working_memory.get_context(query)
long_term = self.long_term_memory.get_context(query)
return f"### Long-Term:n{long_term}nn### Working:n{working}"
def clear(self):
self.working_memory.clear()
self.long_term_memory.clear()
特点:抽取三元组构建知识图谱,适合关系推理。
import networkx as nx
import re
class GraphMemory(BaseMemoryStrategy):
def __init__(self):
self.graph = nx.DiGraph()
def _extract_triples(self, text: str):
prompt = (
"Extract Subject-Relation-Object triples as Python tuples.n"
f"Text:n{text}"
)
response = generate_text("You are a KG extractor.", prompt)
return re.findall(r"(['"](.*?)['"],s*['"](.*?)['"],s*['"](.*?)['"])", response)
def add_message(self, user_input: str, ai_response: str):
triples = self._extract_triples(f"User: {user_input}nAI: {ai_response}")
for s, r, o in triples:
self.graph.add_edge(s.strip(), o.strip(), relation=r.strip())
def get_context(self, query: str) -> str:
if not self.graph.nodes:
return "The knowledge graph is empty."
entities = [w.capitalize() for w in query.replace("?", "").split() if w.capitalize() in self.graph.nodes]
if not entities:
return "No relevant entities from your query were found in the knowledge graph."
facts = []
for e in set(entities):
for u, v, d in self.graph.out_edges(e, data=True):
facts.append(f"{u} --[{d['relation']}]--> {v}")
return "### Facts Retrieved from Knowledge Graph:n" + "n".join(sorted(set(facts)))
def clear(self):
self.graph.clear()
特点:把每轮对话压缩为“极简事实”,超省 token。
class CompressionMemory(BaseMemoryStrategy):
def __init__(self):
self.compressed_facts = []
def add_message(self, user_input: str, ai_response: str):
prompt = (
"Compress the following into its most essential factual statement.n"
f"User: {user_input}nAI: {ai_response}"
)
fact = generate_text("You are a data compressor.", prompt)
self.compressed_facts.append(fact)
def get_context(self, query: str) -> str:
if not self.compressed_facts:
return "No compressed facts in memory."
return "### Compressed Facts:n- " + "n- ".join(self.compressed_facts)
def clear(self):
self.compressed_facts = []
特点:模拟“内存/硬盘”分页,按需调入旧信息。
class OSMemory(BaseMemoryStrategy):
def __init__(self, ram_size: int = 2):
self.ram_size = ram_size
self.active_memory = deque()
self.passive_memory = {}
self.turn_count = 0
def add_message(self, user_input: str, ai_response: str):
turn_id = self.turn_count
turn_data = f"User: {user_input}nAI: {ai_response}"
# RAM 满则页面换出
if len(self.active_memory) >= self.ram_size:
lru_id, lru_data = self.active_memory.popleft()
self.passive_memory[lru_id] = lru_data
# 新页面写入 RAM
self.active_memory.append((turn_id, turn_data))
self.turn_count += 1
def get_context(self, query: str) -> str:
active = "n".join([d for _, d in self.active_memory])
# 简化版“缺页”逻辑:关键词命中则调入
paged_in = ""
for tid, data in self.passive_memory.items():
if any(w in data.lower() for w in query.lower().split() if len(w) > 3):
paged_in += f"n(Paged in Turn {tid}): {data}"
return f"### RAM:n{active}nn### Disk:n{paged_in}"
def clear(self):
self.active_memory.clear()
self.passive_memory = {}
self.turn_count = 0