지난번에 Triton Inference Server에서 HuggingFace 모델 서빙하기를 포스팅 했었는데,
당시에는 HuggingFace에 한국어로 파인튜닝된 LLaMa3 모델을 서빙했었다.

이번에는 벤치마크 LogicKor, Horangi 리더보드에서 높은 순위를 기록하고 있는 ko-gemma-2-9b-it를 서빙하기 위한 과정을 기록해보았다.

TensorRT-LLM Backend의 기본 예제만 따르면 쉽게 서빙할 수 있는 LLaMa3와는 달리, Gemma2는 비교적 최신 모델이라 TensorRT-LLM 버전, TensorRT-LLM Backend 버전, Triton Inferenece Server Container의 버전을 모두 신경써줘야 했다.
TensorRT-LLM v0.13.0부터 Gemma2를 지원한다는 정보를 토대로, 이 포스트에선 다음을 사용한다.

  • Triton Inference Server Container 24.09 (TensorRT-LLM v0.13.0의 dependent TensorRT version인 10.4.0이 설치되어있음)
  • TensorRT-LLM Backend v0.13.0

이후 버전을 사용해도 무방할 것으로 보인다.
(TensorRT-LLM Backend v0.13.0과, TensorRT-LLM v0.15.0이 설치된 Triton Inference Server Container 24.11을 사용해서 테스트 했을 때 동일하게 동작함)

서빙 환경

  • OS: Ubuntu 20.04
  • GPU: NVIDIA RTX4090 * 2
  • GPU Driver Version: 550.127.05

Update the TensorRT-LLM submodule

git clone -b v0.13.0 https://github.com/triton-inference-server/tensorrtllm_backend.git
cd tensorrtllm_backend
git submodule update --init --recursive
git lfs install
git lfs pull

Launch Triton TensorRT-LLM container

docker run -it --net host --shm-size=2g \
    --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \
    -v ~/tensorrtllm_backend:/tensorrtllm_backend \
    -v ~/engines:/engines \
    nvcr.io/nvidia/tritonserver:24.09-trtllm-python-py3

Prepare TensorRT-LLM engines

Download weights from HuggingFace Transformers

cd /tensorrtllm_backend/tensorrt_llm
pip install huggingface-hub
huggingface-cli login
huggingface-cli download "rtzr/ko-gemma-2-9b-it" --local-dir "ko-gemma-2-9b-it"

Convert weights from HF Tranformers to TensorRT-LLM checkpoint

‘world_size’ is the number of GPUs used to build the TensorRT-LLM engine.

CKPT_PATH=ko-gemma-2-9b-it/
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_ko-gemma-2-9b-it_tensorrt_llm/bf16/tp2/
ENGINE_PATH=/engines/gemma2/9b/bf16/2-gpu/
VOCAB_FILE_PATH=ko-gemma-2-9b-it/tokenizer.model

python3 ./examples/gemma/convert_checkpoint.py \
    --ckpt-type hf \
    --model-dir ${CKPT_PATH} \
    --dtype bfloat16 \
    --world-size 2 \
    --output-model-dir ${UNIFIED_CKPT_PATH}

Build TensorRT engines

trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
             --gemm_plugin auto \
             --max_batch_size 8 \
             --max_input_len 3000 \
             --max_seq_len 3100 \
             --output_dir ${ENGINE_PATH}

Prepare the Model Repository

