FastAPI와 LangGraph를 이용하여 RAG의 응답을 Streaming으로 출력하는 방법을 알아보자.
일단 검색 및 생성을 하기 위해서는 RAG pipeline을 구성을 해야 하는데 아주 간단한 방법으로만 할 예정이다.
우선 아래의 방법으로 진행할 예정이다.
- LangGraph로 RAG를 조회하는 node를 구성 + 생성 node 추가
- FastAPI로 조회 api 생성
/search api로 검색 - 응답값을 SSE 형식으로 스트리밍 출력
작업 순서
- 환경설정
- langgraph 생성
- Fastapi 생성
- 테스트
1. 환경설정
.env 설정
.env에 OPENAI_API_KEY를 추가한다.
OPENAI_API_KEY=sk-xxxx
패키지 설치
pip install langchain langchainhub langchain-openai langgraph chromadb python-dotenv uvicorn fastapi langchain_chroma langchain_community beautifulsoup4
아래 예제에서 사용하는 패키지를 설치한다.
2. langgraph 생성
우선 RAG 파이프라인을 만들어 볼 예정이다. 아래와 같이 가장 단순한 방식으로 구성한다.
임베딩
우선 특정 데이터를 chromadb에 임베딩을 한다. 예제에서는 야구규칙이 있는 namu wiki의 내용을 WebBaseLoader를 사용하여 저장한다.
[ingestion.py]
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_community.document_loaders import WebBaseLoader
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
load_dotenv()
urls = [
"https://namu.wiki/w/%EC%95%BC%EA%B5%AC/%EA%B2%BD%EA%B8%B0%20%EB%B0%A9%EC%8B%9D",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=0, model_name="gpt-4o-mini"
)
doc_splits = text_splitter.split_documents(docs_list)
chroma_directory = "./chroma"
vectorstore = Chroma.from_documents(
documents=doc_splits,
collection_name="baseball-chroma",
embedding=OpenAIEmbeddings(),
persist_directory="./chroma",
)
retriever = Chroma(
collection_name="baseball-chroma",
embedding_function=OpenAIEmbeddings(),
persist_directory=chroma_directory,
).as_retriever()
그리고 위의 파일을 실행을 해보자. 그러면 로컬에 chroma 파일이 생성이 되고 url 데이터가 저장되는 것을 확인할 수 있다.
이 파일에서 retriever를 다른 파일에서 읽어올 예정인데 Chroma.from_documents 부분을 주석처리한다. 그렇지 않으면 매번 vector db에 문서가 저장이 되어 중복데이터가 생긴다.
# vectorstore = Chroma.from_documents(
# documents=doc_splits,
# collection_name="baseball-chroma",
# embedding=OpenAIEmbeddings(),
# persist_directory="./chroma",
# )
graph 생성
langgraph를 사용하기 위해서는 state, node, chain을 만들어야 한다.
우선 state를 만든다.
[state.py]
class GraphState(TypedDict):
question: str
generation: str
documents: List[str]
question은 사용자의 prompt 내용이고 generation은 LLM에서 생성된 completion 데이터이다. documents는 vectordb에서 검색한 결과이다.
그리고 retrieve, generate의 두 개의 node를 만든다.
[retrieve.py]
from typing import Any, Dict
from graphs.state import GraphState
from ingestion import retriever
async def retrieve(state: GraphState) -> Dict[str, Any]:
print("---RETRIEVE---")
question = state["question"]
documents = await retriever.ainvoke(question)
return {"documents": documents, "question": question}
[generate.py]
from typing import Any, Dict
from langgraph.types import StreamWriter
from graphs.chains.generation import generation_chain
from graphs.state import GraphState
async def generate(state: GraphState, writer: StreamWriter) -> Dict[str, Any]:
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
chunks = []
async for chunk in generation_chain.astream(
{"context": documents, "question": question}
):
writer(chunk)
chunks.append(chunk)
return {"documents": documents, "question": question, "generation": "".join(chunks)}
api에서 호출방식을 async로 사용할 예정이므로 함수는 async로 지정한다.
위에서 StreamWriter는 각 chunk를 stream으로 출력하기 위해 필요하다. StreamWriter에 각 chunk를 넣어줘야 출력이 된다.
그리고 state에 chunk의 완성체를 추가하여 다음 노드에서 사용할 수 있게 한다.
다음으로 chain을 생성한다.
[generation.py]
from dotenv import load_dotenv
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
load_dotenv()
llm = ChatOpenAI(temperature=0)
prompt = hub.pull("rlm/rag-prompt")
generation_chain = prompt | llm | StrOutputParser()
rag-prompt는 langchain 팀에서 만든 가장 많이 사용하는 rag 프롬프트이다.
마지막으로 graph를 생성한다.
[graph.py]
from dotenv import load_dotenv
load_dotenv()
from langgraph.graph import END, StateGraph
from graphs.nodes.generate import generate
from graphs.nodes.retrieve import retrieve
from graphs.state import GraphState
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
graph_app = workflow.compile()
graph_app.get_graph().draw_mermaid_png(output_file_path="./graph.png")
2개의 노드를 만들고 각 edge를 연결한다.
start -> retrieve
retrieve -> generate
generate -> end
draw_mermaid_png를 실행하면 graph.png로 이미지가 생성되는 것을 확인할 수 있다.
3. Fastapi 생성
이제 api를 만들어보자.
@app.get("/search")
async def search(query: str):
async def event_stream():
try:
async for chunk in graph_app.astream(
input={"question": query}, stream_mode=["custom"]
):
yield f"data: {chunk[1]}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
/search로 query라는 검색어를 파라미터로 받으며 graph_app을 호출한 결과를 SSE 방식으로 사용할 수 있게 stream으로 리턴한다.
여기서 stream_mode를 custom으로 해줘야 실제 답변의 내용만을 받을 수 있다. stream_mode에는 updates, values 등 다른 값이 있는데, 답변 외에 변경된 값 혹은 state에 있는 모든 값을 받고 싶다면 옵션을 조정할 수도 있다.
위에서 chunk[1]은 응답으로 chunk의 내용이 tuple로 넘어온다. 그래서 두번째 값이 실제 응답이라 추가하였다.
chunk 내용
('custom', '내용...')
StreamingResponse는 일반적인 streaming으로 응답을 보낼 때 사용하는 방식이다.
4. 테스트
우선 테스트를 위해서 fastapi를 시작하자.
uvicorn main:app
api 실행
GET http://localhost:8000/search?query=야구 경기방식 알려줘
실행결과
HTTP/1.1 200 OK
date: Thu, 07 Nov 2024 23:23:53 GMT
server: uvicorn
content-type: text/event-stream; charset=utf-8
transfer-encoding: chunked
Response code: 200 (OK); Time: 48ms (48 ms)
data:
data: 야
data: 구
data: 는
data: 각
data:
data: 9
data: 명
...
data: 니다
data: .
data: