langgraph에서 checkpointer로 저장 시 MemorySaver와 Redis, Postgresql용만 존재한다. 다양한 DB를 제공하기 위해 SqlAlchemy용으로 만들어 보았다.
import threading
from typing import Any, Optional, Sequence, Tuple
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
get_checkpoint_metadata,
)
from langgraph.checkpoint.serde.base import SerializerProtocol
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from sqlalchemy import Column, Integer, LargeBinary, String, create_engine
from sqlalchemy.orm import Session, declarative_base
from sqlalchemy.sql import text
Base = declarative_base()
class Checkpoint(Base):
__tablename__ = "checkpoints"
thread_id = Column(String(100), primary_key=True)
checkpoint_ns = Column(String(100), primary_key=True, default="")
checkpoint_id = Column(String(100), primary_key=True)
parent_checkpoint_id = Column(String(100))
type = Column(String(50))
checkpoint = Column(LargeBinary)
checkpoint_metadata = Column(LargeBinary)
class Write(Base):
__tablename__ = "writes"
thread_id = Column(String(100), primary_key=True)
checkpoint_ns = Column(String(100), primary_key=True, default="")
checkpoint_id = Column(String(100), primary_key=True)
task_id = Column(String(100), primary_key=True)
idx = Column(Integer, primary_key=True)
channel = Column(String(100))
type = Column(String(50))
value = Column(LargeBinary)
class SqlAlchemySaver(BaseCheckpointSaver[str]):
def __init__(self, engine_url: str, *, serde: Optional[SerializerProtocol] = None):
super().__init__(serde=serde)
self.engine = create_engine(engine_url)
self.jsonplus_serde = JsonPlusSerializer()
Base.metadata.create_all(self.engine)
self.lock = threading.Lock()
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
with Session(self.engine) as session:
if checkpoint_id := get_checkpoint_id(config):
checkpoint = (
session.query(Checkpoint)
.filter_by(
thread_id=str(config["configurable"]["thread_id"]),
checkpoint_ns=checkpoint_ns,
checkpoint_id=checkpoint_id,
)
.first()
)
else:
checkpoint = (
session.query(Checkpoint)
.filter_by(
thread_id=str(config["configurable"]["thread_id"]),
checkpoint_ns=checkpoint_ns,
)
.order_by(text("checkpoint_id DESC"))
.first()
)
if checkpoint:
writes = (
session.query(Write)
.filter_by(
thread_id=checkpoint.thread_id,
checkpoint_ns=checkpoint_ns,
checkpoint_id=checkpoint.checkpoint_id,
)
.order_by(Write.task_id, Write.idx)
.all()
)
return CheckpointTuple(
config={
"configurable": {
"thread_id": checkpoint.thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint.checkpoint_id,
}
},
checkpoint=self.serde.loads_typed(
(checkpoint.type, checkpoint.checkpoint)
),
metadata=(
self.jsonplus_serde.loads(checkpoint.checkpoint_metadata)
if checkpoint.checkpoint_metadata
else {}
),
parent_config=(
{
"configurable": {
"thread_id": checkpoint.thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint.parent_checkpoint_id,
}
}
if checkpoint.parent_checkpoint_id
else None
),
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
checkpoint_metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
serialized_metadata = self.jsonplus_serde.dumps(
get_checkpoint_metadata(config, checkpoint_metadata)
)
with Session(self.engine) as session:
checkpoint_record = Checkpoint(
thread_id=thread_id,
checkpoint_ns=checkpoint_ns,
checkpoint_id=checkpoint["id"],
parent_checkpoint_id=config["configurable"].get("checkpoint_id"),
type=type_,
checkpoint=serialized_checkpoint,
checkpoint_metadata=serialized_metadata,
)
session.merge(checkpoint_record)
session.commit()
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint["id"],
}
}
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[Tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
with Session(self.engine) as session:
for idx, (channel, value) in enumerate(writes):
type_, serialized_value = self.serde.dumps_typed(value)
write = Write(
thread_id=str(config["configurable"]["thread_id"]),
checkpoint_ns=str(config["configurable"].get("checkpoint_ns", "")),
checkpoint_id=str(config["configurable"]["checkpoint_id"]),
task_id=task_id,
idx=WRITES_IDX_MAP.get(channel, idx),
channel=channel,
type=type_,
value=serialized_value,
)
session.merge(write)
session.commit()
사용할 때는 아래와 같이 쓰면 된다.
# memorySaver를 사용할 때
memory = MemorySaver()
# mysql용으로 사용
memory = SqlAlchemySaver(
"mysql+pymysql://<user>:<password>@<url>:<port>/<dbname>"
)
다른 DB도 해당 url을 넣으면 된다.
여기서 해당 DB에 대한 패키지는 설치해야 한다.
반응형