langgraph / / 2024. 11. 29. 17:14

[langgraph] 병렬 실행 브랜치 생성하는 방법

LangGraph 공식문서를 번역한 내용입니다. 필요한 경우 부연 설명을 추가하였고 이해하기 쉽게 예제를 일부 변경하였습니다. 문제가 되면 삭제하겠습니다.

https://langchain-ai.github.io/langgraph/how-tos/branching/

노드의 병렬 실행은 전체 그래프 작업 속도를 높이는 데 필수적이다. LangGraph는 노드의 병렬 실행을 기본적으로 지원하여 그래프 기반 워크플로우의 성능을 크게 향상시킬 수 있다. 이러한 병렬 처리는 fan-out과 fan-in 메커니즘을 통해 이루어지며, 표준 에지와 조건부 에지를 모두 활용한다. 아래는 효과적인 분기 데이터 흐름을 생성하는 방법을 보여주는 몇 가지 예시이다.

준비

우선, 필요한 패키지를 설치하자.

pip install -U langgraph

병렬 노드 fan-in / fan-out

이 예제에서는 Node A에서 Node BC로 fan-out 한 다음, Node D로 fan-in 한다. 상태(State)에서는 리듀서 add 연산을 지정하여 특정 키에 대한 값을 단순히 덮어쓰지 않고 결합하거나 누적한다. 리스트의 경우, 새로운 리스트를 기존 리스트에 연결(concatenate)하는 것을 의미한다.

LangGraph는 상태(State)의 특정 키에 대한 리듀서 함수를 지정하기 위해 Annotated 타입을 사용한다. 타입 검사에서는 원래 타입(리스트)을 유지하면서도, 타입 자체를 변경하지 않고 리듀서 함수(add)를 타입에 연결할 수 있도록 한다.

import operator
from typing import Annotated, Any

from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START, END


class State(TypedDict):
    aggregate: Annotated[list, operator.add]


class ReturnNodeValue:
    def __init__(self, node_secret: str):
        self._value = node_secret

    def __call__(self, state: State) -> Any:
        print(f"Adding {self._value} to {state['aggregate']}")
        return {"aggregate": [self._value]}


builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")
builder.add_node("b", ReturnNodeValue("I'm B"))
builder.add_node("c", ReturnNodeValue("I'm C"))
builder.add_node("d", ReturnNodeValue("I'm D"))

builder.add_edge("a", "b")
builder.add_edge("a", "c")
builder.add_edge("b", "d")
builder.add_edge("c", "d")
builder.add_edge("d", END)
graph = builder.compile()
from IPython.display import Image, display

display(
    Image(graph.get_graph().draw_mermaid_png(output_file_path="./parallel_execute.png"))
)

리듀서를 사용하면 각 노드에서 추가된 값들이 누적되는 것을 볼 수 있다.

result = graph.invoke({"aggregate": []}, {"configurable": {"thread_id": "foo"}})
print(result)
Adding I'm A to []
Adding I'm B to ["I'm A"]
Adding I'm C to ["I'm A"]
Adding I'm D to ["I'm A", "I'm B", "I'm C"]
{'aggregate': ["I'm A", "I'm B", "I'm C", "I'm D"]}

추가 단계로 병렬 노드 fan-out과 fan-in

위 예제에서는 각 경로가 한 단계만 있는 경우에 fan-out과 fan-in을 구현하는 방법을 보여준다. 하지만 한 경로에 여러 단계가 있는 경우에는 어떻게 될까?

import operator
from typing import Annotated

from langgraph.constants import START, END
from typing_extensions import TypedDict

from langgraph.graph import StateGraph


class State(TypedDict):
    aggregate: Annotated[list, operator.add]


class ReturnNodeValue:
    def __init__(self, node_secret: str):
        self._value = node_secret

    def __call__(self, state: State) -> Any:
        print(f"Adding {self._value} to {state['aggregate']}")
        return {"aggregate": [self._value]}


builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")
builder.add_node("b", ReturnNodeValue("I'm B"))
builder.add_node("b2", ReturnNodeValue("I'm B2"))
builder.add_node("c", ReturnNodeValue("I'm C"))
builder.add_node("d", ReturnNodeValue("I'm D"))
builder.add_edge("a", "b")
builder.add_edge("a", "c")
builder.add_edge("b", "b2")
builder.add_edge(["b2", "c"], "d")
builder.add_edge("d", END)
graph = builder.compile()
from IPython.display import Image, display

display(
    Image(
        graph.get_graph().draw_mermaid_png(output_file_path="./parallel_execute2.png")
    )
)

result = graph.invoke({"aggregate": []})
print(result)
Adding I'm A to []
Adding I'm C to ["I'm A"]
Adding I'm B to ["I'm A"]
Adding I'm B2 to ["I'm A", "I'm B", "I'm C"]
Adding I'm D to ["I'm A", "I'm B", "I'm C", "I'm B2"]
{'aggregate': ["I'm A", "I'm B", "I'm C", "I'm B2", "I'm D"]}

조건부 분기 (Conditional Branching)

fan-out이 결정되지 않은 경우, add_conditional_edges를 직접 사용할 수 있다.

조건부 분기가 이후에 연결될 "sink" 노드가 이미 정해져 있다면, 조건부 에지를 생성할 때 then=<final-node-name>을 지정할 수 있다.

import operator
from typing import Annotated, Sequence, Any

from typing_extensions import TypedDict

from langgraph.graph import END, START, StateGraph


class State(TypedDict):
    aggregate: Annotated[list, operator.add]
    which: str


class ReturnNodeValue:
    def __init__(self, node_secret: str):
        self._value = node_secret

    def __call__(self, state: State) -> Any:
        print(f"Adding {self._value} to {state['aggregate']}")
        return {"aggregate": [self._value]}


builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")
builder.add_node("b", ReturnNodeValue("I'm B"))
builder.add_node("c", ReturnNodeValue("I'm C"))
builder.add_node("d", ReturnNodeValue("I'm D"))
builder.add_node("e", ReturnNodeValue("I'm E"))


def route_bc_or_cd(state: State) -> Sequence[str]:
    if state["which"] == "cd":
        return ["c", "d"]
    return ["b", "c"]


intermediates = ["b", "c", "d"]
builder.add_conditional_edges(
    "a",
    route_bc_or_cd,
    intermediates,
)
for node in intermediates:
    builder.add_edge(node, "e")


builder.add_edge("e", END)
graph = builder.compile()

ReturnNodeValue는 추가한 값을 그대로 리턴하는 클래스이다.

from IPython.display import Image, display

display(
    Image(
        graph.get_graph().draw_mermaid_png(
            output_file_path="./parallel_execute_conditional_branching.png"
        )
    )
)

result = graph.invoke({"aggregate": [], "which": "bc"})
print(result)
Adding I'm A to []
Adding I'm B to ["I'm A"]
Adding I'm C to ["I'm A"]
Adding I'm E to ["I'm A", "I'm B", "I'm C"]
{'aggregate': ["I'm A", "I'm B", "I'm C", "I'm E"], 'which': 'bc'}

which를 bc로 주면 b와 c로 분기한다.

result = graph.invoke({"aggregate": [], "which": "cd"})
print(result)
Adding I'm A to []
Adding I'm C to ["I'm A"]
Adding I'm D to ["I'm A"]
Adding I'm E to ["I'm A", "I'm C", "I'm D"]
{'aggregate': ["I'm A", "I'm C", "I'm D", "I'm E"], 'which': 'cd'}

which를 cd로 주면 c와 d로 분기한다.

Stable Sorting

fan-out이 발생하면, 노드들은 하나의 "슈퍼스텝(superstep)"으로 병렬 실행된다. 각 슈퍼스텝에서의 업데이트는 슈퍼스텝이 완료된 후 순차적으로 상태(State)에 적용된다.

