AI 지식 / / 2025. 8. 19. 21:29

[dots.ocr 번역] dots.ocr: 단일 비전-언어 모델에서의 다국어 문서 레이아웃 파싱

소개

dots.ocr은 단일 비전-언어 모델 내에서 레이아웃 감지와 콘텐츠 인식을 통합하면서 우수한 읽기 순서를 유지하는 강력한 다국어 문서 파서입니다. 컴팩트한 17억 매개변수 LLM 기반임에도 불구하고 최첨단(SOTA) 성능을 달성합니다.

주요 기능

  1. 강력한 성능: OmniDocBench에서 텍스트, 테이블, 읽기 순서에 대한 SOTA 결과 달성
  2. 다국어 지원: 저자원 언어에 대한 강력한 파싱 기능
  3. 통합된 간단한 아키텍처: 단일 비전-언어 모델 사용
  4. 효율적이고 빠른 성능: 컴팩트한 17억 LLM 기반

설치

1. dots.ocr 설치

conda create -n dots_ocr python=3.12
conda activate dots_ocr

git clone https://github.com/rednote-hilab/dots.ocr.git
cd dots.ocr

# PyTorch 설치 (CUDA 버전에 맞게 조정)
pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu128
pip install -e .

2. 모델 가중치 다운로드

python3 tools/download_model.py

# 선택사항: ModelScope 사용
python3 tools/download_model.py --type modelscope

배포

vLLM 추론

# vLLM에 모델 등록
python3 tools/download_model.py
export hf_model_path=./weights/DotsOCR
export PYTHONPATH=$(dirname "$hf_model_path"):$PYTHONPATH

# vLLM 서버 실행
CUDA_VISIBLE_DEVICES=0 vllm serve ${hf_model_path} \
    --tensor-parallel-size 1 \
    --gpu-memory-utilization 0.95 \
    --chat-template-content-format string \
    --served-model-name dots-ocr

클라이언트 사용

from openai import OpenAI
import base64

# vLLM 클라이언트 설정
client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="token-abc123",
)

# 이미지를 base64로 인코딩
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

base64_image = encode_image("path/to/your/image.jpg")

# OCR 요청
completion = client.chat.completions.create(
    model="dots-ocr",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{base64_image}"
                    }
                },
                {
                    "type": "text",
                    "text": "Parse this document and extract all text content with proper reading order."
                }
            ]
        }
    ]
)

print(completion.choices[0].message.content)

직접 사용법

기본 OCR

import torch
from PIL import Image
from dots_ocr import DotsOCRModel, DotsOCRProcessor

# 모델과 프로세서 로드
model = DotsOCRModel.from_pretrained("./weights/DotsOCR")
processor = DotsOCRProcessor.from_pretrained("./weights/DotsOCR")

# 이미지 로드
image = Image.open("path/to/your/document.jpg")

# 이미지 전처리
inputs = processor(images=image, return_tensors="pt")

# 추론 실행
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=2048,
        temperature=0.1,
        do_sample=True
    )

# 결과 디코딩
result = processor.decode(outputs[0], skip_special_tokens=True)
print(result)

배치 처리

import torch
from PIL import Image
from dots_ocr import DotsOCRModel, DotsOCRProcessor

# 여러 이미지 처리
images = [
    Image.open("document1.jpg"),
    Image.open("document2.jpg"),
    Image.open("document3.jpg")
]

model = DotsOCRModel.from_pretrained("./weights/DotsOCR")
processor = DotsOCRProcessor.from_pretrained("./weights/DotsOCR")

# 배치 전처리
inputs = processor(images=images, return_tensors="pt", padding=True)

# 배치 추론
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=2048,
        temperature=0.1,
        do_sample=True,
        num_beams=1
    )

# 결과 디코딩
results = []
for i, output in enumerate(outputs):
    result = processor.decode(output, skip_special_tokens=True)
    results.append(result)
    print(f"Document {i+1}: {result}")

구조화된 출력

import json
from dots_ocr import DotsOCRModel, DotsOCRProcessor

# 구조화된 출력을 위한 프롬프트
def create_structured_prompt():
    return """
    Parse this document and return the result in the following JSON format:
    {
        "title": "document title",
        "sections": [
            {
                "heading": "section heading",
                "content": "section content"
            }
        ],
        "tables": [
            {
                "caption": "table caption",
                "data": [["cell1", "cell2"], ["cell3", "cell4"]]
            }
        ]
    }
    """

