langgraph / / 2024. 12. 1. 18:11

[langgraph] 최종 노드(final node)에서 스트리밍하는 방법

LangGraph 공식문서를 번역한 내용입니다. 필요한 경우 부연 설명을 추가하였고 이해하기 쉽게 예제를 일부 변경하였습니다. 문제가 되면 삭제하겠습니다.

https://langchain-ai.github.io/langgraph/how-tos/streaming-from-final-node/

에이전트에서 스트리밍을 사용할 때 일반적인 사용형태는 최종 노드 내부에서 LLM 토큰을 스트리밍하는 것이다. 여기서는 이를 수행하는 방법을 보여준다.

준비

우선, 필요한 패키지를 설치하자.

pip install langgraph langchain-openai langchain-community

모델과 도구 정의

import asyncio
from typing import Literal

from dotenv import load_dotenv
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode

load_dotenv()


@tool
def get_weather(city: Literal["서울", "부산"]):
    """Use this to get weather information."""
    if city == "서울":
        return "맑아요~"
    elif city == "부산":
        return "비와요~"
    else:
        raise AssertionError("Unknown city")


tools = [get_weather]
model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
final_model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)

model = model.bind_tools(tools)
# 마지막 노드에서 호출된 모델만 필터링하려면 모델 스트림 이벤트를 필터링하는 데 사용할 수 있는 태그를 추가하는 곳이다.
# 단일 LLM을 호출하는 경우 필요하지 않지만 노드 내에서 여러 모델을 호출하고 그 중 하나의 이벤트만 필터링하려는 경우 중요할 수 있다.
final_model = final_model.with_config(tags=["final_node"])
tool_node = ToolNode(tools=tools)

그래프 정의

from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import MessagesState
from langchain_core.messages import SystemMessage, HumanMessage


def should_continue(state: MessagesState) -> Literal["tools", "final"]:
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.tool_calls:
        return "tools"
    return "final"


def call_model(state: MessagesState):
    messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}


def call_final_model(state: MessagesState):
    messages = state["messages"]
    last_ai_message = messages[-1]
    response = final_model.invoke(
        [
            SystemMessage("AI 로커 목소리로 다시 작성해주세요"),
            HumanMessage(last_ai_message.content),
        ]
    )
    response.id = last_ai_message.id
    return {"messages": [response]}


builder = StateGraph(MessagesState)

builder.add_node("agent", call_model)
builder.add_node("tools", tool_node)
builder.add_node("final", call_final_model)

builder.add_edge(START, "agent")
builder.add_conditional_edges(
    "agent",
    should_continue,
)

builder.add_edge("tools", "agent")
builder.add_edge("final", END)

graph = builder.compile()
from IPython.display import display, Image

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            output_file_path="how-to-stream-from-final-node.png"
        )
    )
)

최종 노드에서 출력값 스트리밍

이벤트 메터데이터 필터링

특정 노드(이 경우 최종 노드) 내부에서 LLM 이벤트를 가져오는 첫 번째 방법은 이벤트 메타데이터의 langgraph_node 필드를 필터링하는 것이다. 이 방법은 노드 내부에서 발생하는 모든 LLM 호출의 이벤트를 스트리밍해야 하는 경우 충분하다. 즉, 노드 내부에서 여러 다른 LLM이 호출되더라도 이 필터는 모든 LLM에서 발생한 이벤트를 포함한다.

from langchain_core.messages import HumanMessage

inputs = {"messages": [HumanMessage(content="서울 날씨 어때?")]}
for msg, metadata in graph.stream(inputs, stream_mode="messages"):
    if (
        msg.content
        and not isinstance(msg, HumanMessage)
        and metadata["langgraph_node"] == "final"
    ):
        print(msg.content, end="|", flush=True)
서울|의| 날|씨|는| 맑|습니다|!| 기|분| 좋은| 하루|가| 될| 것| 같|네요|!|

커스텀 태그 필터링

또는 초기 단계에서처럼 final_model.with_config(tags=["final_node"])를 추가하여 LLM에 사용자 정의 태그를 사용하는 구성을 추가할 수 있다. 이를 통해 해당 모델에서만 발생한 이벤트를 더 정확히 필터링할 수 있다.

async def stream_content():
    async for event in graph.astream_events(inputs, version="v2"):
        kind = event["event"]
        tags = event.get("tags", [])
        # 커스텀 태그를 기반으로 필터링
        if kind == "on_chat_model_stream" and "final_node" in event.get("tags", []):
            data = event["data"]
            if data["chunk"].content:
                print(data["chunk"].content, end="|", flush=True)


asyncio.run(stream_content())
부|산|에서는| 현재| 비|가| 내|리고| 있습니다|.| 외|출|하|실| 때|는| 꼭| 우|산|을| 챙|기|시|기| 바랍니다|!|
반응형
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유
  • 카카오스토리 공유