Langchain+RAG+向量数据库
加载数据
import osimport lancedb from langchain_community.document_loaders import TextLoader from langchain_community.embeddings import BaichuanTextEmbeddings from langchain_community.vectorstores import LanceDB from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_openai import ChatOpenAI from langchain_text_splitters import RecursiveCharacterTextSplitteros.environ['BAICHUAN_API_KEY'] = 'sk-732b2b80be7bd800cb3a1dbc330722b4' loader = TextLoader('state_of_the_union.txt', encoding='utf8')documents = loader.load()
切块
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100,#分块大小chunk_overlap=0,#重复大小length_function=len,is_separator_regex=False,separators=["\n\n","\n",".","?","!","。","!","?",",",","," "] )docs = text_splitter.split_documents(documents) print('=======', len(docs))
编码
embeddings = BaichuanTextEmbeddings()#收费。可以通过HuggingFace下载不收费的
构建向量数据库
LanceDB:
os.path.join(os.getcwd(), 'lanceDB'):使用当前目录,作为DB存储目录。
# 连接向量数据库 connect = lancedb.connect(os.path.join(os.getcwd(), 'lanceDB')) # 本地目录存储向量vectorStore = LanceDB.from_documents(docs, embeddings, connection=connect, table_name='my_vectors')query = '今年长三角铁路春游运输共经历多少天?' # 测试一下向量数据库 # docs_and_score = vectorStore.similarity_search_with_score(query) # for doc, score in docs_and_score: # print('-------------------------') # print('Score: ', score) # print("Content: ", doc.page_content)
LLM整合
Prompt—》Model—》parser—》Retrieve(将Question和检索结果结合)—》chain
retriever = vectorStore.as_retriever() template = """Answer the question based only on the following context: {context} Question: {question} """prompt = ChatPromptTemplate.from_template(template)# 创建模型 model = ChatOpenAI(model='glm-4-0520',api_key='0884a4262379e6b9e98d08be606f2192.TOaCwXTLNYo1GlRM',base_url='https://open.bigmodel.cn/api/paas/v4/' )output_parser = StrOutputParser()# 把检索器和用户输入的问题,结合得到检索结果 start_retriever = RunnableParallel({'context': retriever, 'question': RunnablePassthrough()})# 创建长链 chain = start_retriever | prompt | model | output_parserres = chain.invoke(query) print(res)