# 프롬프트와 함께 추론
prompt = create_structured_prompt()
inputs = processor(
    images=image, 
    text=prompt,
    return_tensors="pt"
)

outputs = model.generate(**inputs, max_new_tokens=2048)
result = processor.decode(outputs[0], skip_special_tokens=True)

# JSON 파싱 시도
try:
    parsed_result = json.loads(result)
    print(json.dumps(parsed_result, indent=2, ensure_ascii=False))
except json.JSONDecodeError:
    print("Raw output:", result)

고급 기능

다국어 처리

# 다양한 언어 지원
languages = {
    "korean": "이 문서를 한국어로 파싱해주세요.",
    "chinese": "请解析这个文档。",
    "japanese": "この文書を解析してください。",
    "english": "Parse this document.",
    "spanish": "Analiza este documento.",
    "french": "Analysez ce document."
}

for lang, prompt in languages.items():
    inputs = processor(images=image, text=prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=2048)
    result = processor.decode(outputs[0], skip_special_tokens=True)
    print(f"{lang.capitalize()}: {result}")

테이블 전용 추출

def extract_tables_only():
    table_prompt = """
    Extract only the tables from this document. 
    For each table, provide:
    1. Table caption/title (if any)
    2. Table content in markdown format
    3. Table location in the document
    """

    inputs = processor(images=image, text=table_prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=1024)
    result = processor.decode(outputs[0], skip_special_tokens=True)
    return result

tables = extract_tables_only()
print("Extracted tables:", tables)

레이아웃 분석

def analyze_layout():
    layout_prompt = """
    Analyze the layout of this document and provide:
    1. Document type (article, report, form, etc.)
    2. Layout structure (columns, sections, headers)
    3. Reading order
    4. Visual elements (images, charts, tables)
    """

    inputs = processor(images=image, text=layout_prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=1024)
    result = processor.decode(outputs[0], skip_special_tokens=True)
    return result

layout_info = analyze_layout()
print("Layout analysis:", layout_info)

성능 최적화

GPU 메모리 최적화

import torch

# 메모리 효율적인 설정
torch.cuda.empty_cache()

model = DotsOCRModel.from_pretrained(
    "./weights/DotsOCR",
    torch_dtype=torch.float16,  # half precision
    device_map="auto"
)

# 그래디언트 체크포인팅 활성화
model.gradient_checkpointing_enable()

추론 속도 최적화

# 컴파일된 모델 사용 (PyTorch 2.0+)
if hasattr(torch, 'compile'):
    model = torch.compile(model, mode="reduce-overhead")

# 최적화된 생성 매개변수
generation_config = {
    "max_new_tokens": 2048,
    "temperature": 0.1,
    "do_sample": True,
    "num_beams": 1,
    "pad_token_id": processor.tokenizer.eos_token_id,
    "use_cache": True
}

outputs = model.generate(**inputs, **generation_config)

벤치마크 및 평가

성능 측정

import time
from PIL import Image

def benchmark_model(image_paths, num_runs=5):
    model = DotsOCRModel.from_pretrained("./weights/DotsOCR")
    processor = DotsOCRProcessor.from_pretrained("./weights/DotsOCR")

    total_time = 0
    total_images = len(image_paths) * num_runs

    for run in range(num_runs):
        for img_path in image_paths:
            image = Image.open(img_path)

            start_time = time.time()
            inputs = processor(images=image, return_tensors="pt")
            outputs = model.generate(**inputs, max_new_tokens=1024)
            result = processor.decode(outputs[0], skip_special_tokens=True)
            end_time = time.time()

            total_time += (end_time - start_time)

    avg_time = total_time / total_images
    print(f"Average processing time: {avg_time:.2f} seconds per image")
    print(f"Throughput: {1/avg_time:.2f} images per second")

# 벤치마크 실행
image_paths = ["doc1.jpg", "doc2.jpg", "doc3.jpg"]
benchmark_model(image_paths)

API 서버 설정

FastAPI 서버

from fastapi import FastAPI, File, UploadFile, HTTPException
from PIL import Image
import io
import base64
import torch

