langgraph / / 2024. 12. 2. 07:03

[langgraph] 많은 수의 도구(tool)를 처리하는 방법

[langgraph] 많은 수의 도구(tool)를 처리하는 방법

LangGraph 공식문서를 번역한 내용입니다. 필요한 경우 부연 설명을 추가하였고 이해하기 쉽게 예제를 일부 변경하였습니다. 문제가 되면 삭제하겠습니다.

https://langchain-ai.github.io/langgraph/how-tos/many-tools/

사용 가능한 도구 호출의 하위 집합은 일반적으로 모델의 재량에 달려 있다(많은 제공자들이 사용자에게 도구 선택을 지정하거나 제한할 수 있는 기능을 제공하기도 한다). 사용 가능한 도구의 수가 증가함에 따라, LLM이 선택할 수 있는 도구의 범위를 제한하려는 경우가 있을 수 있다. 이는 토큰 소비를 줄이고 LLM 추론에서 발생할 수 있는 오류의 원천을 관리하는 데 도움이 된다.

여기서는 모델이 사용할 수 있는 도구를 동적으로 조정하는 방법을 보여준다. 핵심은 RAG 및 유사한 방법들과 마찬가지로, 모델 호출을 시작하기 전에 사용 가능한 도구를 검색하는 것이다. 도구 선택 방법에 대한 구현을 하나의 예시로 보여주지만, 도구 선택의 세부 사항은 필요에 맞게 맞춤화할 수 있다.

준비

우선, 필요한 패키지를 설치하자.

pip install langgraph langchain_openai numpy

도구 정의

S&P 500 지수에 포함된 각 상장 기업에 대해 하나의 도구가 있는 장난감 예제를 고려해 보자. 각 도구는 제공된 연도를 기준으로 회사별 정보를 가져온다.

우리는 먼저 각 도구에 고유 식별자와 스키마를 연결하는 레지스트리를 구성한다. 도구는 JSON 스키마를 사용하여 표현되며, 이는 도구 호출을 지원하는 채팅 모델에 직접 바인딩될 수 있다.

import re
import uuid

from dotenv import load_dotenv
from langchain_core.tools import StructuredTool

load_dotenv()


def create_tool(company: str) -> dict:
    """Create schema for a placeholder tool."""
    # 알파벳이 아닌 문자를 제거하고 도구 이름에 공백을 밑줄로 바꾼다.
    formatted_company = re.sub(r"[^\w\s]", "", company).replace(" ", "_")

    def company_tool(year: int) -> str:
        # 회사와 연도에 대한 정적 수익 정보를 반환하는 플레이스홀더 함수
        return f"{company} had revenues of $100 in {year}."

    return StructuredTool.from_function(
        company_tool,
        name=formatted_company,
        description=f"Information about {company}",
    )


# S&P 500 회사들의 약어 목록 (데모용)
s_and_p_500_companies = [
    "3M",
    "A.O. Smith",
    "Abbott",
    "Accenture",
    "Advanced Micro Devices",
    "Yum! Brands",
    "Zebra Technologies",
    "Zimmer Biomet",
    "Zoetis",
]

# UUID 키를 사용하여 각 회사에 대한 도구를 만들고 레지스트리에 저장한다.
tool_registry = {
    str(uuid.uuid4()): create_tool(company) for company in s_and_p_500_companies
}

그래프 정의

도구 선택

상태에 있는 정보를 바탕으로 사용 가능한 도구의 하위 집합을 검색하는 노드를 구성할 것이다. 예를 들어, 최근 사용자 메시지와 같은 정보가 있을 수 있다. 일반적으로 이 단계에서는 다양한 검색 솔루션을 사용할 수 있다. 간단한 솔루션으로는 도구 설명의 임베딩을 벡터 저장소에 인덱싱하고, 의미론적 검색을 통해 사용자 쿼리를 도구와 연결하는 방법을 사용한다.

from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings

tool_documents = [
    Document(
        page_content=tool.description,
        id=id,
        metadata={"tool_name": tool.name},
    )
    for id, tool in tool_registry.items()
]

vector_store = InMemoryVectorStore(embedding=OpenAIEmbeddings())
document_ids = vector_store.add_documents(tool_documents)

에이전트와 통합하기

우리는 전형적인 React 에이전트 그래프(예: quick start에서 사용된 것)를 일부 수정하여 사용할 것이다.

  • 상태에 selected_tools 키를 추가하여 선택된 도구의 하위 집합을 저장한다.
  • 그래프의 진입점을 select_tools 노드로 설정하여 이 상태 요소를 채운다.
  • 선택된 도구의 하위 집합을 에이전트 노드 내에서 채팅 모델에 바인딩한다.
from typing import Annotated

from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition


class State(TypedDict):
    messages: Annotated[list, add_messages]
    selected_tools: list[str]


builder = StateGraph(State)

tools = list(tool_registry.values())
llm = ChatOpenAI()


def agent(state: State):
    selected_tools = [tool_registry[id] for id in state["selected_tools"]]
    llm_with_tools = llm.bind_tools(selected_tools)
    return {"messages": [llm_with_tools.invoke(state["messages"])]}


def select_tools(state: State):
    last_user_message = state["messages"][-1]
    query = last_user_message.content
    tool_documents = vector_store.similarity_search(query)
    return {"selected_tools": [document.id for document in tool_documents]}


builder.add_node("agent", agent)
builder.add_node("select_tools", select_tools)

tool_node = ToolNode(tools=tools)
builder.add_node("tools", tool_node)

