langgraph / / 2024. 11. 7. 07:52

[langgraph] Persistence

udemy 강좌 내용[LangGraph- Develop LLM powered AI agents with LangGraph]의 일부를 정리한 내용입니다. 예제의 일부 내용을 수정하였습니다. 문제가 되면 삭제하겠습니다.

https://www.udemy.com/course/langgraph/learn/lecture/44431680#overview

LangGraph의 영속성(Persistence)에 대해서 알아보자. 영속성은 운영모드에서는 아주 중요한 기능이다. LangGraph에서의 영속성은 node가 실행된 이후 특정 상태(state)를 저장할 수 있다는 것을 말한다. 저장 후 필요한 시점에 조회하여 사용될 수 있다는 것을 의미한다.

이것이 왜 중요할까? 사람이 개입(Human in the loop)하는 워크플로우에서 사용될 수 있기 때문이다. 대부분 애플리케이션에서 사용자의 입력 상태는 각 node와는 독립적이어야 한다. 그래서 LangGraph가 제공하는 방법은 그래프 실행을 멈추는 것이다. LangGraph는 영속 저장소에 상태를 저장하고 사용자 입력을 받을 수 있다. 그리고 그 지점부터 그래프 실행을 재개할 수 있는 것이다.

LangGraph에서는 영속성을 위해 다양한 데이터베이스를 지원하지만 여기서는 SQLite를 사용하여 상태를 저장한다.

from langgraph.checkpoint.sqlite import SqliteSaver

memory = SqliteSaver.from_conn_string(":checkpoints.sqlite")
graph = workflow.compile(checkpointer=memory)

여기서 SqliteSaver로 memory를 만들어 workflow.compile할 때 checkpointer로 넘긴다. 이렇게 함으로써 각 node 실행 후에 state를 저장한다. 기본적으로 각 노드 실행 후에 매번 새로운 상태를 가지게 된다. LangGraph는 필요한 경우 접근하여 조회할 수 있게 된다.

MemorySaver + Interrupts = Human in the loop

MemorySaver는 각 node 실행 후 state를 저장하는 checkpointer이다. 하지만 메모리에 저장하기 때문에 graph 실행 후에 사라진다.

우선 State를 정의하자.

class State(TypedDict):
    input: str
    user_feedback: str

    def step_1(state: State) -> None:
    print("---Step 1---")


def human_feedback(state: State) -> None:
    print("---Human Feedback---")


def step_3(state: State) -> None:
    print("---Step 3---")

여기서 사용자의 입력(input)을 받고 그래프 실행 중에 피드백(user_feedback)을 받을 것이다.

그리고 아무 작업도 하지 않는 step_1을 만들고 역시 아무 작업도 하지 않는 사용자 피드백(human_feedback)을 받는 노드를 만든다. 그리고 step_3를 만든다.

모든 node를 연결하자.

builder = StateGraph(State)
builder.add_node("step_1", step_1)
builder.add_node("human_feedback", human_feedback)
builder.add_node("step_3", step_3)
builder.add_edge(START, "step_1")
builder.add_edge("step_1", "human_feedback")
builder.add_edge("human_feedback", "step_3")
builder.add_edge("step_3", END)

그리고 아래와 같이 edge를 연결한다.

START -> step_1
step_1 -> human_feedback
human_feedback -> step_3
step_3 -> END

memory = MemorySaver()

graph = builder.compile(checkpointer=memory, interrupt_before=["human_feedback"])
graph.get_graph().draw_mermaid_png(output_file_path="graph.png")

이제 MemorySaver를 만든다. graph를 컴파일할 때 memorySaver를 만든다. 이것은 각 그래프를 실행할 때 state를 memory에 저장하는 역할을 한다.

builder를 컴파일할 때 checkpointer에 memory를 넣는다. 추가 파라미터로 interrupt_before가 있는 것을 볼 수 있다. 여기서 node을 human_feedback으로 넣는다. 이것은 human_feedback node를 실행하기 전 그래프 실행을 멈춘다는 것을 의미한다. 그래프가 멈춘 지점에서 그래프 상태를 저장했기 때문에 사용자의 입력을 받을 수가 있는 것이다. 그리고 우리가 멈춘 지점에서부터 그래프 실행을 재개할 수 있다.

이 모든 것은 멈춘 지점을 기억할 수 있게 하는 checkpointer 덕분이다.

MemorySaver

if __name__ == "__main__":
    thread = {"configurable": {"thread_id": "1"}}

    initial_input = {"input": "hello world"}

    for event in graph.stream(initial_input, thread, stream_mode="values"):
        print(event)

    print(graph.get_state(thread).next)

    user_input = input("Tell me how you want to update the state: ")

    graph.update_state(thread, {"user_feedback": user_input}, as_node="human_feedback")

    print("--State after update--")
    print(graph.get_state(thread))

    print(graph.get_state(thread).next)

    for event in graph.stream(initial_input, thread, stream_mode="values"):
        print(event)

