状态维护#
在我们目前的示例中,一直通过自定义事件的属性在步骤间传递数据。这种方式虽然强大,但也存在局限性。例如,若要在非直接连接的步骤间传递数据,就必须通过所有中间步骤来传递,这会导致代码可读性和可维护性降低。
为避免这个问题,工作流中的每个步骤都可以访问 Context
对象。使用时只需在步骤中声明一个 Context
类型的参数即可。具体实现如下:
首先需要导入新的 Context
类型:
from llama_index.core.workflow import (
StartEvent,
StopEvent,
Workflow,
step,
Event,
Context,
)
现在我们定义一个 start
事件,它会检查上下文是否已加载数据。若未加载,则返回触发 setup
的 SetupEvent
,由 setup
加载数据后循环回到 start
:
class SetupEvent(Event):
query: str
class StepTwoEvent(Event):
query: str
class StatefulFlow(Workflow):
@step
async def start(
self, ctx: Context, ev: StartEvent
) -> SetupEvent | StepTwoEvent:
db = await ctx.store.get("some_database", default=None)
if db is None:
print("Need to load data")
return SetupEvent(query=ev.query)
# do something with the query
return StepTwoEvent(query=ev.query)
@step
async def setup(self, ctx: Context, ev: SetupEvent) -> StartEvent:
# load data
await ctx.store.set("some_database", [1, 2, 3])
return StartEvent(query=ev.query)
在 step_two
中,我们可以直接从上下文中访问数据而无需显式传递。这在生成式AI应用中特别适用于加载索引等大型数据操作:
@step
async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StopEvent:
# do something with the data
print("Data is ", await ctx.store.get("some_database"))
return StopEvent(result=await ctx.store.get("some_database"))
w = StatefulFlow(timeout=10, verbose=False)
result = await w.run(query="Some query")
print(result)
添加类型化状态#
通常,我们会预设某种结构作为工作流的状态。最佳实践是使用 Pydantic
模型来定义状态,这样可以:
注意: 应使用所有字段都有默认值的Pydantic模型,这样Context
对象才能用默认值自动初始化状态。
以下示例展示如何结合工作流和Pydantic来利用这些特性:
from pydantic import BaseModel, Field, field_validator, field_serializer
from typing import Union
from llama_index.core.workflow import (
Context,
Workflow,
StartEvent,
StopEvent,
step,
)
# 这是一个我们想在状态中使用的随机对象
class MyRandomObject:
def __init__(self, name: str = "default"):
self.name = name
# 这是我们的状态模型
# 注意:所有字段必须有默认值
class MyState(BaseModel):
model_config = {"arbitrary_types_allowed": True}
my_obj: MyRandomObject = Field(default_factory=MyRandomObject)
some_key: str = Field(default="some_value")
# 以下为可选内容,但可用于控制状态的序列化!
@field_serializer("my_obj", when_used="always")
def serialize_my_obj(self, my_obj: MyRandomObject) -> str:
return my_obj.name
@field_validator("my_obj", mode="before")
@classmethod
def deserialize_my_obj(
cls, v: Union[str, MyRandomObject]
) -> MyRandomObject:
if isinstance(v, MyRandomObject):
return v
if isinstance(v, str):
return MyRandomObject(v)
raise ValueError(f"Invalid type for my_obj: {type(v)}")
class MyStatefulFlow(Workflow):
@step
async def start(self, ctx: Context[MyState], ev: StartEvent) -> StopEvent:
# 直接返回MyState
state = await ctx.store.get_state()
state.my_obj.name = "new_name"
await ctx.store.set_state(state)
# 也可直接访问字段
name = await ctx.store.get("my_obj.name")
await ctx.store.set("my_obj.name", "newer_name")
return StopEvent(result="Done!")
w = MyStatefulFlow(timeout=10, verbose=False)
ctx = Context(w)
result = await w.run(ctx=ctx)
state = await ctx.store.get_state()
print(state)
接下来我们将学习如何流式传输事件来自进行中的工作流。