builder.add_conditional_edges("agent", tools_condition, path_map=["tools", "__end__"])
builder.add_edge("tools", "agent")
builder.add_edge("select_tools", "agent")
builder.add_edge(START, "select_tools")
graph = builder.compile()
from IPython.display import Image, display

try:
    display(
        Image(
            graph.get_graph().draw_mermaid_png(
                output_file_path="how-to-handle-large-numbers-of-tools.png"
            )
        )
    )
except Exception:
    pass

user_input = "Can you give me some information about AMD in 2022?"

result = graph.invoke({"messages": [("user", user_input)]})
print(result["selected_tools"])
['ba3a91f2-2ad0-454e-a598-500533e04cb6', '76eaab8b-9866-446f-890f-e68d942b2d69', 'd7d6e5cd-ade8-4d77-b370-1fe03c84dceb', 'ff01a6e8-42c0-44f2-9bf9-b989e62a11f8']
for message in result["messages"]:
    message.pretty_print()
================================ Human Message =================================

Can you give me some information about AMD in 2022?
================================== Ai Message ==================================
Tool Calls:
  Advanced_Micro_Devices (call_SeO5IqMAd2jXGT4eOytUkXhP)
 Call ID: call_SeO5IqMAd2jXGT4eOytUkXhP
  Args:
    year: 2022
================================= Tool Message =================================
Name: Advanced_Micro_Devices

Advanced Micro Devices had revenues of $100 in 2022.
================================== Ai Message ==================================

In 2022, Advanced Micro Devices (AMD) had revenues of $100.

도구 선택 반복하기

잘못된 도구 선택으로 인한 오류를 관리하기 위해 select_tools 노드를 다시 살펴볼 수 있다. 이를 구현하는 한 가지 방법은 select_tools를 수정하여 상태의 모든 메시지를 사용하여 벡터 저장소 쿼리를 생성하고, 도구에서 select_tools로 가는 엣지 라우팅을 추가하는 것이다.

아래에서 이 변경 사항을 구현한다. 시연을 위해, select_tools 노드에 hack_remove_tool_condition을 추가하여 초기 도구 선택에서 오류를 시뮬레이션한다. 이 조건은 노드의 첫 번째 반복에서 올바른 도구를 제거한다. 두 번째 반복에서 에이전트는 올바른 도구에 접근할 수 있어 실행을 완료한다.

from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langgraph.pregel.retry import RetryPolicy

from pydantic import BaseModel, Field


class QueryForTools(BaseModel):
    """Generate a query for additional tools."""

    query: str = Field(..., description="Query for additional tools.")


def select_tools(state: State):
    """Selects tools based on the last message in the conversation state.

    If the last message is from a human, directly uses the content of the message
    as the query. Otherwise, constructs a query using a system message and invokes
    the LLM to generate tool suggestions.
    """
    last_message = state["messages"][-1]
    hack_remove_tool_condition = False  # Simulate an error in the first tool selection

    if isinstance(last_message, HumanMessage):
        query = last_message.content
        hack_remove_tool_condition = True  # Simulate wrong tool selection
    else:
        assert isinstance(last_message, ToolMessage)
        system = SystemMessage(
            "Given this conversation, generate a query for additional tools. "
            "The query should be a short string containing what type of information "
            "is needed. If no further information is needed, "
            "set more_information_needed False and populate a blank string for the query."
        )
        input_messages = [system] + state["messages"]
        response = llm.bind_tools([QueryForTools], tool_choice=True).invoke(
            input_messages
        )
        query = response.tool_calls[0]["args"]["query"]

    # Search the tool vector store using the generated query
    tool_documents = vector_store.similarity_search(query)
    if hack_remove_tool_condition:
        # Simulate error by removing the correct tool from the selection
        selected_tools = [
            document.id
            for document in tool_documents
            if document.metadata["tool_name"] != "Advanced_Micro_Devices"
        ]
    else:
        selected_tools = [document.id for document in tool_documents]
    return {"selected_tools": selected_tools}


graph_builder = StateGraph(State)
graph_builder.add_node("agent", agent)
graph_builder.add_node("select_tools", select_tools, retry=RetryPolicy(max_attempts=3))

tool_node = ToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)

graph_builder.add_conditional_edges(
    "agent",
    tools_condition,
)
graph_builder.add_edge("tools", "select_tools")
graph_builder.add_edge("select_tools", "agent")
graph_builder.add_edge(START, "select_tools")
graph = graph_builder.compile()
from IPython.display import Image, display

try:
    display(
        Image(
            graph.get_graph().draw_mermaid_png(
                output_file_path="how-to-handle-large-numbers-of-tools-repeating.png"
            )
        )
    )
except Exception:
    pass

user_input = "Can you give me some information about AMD in 2022?"

result = graph.invoke({"messages": [("user", user_input)]})
for message in result["messages"]:
    message.pretty_print()
================================ Human Message =================================

Can you give me some information about AMD in 2022?
================================== Ai Message ==================================
Tool Calls:
  Advanced_Micro_Devices (call_RXoTR2YYPwUWdbI6glz5kz0v)
 Call ID: call_RXoTR2YYPwUWdbI6glz5kz0v
  Args:
    year: 2022
================================= Tool Message =================================
Name: Advanced_Micro_Devices

Advanced Micro Devices had revenues of $100 in 2022.
================================== Ai Message ==================================

In 2022, Advanced Micro Devices had revenues of $100.

LangGraph 참고 자료

반응형
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유
  • 카카오스토리 공유