rm -rf /triton_model_repo
mkdir /triton_model_repo
cp -r /tensorrtllm_backend/all_models/inflight_batcher_llm/* /triton_model_repo/

Modify the Model Configuration

ENGINE_DIR=/engines/gemma2/9b/bf16/2-gpu/
TOKENIZER_DIR=/tensorrtllm_backend/tensorrt_llm/ko-gemma-2-9b-it/
MODEL_FOLDER=/triton_model_repo
TRITON_MAX_BATCH_SIZE=4
INSTANCE_COUNT=1
MAX_QUEUE_DELAY_MS=0
MAX_QUEUE_SIZE=0
FILL_TEMPLATE_SCRIPT=/tensorrtllm_backend/tools/fill_template.py
DECOUPLED_MODE=false

python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/preprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},preprocessing_instance_count:${INSTANCE_COUNT}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},engine_dir:${ENGINE_DIR},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MS},batching_strategy:inflight_fused_batching,max_queue_size:${MAX_QUEUE_SIZE}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},postprocessing_instance_count:${INSTANCE_COUNT},max_queue_size:${MAX_QUEUE_SIZE}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},bls_instance_count:${INSTANCE_COUNT}${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},bls_instance_count:${INSTANCE_COUNT}

Serving with Triton

‘world_size’ is the number of GPUs you want to use for serving. This should be aligned with the number of GPUs used to build the TensorRT-LLM engine.

python3 /tensorrtllm_backend/scripts/launch_triton_server.py --world_size=2 --model_repo=/triton_model_repo

To Stop Triton Server insider the container

pkill tritonserver

Send an Inference Request

curl -X POST http://localhost:8000/v2/models/ensemble/generate -d '{"text_input": "안녕?", "max_tokens": 100, "bad_words": "", "stop_words": ""}'

Output Example

{"batch_index":0,"context_logits":0.0,"cum_log_probs":0.0,"generation_logits":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":0.0,"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"안녕? \n\n저는 한국어를 배우는 중인 AI입니다. \n\n오늘은 한국어로 대화를 나누고 싶어요. \n\n어떤 주제로 이야기해볼까요? \n\n😊\n"}

참고 링크

GitHub triton-inference-server/tensorrtllm_backend 저장소의 튜토리얼과 NVIDIA/TensorRT-LLM 저장소의 LLaMa 예제를 참고하여
HuggingFace의 MLP-KTLim/llama-3-Korean-Bllossom-8B 모델의 엔진을 생성하고 서빙하는 과정을 정리한 글입니다.

Update the TensorRT-LLM submodule

git clone -b v0.11.0 <https://github.com/triton-inference-server/tensorrtllm_backend.git>
cd tensorrtllm_backend
git submodule update –init –recursive
git lfs install
git lfs pull
cd ..

Launch Triton TensorRT-LLM container

docker run -it --net host --shm-size=2g \\
	--ulimit memlock=-1 --ulimit stack=67108864 --gpus all \\
	-v $(pwd)/tensorrtllm_backend:/tensorrtllm_backend \\
	-v $(pwd)/engines:/engines \\
	nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3

Prepare TensorRT-LLM engines

cd /tensorrllm_backend/tensorrt_llm

Download weights from HuggingFace Transformers

pip install huggingface-hub
huggingface-cli login
huggingface-cli download "MLP-KTLim/llama-3-Korean-Bllossom-8B" --local-dir "llama-3-Korean-Bllossom-8B"

Build LLaMA v3 8B TP=1 using HF checkpoints directly.

cd /tensorrtllm_backend/tensorrt_llm/examples/llama

Convert weights from HF Tranformers to TensorRT-LLM checkpoint

python3 convert_checkpoint.py --model_dir /tensorrtllm_backend/tensorrt_llm/llama-3-Korean-Bllossom-8B \\
	--output_dir ./tllm_checkpoint_1gpu_tp1 \\
	--dtype float16 \\
  --tp_size 1

Build TensorRT engines

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_tp1 \\
	--output_dir ./tmp/llama/8B/trt_engines/fp16/1-gpu/ \\
	--gemm_plugin auto

Prepare the Model Repository

rm -rf /triton_model_repo
mkdir /triton_model_repo
cp -r /tensorrtllm_backend/all_models/inflight_batcher_llm/* /triton_model_repo/

Modify the Model Configuration

ENGINE_DIR=/tensorrtllm_backend/tensorrt_llm/examples/llama/tmp/llama/8B/trt_engines/fp16/1-gpu
TOKENIZER_DIR=/tensorrtllm_backend/tensorrt_llm/llama-3-Korean-Bllossom-8B
MODEL_FOLDER=/triton_model_repo
TRITON_MAX_BATCH_SIZE=4
INSTANCE_COUNT=1
MAX_QUEUE_DELAY_MS=0
MAX_QUEUE_SIZE=0
FILL_TEMPLATE_SCRIPT=/tensorrtllm_backend/tools/fill_template.py
DECOUPLED_MODE=false

python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/ensemble/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/preprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},preprocessing_instance_count:${INSTANCE_COUNT}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},engine_dir:${ENGINE_DIR},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MS},batching_strategy:inflight_fused_batching,max_queue_size:${MAX_QUEUE_SIZE}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/postprocessing/config.pbtxt tokenizer_dir:${TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},postprocessing_instance_count:${INSTANCE_COUNT},max_queue_size:${MAX_QUEUE_SIZE}
python3 ${FILL_TEMPLATE_SCRIPT} -i ${MODEL_FOLDER}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},bls_instance_count:${INSTANCE_COUNT}

Serving with Triton

‘world_size’ is the number of GPUs you want to use for serving. This should be aligned with the number of GPUs used to build the TensorRT-LLM engine.

python3 /tensorrtllm_backend/scripts/launch_triton_server.py --world_size=1 --model_repo=/triton_model_repo

To stop Triton Server inside the container

**pkill tritonserver**

Send an Inference Request

curl -X POST <http://localhost:8000/v2/models/ensemble/generate> -d '{"text_input": "한강 작가를 알고 있니?", "max_tokens": 100, "bad_words": "", "stop_words": ""}'

Output Example

{"context_logits":0.0,"cum_log_probs":0.0,"generation_logits":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":"한강 작가를 알고 있니? 한강 작가는 한국의 대표적인 현대 소설가 중 한 명으로, '채식주의', '적도의 남자', '황무지' 등 많은 작품을 발표한 작가다. 그녀의 작품은 주로 인간의 본성, 사회적 규범, 그리고 개인의 자유와 억압을 탐구하는 내용을 담고 있다. 특히 '채식주의'는 한강 작가의 대표작으로, 주인공이 채식을 시작하면서 "}

참고 링크

Overview

  • ML모델 서빙을 위한 오픈소스 플랫폼으로, AI모델의 대규모 배포를 간단하게 하기 위해 설계됨
  • 훈련된 AI모델이 효율적으로 추론할 수 있도록 함
  • 다양한 H/W(NVIDIA GPU, CPU 등)와 프레임워크(TensorFlow, PyTorch, ONNX 등) 지원

Key Features

  • 다양한 프레임워크 지원
    • TensorFlow, PyTorch, ONNX, TensorRT 등 다양한 ML/DL프레임워크를 지원
    • 각기 다른 환경에서 훈련된 모델을 변환 없이 배포 가능
  • 동적 배칭 및 모델 최적화
    • 다수의 추론 요청을 자동으로 배칭하여 지연을 줄이고 처리량을 최적화
    • TensorRT와 같은 모델 최적화 기술 제공 (TensorRT: NVIDIA GPU에서 추론을 가속화함)
  • 확장성
    • 다수의 GPU나 서버에 걸쳐 확장을 제공하여, 높은 처리량의 워크로드를 처리함
    • 데이터센터등 큰 규모의 배포에 적합
  • 배포 용이성
    • 모델은 모델 저장소에 저장되며, 최소한의 구성(configuration)으로 모델을 서빙할 수 있음
    • 새로운 모델을 업데이트하거나 배포하기 쉬움
  • 평가지표 및 모니터링
    • 성능, 처리량, 메모리 사용에 대한 상세한 평가지표 제공
    • Prometheus 또는 Grafana를 통해 접근 가능
  • 커스텀 백엔드 지원
    • Triton의 기능을 확장하기 위한 커스텀 백엔드를 만들 수 있음
    • 특별한 요구사항 충족 가능

Architecture

https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#triton-architecture

  • Model Repository는 Triton에서 추론 가능한 모델들의 저장소 (파일시스템 기반)
  • 서버에 도착한 추론 요청은 HTTP/REST, gRPC 또는 C API를 통해 적절한 Per-Model Scheduler로 라우팅 됨
  • 모델별로 설정 가능한 Multiple Scheduling과 Batching Algorithm이 제공됨
  • 각 모델의 스케줄러는 선택적으로 추론 요청의 배칭을 실행하고, 모델 타입에 맞는 백엔드에 요청을 전달함
  • 백엔드는 배칭된 요청에서 제공된 입력을 사용하여 추론을 실행하여 요청된 결과를 생성함
  • Triton은 백엔드 C API를 제공하여 커스텀 전/후처리, 새로운 딥러닝 프레임워크 등 새로운 기능으로 확장 가능함
  • HTTP/REST, gRPC 또는 C API를 통해 사용 가능한 Model Management API으로 Triton이 서브하는 모델에 쿼리를 보내고 컨트롤을 할 수 있음

Ensemble Model vs. BLS(Business Logic Scripting)

  • Ensemble Model (🔗)
    • 모델간의 입출력 텐서를 연결하는 파이프라인
    • 여러 모델을 포함하는 과정을 캡슐화하기 위해 사용됨
      (Data Preprocessing → Inferene → data Postprocessing)
    • 중간 텐서들의 전송 오버헤드를 피하고, Triton에 전송될 요청 갯수를 최소화함
  • BLS(Business Logic Scripting) (🔗)
    • 커스텀 로직과 모델 실행의 결합
    • 반복문, 조건문, 데이터 의존적인 제어흐름과 커스텀로직을 모델 파이프라인에 포함하여, 모델 실행과 결합할 수 있음
  • Examples (🔗)
    • Ensemble 방식
      • 전처리, 추론, 후처리가 각각의 모델로 구성
        이미지 입력
        → 얼굴 인식을 위한 전처리 (Python Backend)
        → 얼굴 인식 추론 (ONNX)
        → 후처리 (Python Backend)
        → 결과 출력
    • BLS 방식
      • 전처리, 추론, 후처리가 하나의 파일로 처리되며, 인식 결과가 bls 파일에서 요청되고 얻어짐
        이미지 입력
        → Business Logic Script (Python Backend): 전처리 → 얼굴 인식 추론 → 후처리
        → 결과 출력

+ Recent posts