langgraph / / 2025. 2. 21. 10:10

[langgraph] checkpointer로 SqlAlchemy 사용

langgraph에서 checkpointer로 저장 시 MemorySaver와 Redis, Postgresql용만 존재한다. 다양한 DB를 제공하기 위해 SqlAlchemy용으로 만들어 보았다.

import threading
from typing import Any, Optional, Sequence, Tuple

from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
    WRITES_IDX_MAP,
    BaseCheckpointSaver,
    ChannelVersions,
    CheckpointMetadata,
    CheckpointTuple,
    get_checkpoint_id,
    get_checkpoint_metadata,
)
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from sqlalchemy import Column, Integer, LargeBinary, String, create_engine
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.sql import text

Base = declarative_base()


class Checkpoint(Base):
    __tablename__ = "checkpoints"

    thread_id = Column(String(100), primary_key=True)
    checkpoint_ns = Column(String(100), primary_key=True, default="")
    checkpoint_id = Column(String(100), primary_key=True)
    parent_checkpoint_id = Column(String(100))
    type = Column(String(50))
    checkpoint = Column(LargeBinary)
    checkpoint_metadata = Column(LargeBinary)


class Write(Base):
    __tablename__ = "writes"

    thread_id = Column(String(100), primary_key=True)
    checkpoint_ns = Column(String(100), primary_key=True, default="")
    checkpoint_id = Column(String(100), primary_key=True)
    task_id = Column(String(100), primary_key=True)
    idx = Column(Integer, primary_key=True)
    channel = Column(String(100))
    type = Column(String(50))
    value = Column(LargeBinary)


class SqlAlchemySaver(BaseCheckpointSaver[str]):
    def __init__(self, engine_url: str, *, serde: Optional[SerializerProtocol] = None):
        super().__init__(serde=serde)
        self.engine = create_engine(engine_url)
        self.jsonplus_serde = JsonPlusSerializer()
        Base.metadata.create_all(self.engine)
        self.lock = threading.Lock()

    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
        checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
        with Session(self.engine) as session:
            if checkpoint_id := get_checkpoint_id(config):
                checkpoint = (
                    session.query(Checkpoint)
                    .filter_by(
                        thread_id=str(config["configurable"]["thread_id"]),
                        checkpoint_ns=checkpoint_ns,
                        checkpoint_id=checkpoint_id,
                    )
                    .first()
                )
            else:
                checkpoint = (
                    session.query(Checkpoint)
                    .filter_by(
                        thread_id=str(config["configurable"]["thread_id"]),
                        checkpoint_ns=checkpoint_ns,
                    )
                    .order_by(text("checkpoint_id DESC"))
                    .first()
                )

            if checkpoint:
                writes = (
                    session.query(Write)
                    .filter_by(
                        thread_id=checkpoint.thread_id,
                        checkpoint_ns=checkpoint_ns,
                        checkpoint_id=checkpoint.checkpoint_id,
                    )
                    .order_by(Write.task_id, Write.idx)
                    .all()
                )

                return CheckpointTuple(
                    config={
                        "configurable": {
                            "thread_id": checkpoint.thread_id,
                            "checkpoint_ns": checkpoint_ns,
                            "checkpoint_id": checkpoint.checkpoint_id,
                        }
                    },
                    checkpoint=self.serde.loads_typed(
                        (checkpoint.type, checkpoint.checkpoint)
                    ),
                    metadata=(
                        self.jsonplus_serde.loads(checkpoint.checkpoint_metadata)
                        if checkpoint.checkpoint_metadata
                        else {}
                    ),
                    parent_config=(
                        {
                            "configurable": {
                                "thread_id": checkpoint.thread_id,
                                "checkpoint_ns": checkpoint_ns,
                                "checkpoint_id": checkpoint.parent_checkpoint_id,
                            }
                        }
                        if checkpoint.parent_checkpoint_id
                        else None
                    ),
                )

    def put(
        self,
        config: RunnableConfig,
        checkpoint: Checkpoint,
        checkpoint_metadata: CheckpointMetadata,
        new_versions: ChannelVersions,
    ) -> RunnableConfig:
        thread_id = config["configurable"]["thread_id"]
        checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
        type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
        serialized_metadata = self.jsonplus_serde.dumps(
            get_checkpoint_metadata(config, checkpoint_metadata)
        )

        with Session(self.engine) as session:
            checkpoint_record = Checkpoint(
                thread_id=thread_id,
                checkpoint_ns=checkpoint_ns,
                checkpoint_id=checkpoint["id"],
                parent_checkpoint_id=config["configurable"].get("checkpoint_id"),
                type=type_,
                checkpoint=serialized_checkpoint,
                checkpoint_metadata=serialized_metadata,
            )
            session.merge(checkpoint_record)
            session.commit()

        return {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": checkpoint_ns,
                "checkpoint_id": checkpoint["id"],
            }
        }

    def put_writes(
        self,
        config: RunnableConfig,
        writes: Sequence[Tuple[str, Any]],
        task_id: str,
        task_path: str = "",
    ) -> None:
        with Session(self.engine) as session:
            for idx, (channel, value) in enumerate(writes):
                type_, serialized_value = self.serde.dumps_typed(value)
                write = Write(
                    thread_id=str(config["configurable"]["thread_id"]),
                    checkpoint_ns=str(config["configurable"].get("checkpoint_ns", "")),
                    checkpoint_id=str(config["configurable"]["checkpoint_id"]),
                    task_id=task_id,
                    idx=WRITES_IDX_MAP.get(channel, idx),
                    channel=channel,
                    type=type_,
                    value=serialized_value,
                )
                session.merge(write)
            session.commit()

사용할 때는 아래와 같이 쓰면 된다.

# memorySaver를 사용할 때
memory = MemorySaver()

# mysql용으로 사용
memory = SqlAlchemySaver(
    "mysql+pymysql://<user>:<password>@<url>:<port>/<dbname>"
 )

다른 DB도 해당 url을 넣으면 된다.

여기서 해당 DB에 대한 패키지는 설치해야 한다.

반응형
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유
  • 카카오스토리 공유