import nest_asyncio
nest_asyncio.apply()
%pip install -U llama-index llama-index-tools-tavily-research
import os
os.environ["OPENAI_API_KEY"] = "sk-proj-..."
tavily_ai_api_key = "<Your Tavily AI API Key>"
!mkdir -p 'data/'
!wget 'https://arxiv.org/pdf/2307.09288.pdf' -O 'data/llama2.pdf'
由于工作流默认采用异步优先模式,这些代码在笔记本环境中可以顺畅运行。若需在自主代码中执行,当不存在已启动的异步事件循环时,应使用 asyncio.run()
来创建运行环境。
async def main():
<async code>
if __name__ == "__main__":
import asyncio
asyncio.run(main())
工作流设计¶
修正型 RAG 包含以下步骤:
- 数据摄取 —— 将数据加载至索引并配置 Tavily AI。该步骤将独立运行,接收启动事件并返回停止事件。
- 检索 —— 根据查询获取最相关的节点。
- 相关性评估 —— 使用大语言模型(LLM)判断检索到的节点内容是否与查询相关。
- 相关性提取 —— 提取被 LLM 判定为相关的节点。
- 查询转换与 Tavily 搜索 —— 若节点不相关,则使用 LLM 转换查询以适配网络搜索。通过 Tavily 根据查询在网络上搜索相关答案。
- 响应生成 —— 根据相关节点文本和 Tavily 搜索结果构建摘要索引,并利用该索引针对原始查询生成结果。
需要以下事件:
PrepEvent
- 表示索引和其他对象已准备就绪的事件。RetrieveEvent
- 包含检索节点相关信息的事件。RelevanceEvalEvent
- 包含相关性评估结果列表的事件。TextExtractEvent
- 包含从相关节点提取的文本拼接字符串的事件。QueryEvent
- 同时包含相关文本和搜索文本的事件。
from llama_index.core.workflow import Event
from llama_index.core.schema import NodeWithScore
class PrepEvent(Event):
"""Prep event (prepares for retrieval)."""
pass
class RetrieveEvent(Event):
"""Retrieve event (gets retrieved nodes)."""
retrieved_nodes: list[NodeWithScore]
class RelevanceEvalEvent(Event):
"""Relevance evaluation event (gets results of relevance evaluation)."""
relevant_results: list[str]
class TextExtractEvent(Event):
"""Text extract event. Extracts relevant text and concatenates."""
relevant_text: str
class QueryEvent(Event):
"""Query event. Queries given relevant text and search text."""
relevant_text: str
search_text: str
以下是修正型 RAG 工作流的代码:
from llama_index.core.workflow import (
Workflow,
step,
Context,
StartEvent,
StopEvent,
)
from llama_index.core import (
VectorStoreIndex,
Document,
PromptTemplate,
SummaryIndex,
)
from llama_index.core.query_pipeline import QueryPipeline
from llama_index.llms.openai import OpenAI
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.base.base_retriever import BaseRetriever
DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.
Retrieved Document:
-------------------
{context_str}
User Question:
--------------
{query_str}
Evaluation Criteria:
- Consider whether the document contains keywords or topics related to the user's question.
- The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.
Decision:
- Assign a binary score to indicate the document's relevance.
- Use 'yes' if the document is relevant to the question, or 'no' if it is not.
Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)
DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
Analyze the given input to grasp the core semantic intent or meaning. \n
Original Query:
\n ------- \n
{query_str}
\n ------- \n
Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
Respond with the optimized query only:"""
)
class CorrectiveRAGWorkflow(Workflow):
@step
async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None:
"""Ingest step (for ingesting docs and initializing index)."""
documents: list[Document] | None = ev.get("documents")
if documents is None:
return None
index = VectorStoreIndex.from_documents(documents)
return StopEvent(result=index)
@step
async def prepare_for_retrieval(
self, ctx: Context, ev: StartEvent
) -> PrepEvent | None:
"""Prepare for retrieval."""
query_str: str | None = ev.get("query_str")
retriever_kwargs: dict | None = ev.get("retriever_kwargs", {})
if query_str is None:
return None
tavily_ai_apikey: str | None = ev.get("tavily_ai_apikey")
index = ev.get("index")
llm = OpenAI(model="gpt-4")
await ctx.store.set(
"relevancy_pipeline",
QueryPipeline(chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]),
)
await ctx.store.set(
"transform_query_pipeline",
QueryPipeline(chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]),
)
await ctx.store.set("llm", llm)
await ctx.store.set("index", index)
await ctx.store.set(
"tavily_tool", TavilyToolSpec(api_key=tavily_ai_apikey)
)
await ctx.store.set("query_str", query_str)
await ctx.store.set("retriever_kwargs", retriever_kwargs)
return PrepEvent()
@step
async def retrieve(
self, ctx: Context, ev: PrepEvent
) -> RetrieveEvent | None:
"""Retrieve the relevant nodes for the query."""
query_str = await ctx.store.get("query_str")
retriever_kwargs = await ctx.store.get("retriever_kwargs")
if query_str is None:
return None
index = await ctx.store.get("index", default=None)
tavily_tool = await ctx.store.get("tavily_tool", default=None)
if not (index or tavily_tool):
raise ValueError(
"Index and tavily tool must be constructed. Run with 'documents' and 'tavily_ai_apikey' params first."
)
retriever: BaseRetriever = index.as_retriever(**retriever_kwargs)
result = retriever.retrieve(query_str)
await ctx.store.set("retrieved_nodes", result)
await ctx.store.set("query_str", query_str)
return RetrieveEvent(retrieved_nodes=result)
@step
async def eval_relevance(
self, ctx: Context, ev: RetrieveEvent
) -> RelevanceEvalEvent:
"""Evaluate relevancy of retrieved documents with the query."""
retrieved_nodes = ev.retrieved_nodes
query_str = await ctx.store.get("query_str")
relevancy_results = []
for node in retrieved_nodes:
relevancy_pipeline = await ctx.store.get("relevancy_pipeline")
relevancy = relevancy_pipeline.run(
context_str=node.text, query_str=query_str
)
relevancy_results.append(relevancy.message.content.lower().strip())
await ctx.store.set("relevancy_results", relevancy_results)
return RelevanceEvalEvent(relevant_results=relevancy_results)
@step
async def extract_relevant_texts(
self, ctx: Context, ev: RelevanceEvalEvent
) -> TextExtractEvent:
"""Extract relevant texts from retrieved documents."""
retrieved_nodes = await ctx.store.get("retrieved_nodes")
relevancy_results = ev.relevant_results
relevant_texts = [
retrieved_nodes[i].text
for i, result in enumerate(relevancy_results)
if result == "yes"
]
result = "\n".join(relevant_texts)
return TextExtractEvent(relevant_text=result)
@step
async def transform_query_pipeline(
self, ctx: Context, ev: TextExtractEvent
) -> QueryEvent:
"""Search the transformed query with Tavily API."""
relevant_text = ev.relevant_text
relevancy_results = await ctx.store.get("relevancy_results")
query_str = await ctx.store.get("query_str")
# If any document is found irrelevant, transform the query string for better search results.
if "no" in relevancy_results:
qp = await ctx.store.get("transform_query_pipeline")
transformed_query_str = qp.run(query_str=query_str).message.content
# Conduct a search with the transformed query string and collect the results.
tavily_tool = await ctx.store.get("tavily_tool")
search_results = tavily_tool.search(
transformed_query_str, max_results=5
)
search_text = "\n".join([result.text for result in search_results])
else:
search_text = ""
return QueryEvent(relevant_text=relevant_text, search_text=search_text)
@step
async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent:
"""Get result with relevant text."""
relevant_text = ev.relevant_text
search_text = ev.search_text
query_str = await ctx.store.get("query_str")
documents = [Document(text=relevant_text + "\n" + search_text)]
index = SummaryIndex.from_documents(documents)
query_engine = index.as_query_engine()
result = query_engine.query(query_str)
return StopEvent(result=result)
运行工作流¶
from llama_index.core import SimpleDirectoryReader
documents = SimpleDirectoryReader("./data").load_data()
workflow = CorrectiveRAGWorkflow()
index = await workflow.run(documents=documents)
from IPython.display import Markdown, display
response = await workflow.run(
query_str="How was Llama2 pretrained?",
index=index,
tavily_ai_apikey=tavily_ai_api_key,
)
display(Markdown(str(response)))
Llama 2 was pretrained using an optimized auto-regressive transformer with several modifications to enhance performance. These modifications included more robust data cleaning, updated data mixes, training on 40% more total tokens, doubling the context length, and using grouped-query attention (GQA) to improve inference scalability for larger models.
response = await workflow.run(
query_str="What is the functionality of latest ChatGPT memory."
)
display(Markdown(str(response)))
The functionality of the latest ChatGPT memory is to autonomously remember information it deems relevant from conversations. This feature aims to save users from having to repeat information and make future conversations more helpful. Users have control over the chatbot's memory, being able to access and manage these memories as needed.