Skip to content

状态维护#

在我们目前的示例中,一直通过自定义事件的属性在步骤间传递数据。这种方式虽然强大,但也存在局限性。例如,若要在非直接连接的步骤间传递数据,就必须通过所有中间步骤来传递,这会导致代码可读性和可维护性降低。

为避免这个问题,工作流中的每个步骤都可以访问 Context 对象。使用时只需在步骤中声明一个 Context 类型的参数即可。具体实现如下:

首先需要导入新的 Context 类型:

from llama_index.core.workflow import (
    StartEvent,
    StopEvent,
    Workflow,
    step,
    Event,
    Context,
)

现在我们定义一个 start 事件,它会检查上下文是否已加载数据。若未加载,则返回触发 setupSetupEvent,由 setup 加载数据后循环回到 start

class SetupEvent(Event):
    query: str


class StepTwoEvent(Event):
    query: str


class StatefulFlow(Workflow):
    @step
    async def start(
        self, ctx: Context, ev: StartEvent
    ) -> SetupEvent | StepTwoEvent:
        db = await ctx.store.get("some_database", default=None)
        if db is None:
            print("Need to load data")
            return SetupEvent(query=ev.query)

        # do something with the query
        return StepTwoEvent(query=ev.query)

    @step
    async def setup(self, ctx: Context, ev: SetupEvent) -> StartEvent:
        # load data
        await ctx.store.set("some_database", [1, 2, 3])
        return StartEvent(query=ev.query)

step_two 中,我们可以直接从上下文中访问数据而无需显式传递。这在生成式AI应用中特别适用于加载索引等大型数据操作:

@step
async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StopEvent:
    # do something with the data
    print("Data is ", await ctx.store.get("some_database"))

    return StopEvent(result=await ctx.store.get("some_database"))


w = StatefulFlow(timeout=10, verbose=False)
result = await w.run(query="Some query")
print(result)

添加类型化状态#

通常,我们会预设某种结构作为工作流的状态。最佳实践是使用 Pydantic 模型来定义状态,这样可以:

  • 获得状态类型提示
  • 自动验证状态
  • (可选)通过验证器序列化器完全控制状态的序列化与反序列化

注意: 应使用所有字段都有默认值的Pydantic模型,这样Context对象才能用默认值自动初始化状态。

以下示例展示如何结合工作流和Pydantic来利用这些特性:

from pydantic import BaseModel, Field, field_validator, field_serializer
from typing import Union

from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    step,
)


# 这是一个我们想在状态中使用的随机对象
class MyRandomObject:
    def __init__(self, name: str = "default"):
        self.name = name


# 这是我们的状态模型
# 注意:所有字段必须有默认值
class MyState(BaseModel):
    model_config = {"arbitrary_types_allowed": True}
    my_obj: MyRandomObject = Field(default_factory=MyRandomObject)
    some_key: str = Field(default="some_value")

    # 以下为可选内容,但可用于控制状态的序列化!

    @field_serializer("my_obj", when_used="always")
    def serialize_my_obj(self, my_obj: MyRandomObject) -> str:
        return my_obj.name

    @field_validator("my_obj", mode="before")
    @classmethod
    def deserialize_my_obj(
        cls, v: Union[str, MyRandomObject]
    ) -> MyRandomObject:
        if isinstance(v, MyRandomObject):
            return v
        if isinstance(v, str):
            return MyRandomObject(v)

        raise ValueError(f"Invalid type for my_obj: {type(v)}")


class MyStatefulFlow(Workflow):
    @step
    async def start(self, ctx: Context[MyState], ev: StartEvent) -> StopEvent:
        # 直接返回MyState
        state = await ctx.store.get_state()
        state.my_obj.name = "new_name"
        await ctx.store.set_state(state)

        # 也可直接访问字段
        name = await ctx.store.get("my_obj.name")
        await ctx.store.set("my_obj.name", "newer_name")

        return StopEvent(result="Done!")


w = MyStatefulFlow(timeout=10, verbose=False)

ctx = Context(w)
result = await w.run(ctx=ctx)
state = await ctx.store.get_state()
print(state)

接下来我们将学习如何流式传输事件来自进行中的工作流。