如何使用Pydantic模型作为图状态¶
一个StateGraph在初始化时接受一个state_schema
参数,该参数指定了图中节点可以访问和更新的状态的“形状”。
在我们的示例中,我们通常使用原生的python TypedDict
作为state_schema
(或者在MessageGraph的情况下,一个list),但state_schema
可以是任何type。
在这份操作指南中,我们将看到如何使用Pydantic BaseModel作为state_schema
来添加对**输入**的运行时验证。
已知限制
-
此笔记本使用Pydantic v2
BaseModel
,这需要langchain-core >= 0.3
。使用langchain-core < 0.3
会导致错误,因为混合了Pydantic v1和v2BaseModels
。 - 目前,图的`output`将**不会**是pydantic模型的一个实例。
- 运行时验证只发生在进入节点的**输入**上,而不是输出上。
- 来自pydantic的验证错误跟踪不会显示错误出现在哪个节点。
环境搭建¶
首先我们需要安装所需的包
import getpass
import os
def _set_env(var: str):
if not os.environ.get(var):
os.environ[var] = getpass.getpass(f"{var}: ")
_set_env("OPENAI_API_KEY")
为LangGraph开发设置LangSmith
注册LangSmith可以快速发现并解决您的LangGraph项目中的问题,并提高其性能。使用LangSmith,您可以利用跟踪数据来调试、测试和监控使用LangGraph构建的LLM应用程序——更多关于如何开始的信息,请参阅这里。
输入验证¶
API Reference: StateGraph | START | END
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from pydantic import BaseModel
# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
a: str
def node(state: OverallState):
return {"a": "goodbye"}
# Build the state graph
builder = StateGraph(OverallState)
builder.add_node(node) # node_1 is the first node
builder.add_edge(START, "node") # Start the graph with node_1
builder.add_edge("node", END) # End the graph after node_1
graph = builder.compile()
# Test the graph with a valid input
graph.invoke({"a": "hello"})
使用一个**无效**的输入来调用图
try:
graph.invoke({"a": 123}) # Should be a string
except Exception as e:
print("An exception was raised because `a` is an integer rather than a string.")
print(e)
An exception was raised because `a` is an integer rather than a string.
1 validation error for OverallState
a
Input should be a valid string [type=string_type, input_value=123, input_type=int]
For further information visit https://errors.pydantic.dev/2.9/v/string_type
多个节点¶
运行时验证也可以在多节点图中工作。在下面的例子中,bad_node
将 a
更新为一个整数。
由于运行时验证发生在**输入**上,因此当调用 ok_node
时(而不是当 bad_node
返回一个与模式不一致的状态更新时),会触发验证错误。
API Reference: StateGraph | START | END
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from pydantic import BaseModel
# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
a: str
def bad_node(state: OverallState):
return {
"a": 123 # Invalid
}
def ok_node(state: OverallState):
return {"a": "goodbye"}
# Build the state graph
builder = StateGraph(OverallState)
builder.add_node(bad_node)
builder.add_node(ok_node)
builder.add_edge(START, "bad_node")
builder.add_edge("bad_node", "ok_node")
builder.add_edge("ok_node", END)
graph = builder.compile()
# Test the graph with a valid input
try:
graph.invoke({"a": "hello"})
except Exception as e:
print("An exception was raised because bad_node sets `a` to an integer.")
print(e)
An exception was raised because bad_node sets `a` to an integer.
1 validation error for OverallState
a
Input should be a valid string [type=string_type, input_value=123, input_type=int]
For further information visit https://errors.pydantic.dev/2.9/v/string_type
多节点¶
运行时验证也可以在多节点图中工作。在下面的例子中,bad_node
将 a
更新为一个整数。
因为运行时验证发生在**输入**上,所以当调用 ok_node
时(而不是当 bad_node
返回一个与模式不一致的状态更新时)会出现验证错误。
API Reference: StateGraph | START | END
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from pydantic import BaseModel
# The overall state of the graph (this is the public state shared across nodes)
class OverallState(BaseModel):
a: str
def bad_node(state: OverallState):
return {
"a": 123 # Invalid
}
def ok_node(state: OverallState):
return {"a": "goodbye"}
# Build the state graph
builder = StateGraph(OverallState)
builder.add_node(bad_node)
builder.add_node(ok_node)
builder.add_edge(START, "bad_node")
builder.add_edge("bad_node", "ok_node")
builder.add_edge("ok_node", END)
graph = builder.compile()
# Test the graph with a valid input
try:
graph.invoke({"a": "hello"})
except Exception as e:
print("An exception was raised because bad_node sets `a` to an integer.")
print(e)
高级 Pydantic 模型用法¶
本节介绍在使用 LangGraph 时,Pydantic 模型的一些高级主题。
序列化行为¶
当使用 Pydantic 模型作为状态模式时,了解序列化的运作方式非常重要,尤其是在以下情况下: - 将 Pydantic 对象作为输入传递 - 从图接收输出 - 处理嵌套的 Pydantic 模型
让我们看看这些行为的实际应用:
API Reference: StateGraph | START | END
from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel
class NestedModel(BaseModel):
value: str
class ComplexState(BaseModel):
text: str
count: int
nested: NestedModel
def process_node(state: ComplexState):
# Node receives a validated Pydantic object
print(f"Input state type: {type(state)}")
print(f"Nested type: {type(state.nested)}")
# Return a dictionary update
return {"text": state.text + " processed", "count": state.count + 1}
# Build the graph
builder = StateGraph(ComplexState)
builder.add_node("process", process_node)
builder.add_edge(START, "process")
builder.add_edge("process", END)
graph = builder.compile()
# Create a Pydantic instance for input
input_state = ComplexState(text="hello", count=0, nested=NestedModel(value="test"))
print(f"Input object type: {type(input_state)}")
# Invoke graph with a Pydantic instance
result = graph.invoke(input_state)
print(f"Output type: {type(result)}")
print(f"Output content: {result}")
# Convert back to Pydantic model if needed
output_model = ComplexState(**result)
print(f"Converted back to Pydantic: {type(output_model)}")
运行时类型转换¶
Pydantic 在运行时对某些数据类型执行类型转换。这在某些情况下可能很有帮助,但如果不清楚这一点,也可能导致意外的行为。
API Reference: StateGraph | START | END
from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel
class CoercionExample(BaseModel):
# Pydantic will coerce string numbers to integers
number: int
# Pydantic will parse string booleans to bool
flag: bool
def inspect_node(state: CoercionExample):
print(f"number: {state.number} (type: {type(state.number)})")
print(f"flag: {state.flag} (type: {type(state.flag)})")
return {}
builder = StateGraph(CoercionExample)
builder.add_node("inspect", inspect_node)
builder.add_edge(START, "inspect")
builder.add_edge("inspect", END)
graph = builder.compile()
# Demonstrate coercion with string inputs that will be converted
result = graph.invoke({"number": "42", "flag": "true"})
# This would fail with a validation error
try:
graph.invoke({"number": "not-a-number", "flag": "true"})
except Exception as e:
print(f"\nExpected validation error: {e}")
使用消息模型¶
在您的状态模式中使用 LangChain 消息类型时,有几点重要的考虑因素涉及到序列化。当通过网络传输消息对象时,应使用 AnyMessage
(而不是 BaseMessage
)以实现正确的序列化/反序列化:
API Reference: StateGraph | START | END | HumanMessage | AIMessage | AnyMessage
from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, AIMessage, AnyMessage
from typing import List
class ChatState(BaseModel):
messages: List[AnyMessage]
context: str
def add_message(state: ChatState):
return {"messages": state.messages + [AIMessage(content="Hello there!")]}
builder = StateGraph(ChatState)
builder.add_node("add_message", add_message)
builder.add_edge(START, "add_message")
builder.add_edge("add_message", END)
graph = builder.compile()
# Create input with a message
initial_state = ChatState(
messages=[HumanMessage(content="Hi")], context="Customer support chat"
)
result = graph.invoke(initial_state)
print(f"Output: {result}")
# Convert back to Pydantic model to see message types
output_model = ChatState(**result)
for i, msg in enumerate(output_model.messages):
print(f"Message {i}: {type(msg).__name__} - {msg.content}")