thread 변수를 dictionary로 만든다. thread_id를 가지고 있는 configurable 하나의 값만을 가지고 있다. thread_id는 session_id 혹은 conversation_id로 생각할 수도 있다. thread_id는 그래프를 실행할 때 다른 것과 구별해주는 역할을 한다. 만일 다른 사용자가 다른 대화를 하고 있다면 이것이 구분하는 역할을 하는 것이다. 여기에 uuid같은 것을 넣을 수 있고 간단한 값을 넣을 수도 있다.

최초 입력을 hello world가 넣고 graph.stream이 입력 값을 받는다. 여기서 출력되는 값은 아래와 같다.

{'input': 'hello world'}
---Step 1---

입력값이 출력되고 Step 1이 출력된다.

graph.get_state(thread).next는 현재 thread의 그래프 상태에서 다음 노드를 표시해준다. 여기서 출력해보면 ('human_feedback',)으로 나온다. human_input에서 그래프가 실행을 중지한다. 그리고 사람의 입력을 받을 것이다.

input으로 입력을 받고 그래프를 update한다. 그리고 다음 node를 실행한다.

그래프의 현재 상태를 debug 하려면 graph.get_state(thread)를 하면 된다.

user_feedback을 받고 --State after update--를 출력해보면 아래와 같다.

--State after update--
StateSnapshot(values={'input': 'hello world', 'user_feedback': 'koko'}, next=('step_3',), config={'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef9c8a1-650b-6cec-8002-f96e54ac5578'}}, metadata={'source': 'update', 'step': 2, 'writes': {'human_feedback': {'user_feedback': 'koko'}}, 'parents': {}}, created_at='2024-11-06T21:57:16.558417+00:00', parent_config={'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef9c896-0265-6dc8-8001-0dda33ef9e0b'}}, tasks=(PregelTask(id='edb830e7-24b6-01de-c248-91581cfe334b', name='step_3', path=('__pregel_pull', 'step_3'), error=None, interrupts=(), state=None, result=None),))

다음 graph state는 ('step_3', ) 이며 graph.stream을 출력하면 아래와 같다.

{'input': 'hello world', 'user_feedback': 'koko'}
---Step 1---

LangGraph를 확인해보자.

최초 start가 시작되어 user_feedback을 통해 LangGraphUpdateState가 실행되었고 graph가 재개되었다.

SqliteSaver

SqliteSaver를 구현하기 위해서 우선 아래 패키지를 설치하자.

pip install langgraph-checkpointer-sqlite

그리고 위의 코드에서 아래 코드를 추가한다.

import sqlite3

conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
memory = SqliteSaver(conn)

그리고 실행을 하고 sqlite를 확인하면 데이터가 잘 저장된 것을 확인할 수 있다.

최종 완성된 코드는 아래와 같다.

import sqlite3
from typing import TypedDict

from dotenv import load_dotenv

load_dotenv()

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.constants import START, END
from langgraph.graph import StateGraph


class State(TypedDict):
    input: str
    user_feedback: str


def step_1(state: State) -> None:
    print("---Step 1---")


def human_feedback(state: State) -> None:
    print("---Human Feedback---")


def step_3(state: State) -> None:
    print("---Step 3---")


builder = StateGraph(State)
builder.add_node("step_1", step_1)
builder.add_node("human_feedback", human_feedback)
builder.add_node("step_3", step_3)
builder.add_edge(START, "step_1")
builder.add_edge("step_1", "human_feedback")
builder.add_edge("human_feedback", "step_3")
builder.add_edge("step_3", END)

conn = sqlite3.connect("checkpoints.sqlite", check_same_thread=False)
memory = SqliteSaver(conn)

graph = builder.compile(checkpointer=memory, interrupt_before=["human_feedback"])
graph.get_graph().draw_mermaid_png(output_file_path="graph.png")

if __name__ == "__main__":
    thread = {"configurable": {"thread_id": "1"}}

    initial_input = {"input": "hello world"}

    for event in graph.stream(initial_input, thread, stream_mode="values"):
        print(event)

    print(graph.get_state(thread).next)

    user_input = input("Tell me how you want to update the state: ")

    graph.update_state(thread, {"user_feedback": user_input}, as_node="human_feedback")

    print("--State after update--")
    print(graph.get_state(thread))

    print(graph.get_state(thread).next)

    for event in graph.stream(initial_input, thread, stream_mode="values"):
        print(event)
반응형
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유
  • 카카오스토리 공유