병렬 슈퍼스텝의 업데이트를 일관되고 사전에 정해진 순서로 처리해야 한다면, 출력을 상태의 별도 필드에 식별 키와 함께 기록한 후, "sink" 노드에서 이를 결합해야 한다. 이를 위해 fan-out 노드 각각에서 랑데부 지점으로 가는 일반 에지를 추가하면 된다.

예를 들어, 병렬 단계의 출력을 "신뢰도(reliability)"에 따라 정렬하고 싶다고 가정해 보자.

import operator
from tkinter import END
from typing import Annotated, Sequence, Any

from langgraph.constants import START
from typing_extensions import TypedDict

from langgraph.graph import StateGraph


def reduce_fanouts(left, right):
    if left is None:
        left = []
    if not right:
        return []
    return left + right # 배열을 병합
  # 예: left: [{'reliability': 0.9, 'value': ["I'm B"]}], right: [{'reliability': 0.1, 'value': ["I'm C"]}]
  # left + right : [{'reliability': 0.9, 'value': ["I'm B"]}, {'reliability': 0.1, 'value': ["I'm C"]}]


class State(TypedDict):
    aggregate: Annotated[list, operator.add]
    fanout_values: Annotated[list, reduce_fanouts]
    which: str


class ReturnNodeValue:
    def __init__(self, node_secret: str):
        self._value = node_secret

    def __call__(self, state: State) -> Any:
        print(f"Adding {self._value} to {state['aggregate']}")
        return {"aggregate": [self._value]}


builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")


class ParallelReturnNodeValue:
    def __init__(
        self,
        node_secret: str,
        reliability: float,
    ):
        self._value = node_secret
        self._reliability = reliability

    def __call__(self, state: State) -> Any:
        print(f"Adding {self._value} to {state['aggregate']} in parallel.")
        return {
            "fanout_values": [
                {
                    "value": [self._value],
                    "reliability": self._reliability,
                }
            ]
        }


builder.add_node("b", ParallelReturnNodeValue("I'm B", reliability=0.9))
builder.add_node("c", ParallelReturnNodeValue("I'm C", reliability=0.1))
builder.add_node("d", ParallelReturnNodeValue("I'm D", reliability=0.3))


def aggregate_fanout_values(state: State) -> Any:
    # reliability 기준으로 정렬
    ranked_values = sorted(
        state["fanout_values"], key=lambda x: x["reliability"], reverse=True
    )
    return {
        "aggregate": [x["value"] for x in ranked_values] + ["I'm E"],
        "fanout_values": [],
    }


builder.add_node("e", aggregate_fanout_values)


def route_bc_or_cd(state: State) -> Sequence[str]:
    if state["which"] == "cd":
        return ["c", "d"]
    return ["b", "c"]


intermediates = ["b", "c", "d"]
builder.add_conditional_edges("a", route_bc_or_cd, intermediates)

for node in intermediates:
    builder.add_edge(node, "e")

builder.add_edge("e", END)
graph = builder.compile()
display(
    Image(
        graph.get_graph().draw_mermaid_png(
            output_file_path="./parallel_execute_stable_sorting.png"
        )
    )
)

graph.invoke({"aggregate": [], "which": "bc", "fanout_values": []})
Adding I'm A to []
Adding I'm B to ["I'm A"] in parallel.
Adding I'm C to ["I'm A"] in parallel.
{'aggregate': ["I'm A", ["I'm B"], ["I'm C"], "I'm E"],
 'fanout_values': [],
 'which': 'bc'}
graph.invoke({"aggregate": [], "which": "cd"})
Adding I'm A to []
Adding I'm C to ["I'm A"] in parallel.
Adding I'm D to ["I'm A"] in parallel.
{'aggregate': ["I'm A", ["I'm D"], ["I'm C"], "I'm E"],
 'fanout_values': [],
 'which': 'cd'}

LangGraph 참고 자료

반응형
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유
  • 카카오스토리 공유