app = FastAPI(title="DotsOCR API Server")

# 모델 로드 (전역)
model = DotsOCRModel.from_pretrained("./weights/DotsOCR")
processor = DotsOCRProcessor.from_pretrained("./weights/DotsOCR")

@app.post("/ocr")
async def ocr_endpoint(file: UploadFile = File(...)):
    try:
        # 이미지 읽기
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))

        # OCR 처리
        inputs = processor(images=image, return_tensors="pt")

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=2048)

        result = processor.decode(outputs[0], skip_special_tokens=True)

        return {
            "status": "success",
            "filename": file.filename,
            "content": result
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/ocr/batch")
async def batch_ocr_endpoint(files: list[UploadFile] = File(...)):
    results = []

    try:
        images = []
        filenames = []

        for file in files:
            contents = await file.read()
            image = Image.open(io.BytesIO(contents))
            images.append(image)
            filenames.append(file.filename)

        # 배치 처리
        inputs = processor(images=images, return_tensors="pt", padding=True)

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=2048)

        for i, output in enumerate(outputs):
            result = processor.decode(output, skip_special_tokens=True)
            results.append({
                "filename": filenames[i],
                "content": result
            })

        return {
            "status": "success",
            "results": results
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

서버 실행

# FastAPI 서버 실행
python api_server.py

# 또는 uvicorn 직접 사용
uvicorn api_server:app --host 0.0.0.0 --port 8000 --reload

클라이언트 사용 예제

import requests
import json

# 단일 파일 OCR
with open("document.jpg", "rb") as f:
    files = {"file": f}
    response = requests.post("http://localhost:8000/ocr", files=files)
    result = response.json()
    print(json.dumps(result, indent=2, ensure_ascii=False))

# 배치 OCR
files_to_upload = []
for filename in ["doc1.jpg", "doc2.jpg", "doc3.jpg"]:
    with open(filename, "rb") as f:
        files_to_upload.append(("files", f))

response = requests.post("http://localhost:8000/ocr/batch", files=files_to_upload)
results = response.json()
print(json.dumps(results, indent=2, ensure_ascii=False))

문제 해결

일반적인 문제

  1. 메모리 부족 오류

    # 메모리 사용량 감소
    torch.cuda.empty_cache()
    model = model.half()  # FP16 사용
  2. 느린 추론 속도

    # 배치 크기 조정
    # 더 작은 max_new_tokens 사용
    # GPU 최적화 활성화
  3. 모델 로딩 실패

    # 모델 재다운로드
    rm -rf ./weights/DotsOCR
    python3 tools/download_model.py

디버깅 모드

import logging

# 로깅 활성화
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# 상세한 오류 정보
try:
    outputs = model.generate(**inputs)
except Exception as e:
    logger.error(f"Generation failed: {e}")
    logger.error(f"Input shape: {inputs['input_ids'].shape}")
    logger.error(f"Available GPU memory: {torch.cuda.memory_reserved()}")

모델 정보

아키텍처

  • 백본: 17억 매개변수 언어 모델
  • 비전 인코더: 고해상도 이미지 처리
  • 멀티모달 융합: 비전-언어 정렬
  • 출력 형식: 구조화된 텍스트

지원 언어

  • 영어, 중국어, 일본어, 한국어
  • 유럽 언어들 (독일어, 프랑스어, 스페인어, 이탈리아어 등)
  • 아랍어, 힌디어 등 기타 언어

지원 문서 유형

  • 학술 논문
  • 비즈니스 문서
  • 양식 및 표
  • 웹페이지
  • 책 및 잡지
  • 손글씨 문서

라이선스 및 인용

이 프로젝트는 Apache 2.0 라이선스 하에 배포됩니다.

연구에서 사용하시는 경우 다음과 같이 인용해 주세요:

@article{dots_ocr_2024,
  title={dots.ocr: Multilingual Document Layout Parsing in a Single Vision-Language Model},
  author={RedNote HiLab Team},
  journal={arXiv preprint},
  year={2024}
}

출처: https://github.com/rednote-hilab/dots.ocr

반응형
  • 네이버 블로그 공유
  • 네이버 밴드 공유
  • 페이스북 공유
  • 카카오스토리 공유