结构化输出的反思工作流程¶
本笔记本将指导您设置一个Workflow
,通过重试和错误反思来提供可靠的结构化输出。
本笔记本最适合与开源LLM配合使用,因此我们将使用Ollama
。如果您尚未运行Ollama,请访问https://ollama.com开始使用并下载您需要的模型。(在本例中,我们在运行此笔记本之前执行了ollama pull llama3.1
命令)。
In [ ]:
Copied!
!pip install -U llama-index llama-index-llms-ollama
!pip install -U llama-index llama-index-llms-ollama
由于工作流默认采用异步优先模式,这些操作在笔记本环境中都能顺畅运行。若您在自己的代码中执行,当不存在已启动的异步事件循环时,需使用 asyncio.run()
来启动:
async def main():
<async code>
if __name__ == "__main__":
import asyncio
asyncio.run(main())
In [ ]:
Copied!
from llama_index.core.workflow import Event
class ExtractionDone(Event):
output: str
passage: str
class ValidationErrorEvent(Event):
error: str
wrong_output: str
passage: str
from llama_index.core.workflow import Event
class ExtractionDone(Event):
output: str
passage: str
class ValidationErrorEvent(Event):
error: str
wrong_output: str
passage: str
待提取项¶
为了让模型理解提取需求,我们需要先定义一个期望提取的 pydantic 模型。
In [ ]:
Copied!
from pydantic import BaseModel
class Car(BaseModel):
brand: str
model: str
power: int
class CarCollection(BaseModel):
cars: list[Car]
from pydantic import BaseModel
class Car(BaseModel):
brand: str
model: str
power: int
class CarCollection(BaseModel):
cars: list[Car]
In [ ]:
Copied!
import json
from llama_index.core.workflow import (
Workflow,
StartEvent,
StopEvent,
Context,
step,
)
from llama_index.llms.ollama import Ollama
EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------
Given the context information and not prior knowledge, create a JSON object from the information in the context.
The JSON object must follow the JSON schema:
{schema}
"""
REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------
This caused the JSON decode error: {error}
Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""
class ReflectionWorkflow(Workflow):
max_retries: int = 3
@step
async def extract(
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
) -> StopEvent | ExtractionDone:
current_retries = await ctx.store.get("retries", default=0)
if current_retries >= self.max_retries:
return StopEvent(result="Max retries reached")
else:
await ctx.store.set("retries", current_retries + 1)
if isinstance(ev, StartEvent):
passage = ev.get("passage")
if not passage:
return StopEvent(result="Please provide some text in input")
reflection_prompt = ""
elif isinstance(ev, ValidationErrorEvent):
passage = ev.passage
reflection_prompt = REFLECTION_PROMPT.format(
wrong_answer=ev.wrong_output, error=ev.error
)
llm = Ollama(
model="llama3",
request_timeout=30,
# Manually set the context window to limit memory usage
context_window=8000,
)
prompt = EXTRACTION_PROMPT.format(
passage=passage, schema=CarCollection.schema_json()
)
if reflection_prompt:
prompt += reflection_prompt
output = await llm.acomplete(prompt)
return ExtractionDone(output=str(output), passage=passage)
@step
async def validate(
self, ev: ExtractionDone
) -> StopEvent | ValidationErrorEvent:
try:
CarCollection.model_validate_json(ev.output)
except Exception as e:
print("Validation failed, retrying...")
return ValidationErrorEvent(
error=str(e), wrong_output=ev.output, passage=ev.passage
)
return StopEvent(result=ev.output)
import json
from llama_index.core.workflow import (
Workflow,
StartEvent,
StopEvent,
Context,
step,
)
from llama_index.llms.ollama import Ollama
EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------
Given the context information and not prior knowledge, create a JSON object from the information in the context.
The JSON object must follow the JSON schema:
{schema}
"""
REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------
This caused the JSON decode error: {error}
Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""
class ReflectionWorkflow(Workflow):
max_retries: int = 3
@step
async def extract(
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
) -> StopEvent | ExtractionDone:
current_retries = await ctx.store.get("retries", default=0)
if current_retries >= self.max_retries:
return StopEvent(result="Max retries reached")
else:
await ctx.store.set("retries", current_retries + 1)
if isinstance(ev, StartEvent):
passage = ev.get("passage")
if not passage:
return StopEvent(result="Please provide some text in input")
reflection_prompt = ""
elif isinstance(ev, ValidationErrorEvent):
passage = ev.passage
reflection_prompt = REFLECTION_PROMPT.format(
wrong_answer=ev.wrong_output, error=ev.error
)
llm = Ollama(
model="llama3",
request_timeout=30,
# Manually set the context window to limit memory usage
context_window=8000,
)
prompt = EXTRACTION_PROMPT.format(
passage=passage, schema=CarCollection.schema_json()
)
if reflection_prompt:
prompt += reflection_prompt
output = await llm.acomplete(prompt)
return ExtractionDone(output=str(output), passage=passage)
@step
async def validate(
self, ev: ExtractionDone
) -> StopEvent | ValidationErrorEvent:
try:
CarCollection.model_validate_json(ev.output)
except Exception as e:
print("Validation failed, retrying...")
return ValidationErrorEvent(
error=str(e), wrong_output=ev.output, passage=ev.passage
)
return StopEvent(result=ev.output)
就是这样!让我们简单分析一下我们编写的工作流程。
- 我们有一个入口点
extract
(接收StartEvent
的步骤) - 当
extract
完成时,它会发出ExtractionDone
事件 validate
随即运行并对提取结果进行验证:- 如果验证通过,则发出
StopEvent
并终止工作流 - 如果验证失败,则返回包含错误信息的
ValidationErrorEvent
- 如果验证通过,则发出
- 任何发出的
ValidationErrorEvent
都会触发循环,使extract
再次运行! - 该循环将持续到结构化输出通过验证为止
运行工作流!¶
注意: 使用循环时需注意运行时间。此处我们设置了120秒的超时限制。
In [ ]:
Copied!
w = ReflectionWorkflow(timeout=120, verbose=True)
# Run the workflow
ret = await w.run(
passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp."
)
w = ReflectionWorkflow(timeout=120, verbose=True)
# Run the workflow
ret = await w.run(
passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp."
)
Running step extract Step extract produced event ExtractionDone Running step validate Validation failed, retrying... Step validate produced event ValidationErrorEvent Running step extract Step extract produced event ExtractionDone Running step validate Step validate produced event StopEvent
In [ ]:
Copied!
print(ret)
print(ret)
{ "cars": [ { "brand": "Fiat", "model": "Panda", "power": 45 }, { "brand": "Honda", "model": "Civic", "power": 330 } ] }