langgraph / / 2024. 11. 15. 11:01

[번역][langgraph tutorial] Corrective RAG (CRAG)

langgraph의 공식문서를 번역해 놓은 자료입니다. 예제 일부는 변경하였고, 필요한 경우 부연 설명을 추가하였습니다. 문제가 되면 삭제하겠습니다.

https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_crag/

Corrective-RAG (CRAG)은 검색된 문서에 대한 self-reflection과 self-grading(스스로 평가)를 포함하는 RAG 전략이다.

논문에서는 몇 가지 단계를 수행한다.

  1. 적어도 하나의 문서가 관련성 기준을 초과하면, 생성(generation) 단계로 진행한다.

    1. 관련성 기준은 검색된 문서가 사용자 질문과 관련있는가 체크 (LLM)
  2. 생성 전에 지식 정제(knowledge refinement)를 수행한다.

    1. 관련성 있는 문서만 생성에서 사용한다. 관련성 없는 문서는 제거한다.
  3. 문서를 "지식 스트립(knowledge strips)"으로 분할한다.

  4. 각 스트립을 평가하고 관련 없는 스트립은 필터링한다.

  5. 모든 문서가 관련성 기준 이하이거나 평가자가 확신하지 못하는 경우, 프레임워크는 추가 데이터 소스를 찾는다.

  6. 검색 보완을 위해 웹 검색을 사용한다.

LangGraph를 사용하여 이러한 아이디어 중 일부를 처음부터 구현할 것이다.

  • 첫 번째 단계에서는 지식 정제 단계를 생략한다. 원한다면 나중에 노드로 추가할 수 있다.
  • 관련 없는 문서가 있는 경우, 웹 검색으로 검색을 보완한다.
  • 웹 검색을 위해 Tavily Search를 사용한다.
  • 웹 검색을 최적화하기 위해 쿼리 재작성(query re-writing)을 사용한다.

Index 생성

야구/경기 방식과 관련된 웹사이트를 인덱싱한다.

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

urls = [
    "https://namu.wiki/w/%EC%95%BC%EA%B5%AC/%EA%B2%BD%EA%B8%B0%20%EB%B0%A9%EC%8B%9D"  # 야구/경기 방식
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

# vectorDB에 추가
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

LLMs

문서가 사용자 질문과 관련성이 있는지 체크하는 노드

### Retrieval Grader

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

from pydantic import BaseModel, Field


# Data model
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )


# LLM with function call
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

retrieval_grader = grade_prompt | structured_llm_grader
question = "야구 규칙에 대해 설명해주세요."
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
binary_score='yes'

관련성이 있으면 yes, 없으면 no로 리턴한다.

### Generate

from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = hub.pull("rlm/rag-prompt")

# LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)


# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

관련성 있는 문서로 LLM에 호출한다. rag-prompt는 langchain 팀에서 만든 프롬프트이다.

야구는 두 팀이 번갈아 가며 공격과 수비를 하는 스포츠로, 각 팀은 9명의 선수로 구성됩니다. 경기는 주로 9이닝으로 진행되며, 각 이닝에서 공격팀은 점수를 얻기 위해 주자를 홈으로 보내야 합니다. 심판의 판단에 따라 타구의 페어와 파울, 투구의 스트라이크와 볼, 주자의 아웃과 세이프가 결정됩니다.
### Question Re-writer

# LLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)

# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n 
     for web search. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()
print(question_rewriter.invoke({"question": question}))

쿼리 재작성 노드이다. 웹 검색에 최적화된 새로운 질문으로 만든다.

야구 규칙에 대한 자세한 설명을 부탁드립니다.

Web 검색 도구

### Search

from langchain_community.tools.tavily_search import TavilySearchResults

web_search_tool = TavilySearchResults(k=3)

Graph 생성

이제 CRAG를 사용할 그래프를 만들자.

Graph 상태 정의

from typing import List

