RAG 기반 Gemma 기술문서 QA 챗봇 (RAG, Gemma 7B)¶
- 목표: RAG와 Gemma를 활용한 Gemma 기술문서 QA 챗봇을 개발합니다.
1. 환경 설정 및 데이터 로드¶
1-1. 필수 라이브러리 설치¶
In [ ]:
!pip install transformers sentence-transformers langchain openai chromadb bs4 accelerate langchain_community pypdf text_generation
1-2. Hugging Face 토큰 등록¶
In [ ]:
import os
from google.colab import userdata
os.environ['HUGGINGFACEHUB_API_TOKEN'] = userdata.get('HUGGINGFACEHUB_API_TOKEN')
1-3. 데이터 로드¶
- 데이터 유형별로 로드 방식이 다르며, 해당 데이터는 RecursiveUrlLoader 사용
- Gemma 공식 기술문서
- robots.txt에서 크롤링 가능 확인
In [ ]:
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
url = "https://ai.google.dev/gemma/docs?hl=en"
loader = RecursiveUrlLoader(url=url, max_depth=2)
docs = loader.load()
2. 문서 변환 및 Vector DB 저장¶
2-1. 문서 변환 – 문서 분할 및 청킹¶
- 긴 문서를 모델과 호환되고 정확하고 명확한 결과를 생성하는 작은 덩어리로 분할하는 작업
- RecursiveCharacterTextSplitter 사용
In [ ]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200, add_start_index=True
)
splits = text_splitter.split_documents(docs)
2-2. Vector DB 저장¶
- 텍스트 청크를 추출한 후 RAG 애플리케이션을 사용하여 향후 검색을 위해 이를 저장하고 색인화함
- 일반적인 접근 방식은 각 분할의 콘텐츠를 임베딩하고 이러한 임베딩을 벡터 저장소에 저장하는 것
- 다국어 모델인 sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 사용
In [ ]:
# Import the HuggingFaceEmbeddings class from the langchain module
from langchain.embeddings import HuggingFaceEmbeddings
# Define the path to the pre-trained model you want to use
# modelPath = "sentence-transformers/all-MiniLM-l6-v2"
modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# Create a dictionary with model configuration options, specifying to use the CPU for computations
model_kwargs = {'device':'cpu'}
# Create a dictionary with encoding options, specifically setting 'normalize_embeddings' to False
encode_kwargs = {'normalize_embeddings': False}
# Initialize an instance of HuggingFaceEmbeddings with the specified parameters
embeddings = HuggingFaceEmbeddings(
model_name=modelPath, # Provide the pre-trained model's path
model_kwargs=model_kwargs, # Pass the model configuration options
encode_kwargs=encode_kwargs # Pass the encoding options
)
- Chroma 벡터 저장소와 Lang chain의 오픈 소스 "HuggingFaceEmbeddings"를 사용하면 단일 명령으로 모든 문서 분할을 삽입하고 저장
In [ ]:
# Import the Chroma class from the langchain.vectorstores module
from langchain.vectorstores import Chroma
# Create a Chroma vector store from the previously split documents and using the specified embeddings
vectorstore = Chroma.from_documents(
documents=splits, # Split documents for vectorization
embedding=embeddings # Embeddings instance for encoding
)
- 검색할 때 검색 쿼리를 포함하고 유사성 검색을 수행하여 쿼리 포함과 가장 유사한 포함으로 저장된 분할을 식별합니다.
- 임베딩 간의 각도를 측정하는 코사인 유사성은 간단한 유사성 측정입니다.
In [ ]:
question = "What is Gemma?"
searchDocs = vectorstore.similarity_search(question)
print(searchDocs[0].page_content)
<p>This tutorial shows you how to get started with Gemma using <a href="https://keras.io/keras_nlp/">KerasNLP</a>. Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models. KerasNLP is a collection of natural language processing (NLP) models implemented in <a href="https://keras.io/">Keras</a> and runnable on JAX, PyTorch, and TensorFlow.</p>
<p>In this tutorial, you'll use Gemma to generate text responses to several prompts. If you're new to Keras, you might want to read <a href="https://keras.io/getting_started/">Getting started with Keras</a> before you begin, but you don't have to. You'll learn more about Keras as you work through this tutorial.</p>
<h2 id="setup" data-text="Setup" tabindex="-1">Setup</h2>
<h3 id="gemma_setup" data-text="Gemma setup" tabindex="-1">Gemma setup</h3>
2-3. 텍스트 청크 검색기(retriever) 생성¶
- 데이터를 저장하고, LLM 모델을 준비하고, 파이프라인을 구축한 후에는 데이터를 검색해야 합니다.
- 검색기는 쿼리를 기반으로 문서를 반환하는 인터페이스 역할을 합니다.
In [ ]:
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
3. RAG 및 Gemma 7B 기반 답변 생성¶
3-1. LLM 로드¶
- HuggingFaceTextGenInference로 HuggingFace의 Gemma 7B 모델을 사용
In [ ]:
from langchain_community.llms import HuggingFaceTextGenInference
ENDPOINT_URL = "https://api-inference.huggingface.co/models/google/gemma-7b"
HF_TOKEN = userdata.get('HUGGINGFACEHUB_API_TOKEN')
llm = HuggingFaceTextGenInference(
inference_server_url=ENDPOINT_URL,
max_new_tokens=1024,
top_k=50,
temperature=0.1,
repetition_penalty=1.03,
server_kwargs={
"headers": {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
}
},
)
3-2. 프롬프트 템플릿 생성¶
- 검색기로 벡터 저장소에서 관련 문서를 검색하여 context 값으로 생성
In [ ]:
from langchain.prompts import PromptTemplate
question = "Difference between Gemma 7B and 2B model"
# 템플릿
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
3-3. QA 체인 생성¶
- 프롬프트를 사용하여 사용자 쿼리와 함께 llm에 전달할 RetrievalQA 체인 생성
In [ ]:
# Import the RetrievalQA class from the langchain module
from langchain.chains import RetrievalQA
# Create a RetrievalQA instance with specified components
chain = RetrievalQA.from_chain_type(
llm=llm, # Provide the language model
chain_type="stuff", # Specify the type of the language model chain
retriever=retriever, # Provide the document retriever
return_source_documents=True, # Returning source-documents with answers
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
4. QA 결과 및 문서 레퍼런스 확인¶
4-1. QA 결과¶
In [ ]:
# Execute the query using the RetrievalQA chain and store the result
result = chain ({ "query" : question })
# Print or use the formatted result text
print(result['result'])
Yes No
Answer: The main difference is the size (and thus the complexity) in the language models. Smaller ones like Gemm...
4-2. 문서 레퍼런스¶
In [ ]:
documents=result['source_documents']
for document in documents:
# Extracting content from the 4 document chunks used for specific query
page_content = document.page_content
metadata = document.metadata
# Now you can use page_content and metadata as needed
print("Page Content:", page_content)
print("Source:", metadata['source'])
print("Start Index:", metadata['start_index'])
print("\n")
In [ ]:
'IT > 인공지능' 카테고리의 다른 글
[생성형AI][LLM] vLLM: LLM 추론 및 배포 최적화 라이브러리 (1) | 2024.03.12 |
---|---|
[GPU] RAPIDS: 대규모 데이터 세트 분석을 위한 GPU 가속 프레임워크 (0) | 2024.02.27 |
[생성형AI][LLM] Gemma 모델 파인튜닝 (Hugging Face) (3) | 2024.02.24 |
[생성형AI][Text2Video] Sora: 콘텐츠 제작의 미래를 선도하는 비디오 생성 모델 (0) | 2024.02.20 |
[생성형AI][RAG] 증상 기반 법정감염병 판별 챗봇 (0) | 2024.02.09 |