from typing_extensions import TypedDict


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """

    question: str
    generation: str
    web_search: str
    documents: List[str]

LangGraph의 노드에서 사용되는 상태값이다.

from langchain.schema import Document


def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieval
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}


def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}


def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with only filtered relevant documents
    """

    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue

    if not filtered_docs:
        web_search = "Yes"

    return {"documents": filtered_docs, "question": question, "web_search": web_search}


def transform_query(state):
    """
    Transform the query to produce a better question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates question key with a re-phrased question
    """

    print("---TRANSFORM QUERY---")
    question = state["question"]
    documents = state["documents"]

    # Re-write question
    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}


def web_search(state):
    """
    Web search based on the re-phrased question.

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Updates documents key with appended web results
    """

    print("---WEB SEARCH---")
    question = state["question"]
    documents = state["documents"]

    # Web search
    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)

    return {"documents": documents, "question": question}


### Edges


def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.

    Args:
        state (dict): The current graph state

    Returns:
        str: Binary decision for next node to call
    """

    print("---ASSESS GRADED DOCUMENTS---")
    state["question"]
    web_search = state["web_search"]
    state["documents"]

    if web_search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
        )
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

Graph 컴파일

이는 위의 그림에서 나타낸 흐름을 따른다.

from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
workflow.add_node("transform_query", transform_query)  # transform_query
workflow.add_node("web_search_node", web_search)  # web search

# Build graph
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()

Graph 사용하기

from pprint import pprint

# 실행
inputs = {"question": "야구에서 홈런이 뭐야?"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node '{key}':")
    pprint("\n---\n")

# Final generation
pprint(value["generation"])
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
야구에서 홈런이 뭐야?
[Document(metadata={'language': 'ko', 'source': 'https://namu.wiki/w/%EC%95%BC%EA%B5%AC/%EA%B2%BD%EA%B8%B0%20%EB%B0%A9%EC%8B%9D', 'title': '야구/경기 방식 - 나무위키'}, page_content='경우를 홈런으로 기록하고, 후자의 경우는 인사이드 파크 홈런으로 기록한다. 루상에 주자가 있으면 순서대로 진루하여 홈 플레이트를 밟으면 득점한다. 그렇기 때문에 만루홈런을[26] 치면 최대 4점을 득점할')]
---DECISION: GENERATE---
"Node 'grade_documents':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('야구에서 홈런은 타자가 공을 쳐서 외야를 넘어가는 경우를 말합니다. 주자가 있을 경우, 홈런을 치면 주자들이 순서대로 진루하여 득점하게 '
 '됩니다. 만루홈런을 치면 최대 4점을 득점할 수 있습니다.')

위에서 보면 Vector에서 검색을 하고 관련있는 문서가 1개 있는 것을 확인할 수 있다. 그래서 바로 Generation으로 이동하여 답변을 생성한다.

그러면 vectorstore에 있는 문서와 관련없는 질문을 해보자.

Question: 축구에서 파울이 뭐야?

from pprint import pprint

# 실행
inputs = {"question": "축구에서 파울이 뭐야?"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node '{key}':")
    pprint("\n---\n")

# Final generation
pprint(value["generation"])
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
축구에서 파울이 뭐야?
[]
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---
"Node 'grade_documents':"
'\n---\n'
---TRANSFORM QUERY---
"Node 'transform_query':"
'\n---\n'
---WEB SEARCH---
"Node 'web_search_node':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('축구에서 파울은 상대 선수를 밀거나 차는 행위, 발을 걸거나 잡아당기는 행위, 핸드볼 등이 포함됩니다. 파울의 정도에 따라 심판은 옐로 '
 '카드나 레드 카드를 부여하며, 반복적인 위반은 징계로 이어질 수 있습니다. 이러한 규정은 국제 축구 연맹(FIFA)의 경기 규칙에 따라 '
 '공정하고 안전한 경기를 보장하기 위해 설정됩니다.')

Vectorstore에서 검색하고 문서 관련성 평가에서 모든 문서가 관련이 없다고 나온다.

그래서 쿼리를 rewrite를 하고 웹검색을 수행한 결과를 답변으로 만든다.

관련 자료

반응형
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유
  • 카카오스토리 공유