Java 使用 LangChain4j 搭建大模型的 RAG 教程
一、引言
随着大语言模型(LLM)的兴起,其在各个领域的应用越来越广泛。然而,预训练模型的知识存在时效性问题,并且对于公司内部的私有数据,出于安全和商业利益考虑,不能直接使用通用的模型。因此,检索增强生成(RAG)技术应运而生。RAG 结合了信息检索和大模型生成的优势,能够在生成答案之前,从知识库中检索相关文档片段,从而生成更准确和信息丰富的文本。本文将详细介绍如何使用 Java 和 LangChain4j 框架搭建一个基于 RAG 的大模型应用。
二、基本概念
(一)什么是 RAG
RAG 的核心思想是将传统的信息检索(IR)技术与现代的生成式大模型结合起来。其工作原理可以分为以下几个步骤:
接收请求:系统接收到用户的查询。
信息检索(R):从知识库中检索出与查询最相关的文档片段。
生成增强(A):将检索到的文档片段与原始查询一起输入到大模型中。
输出生成(G):大模型生成最终的文本答案。
(二)LangChain4j 简介
LangChain4j 是 LangChain 的 Java 版本,旨在封装与 LLM 对接的细节,简化开发流程,提升基于 LLM 开发的效率。
三、环境准备
(一)安装 JDK 21
Oracle 官方网站,找到 JDK 21 的下载页面。
一直点下一步安装完成后执行
java-version
(二)安装 pgvector 不同环境参考链接安装即可
四、项目搭建
创建一个SpringBoot项目
在 pom.xml 文件中添加 LangChain4j 相关依赖:
<dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-jdbc</artifactId></dependency><!-- langchain4j dependencies start --><dependency><groupId>dev.langchain4j</groupId><artifactId>langchain4j</artifactId><version>${langchain.version}</version></dependency><dependency><groupId>dev.langchain4j</groupId><artifactId>langchain4j-community-dashscope-spring-boot-starter</artifactId></dependency><dependency><groupId>dev.langchain4j</groupId><artifactId>langchain4j-ollama-spring-boot-starter</artifactId><version>${langchain.version}</version></dependency><dependency><groupId>dev.langchain4j</groupId><artifactId>langchain4j-web-search-engine-searchapi</artifactId><version>${langchain.version}</version></dependency><dependency><groupId>dev.langchain4j</groupId><artifactId>langchain4j-easy-rag</artifactId><version>${langchain.version}</version></dependency><dependency><groupId>dev.langchain4j</groupId><artifactId>langchain4j-pgvector</artifactId><version>${langchain.version}</version></dependency>
依赖版本仅供参考,可根据实际开发需求修改
<maven.compiler.source>21</maven.compiler.source><maven.compiler.target>21</maven.compiler.target><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><lombok.version>1.18.30</lombok.version><junit.version>5.11.4</junit.version><log4j2.version>2.24.3</log4j2.version><springboot.version>3.3.2</springboot.version><postgresql.version>42.3.8</postgresql.version><mybatis-plus.version>3.5.8</mybatis-plus.version><oapi-sdk>2.4.8</oapi-sdk><caffeine.version>3.1.8</caffeine.version><httpclient.version>5.4.1</httpclient.version><springai.version>1.0.0-SNAPSHOT</springai.version><langchain.version>1.0.0-beta1</langchain.version>
初始化Assistant,自定义embeddingStore和embeddingModel
@Configuration
@RequiredArgsConstructor
public class AssistantInit {final ChatLanguageModel chatLanguageModel;@Beanpublic Assistant init(EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel) {ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder().embeddingStore(embeddingStore).embeddingModel(embeddingModel).maxResults(1).build();return AiServices.builder(Assistant.class).chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(10)).contentRetriever(contentRetriever)
// .tools(new HighLevelCalculator(), new WebSearchTool(engine)).chatLanguageModel(chatLanguageModel).build();}}
初始化pgvector数据库
@Configuration
@RequiredArgsConstructor
public class EmbeddingStoreInit {final PgConfig pgConfig;@Beanpublic EmbeddingStore<TextSegment> initEmbeddingStore() {return PgVectorEmbeddingStore.builder().table(pgConfig.getTable()).dropTableFirst(true).createTable(true).host(pgConfig.getHost()).port(pgConfig.getPort()).user(pgConfig.getUser()).dimension(1024)//向量模型的向量维度,根据模型支持的维度填写.password(pgConfig.getPassword()).database(pgConfig.getDatabase()).build();}
}
yml配置参数读取
@Configuration
@ConfigurationProperties(prefix = "pgvector")
@Data
public class PgConfig {private String host;private int port;private String database;private String user;private String password;private String table;}
API接口设计
@RequestMapping("/rag")
@RequiredArgsConstructor
@RestController
public class RagAPI {final EmbeddingStore<TextSegment> embeddingStore;final EmbeddingModel embeddingModel;final Assistant assistant;@GetMapping("/chat")public String chat(@RequestParam(value = "message") String message) {return assistant.chat(message);}@GetMapping("/load")public String load(@RequestParam(value = "maxSegmentSizeInChars",required = false,defaultValue = "300") int maxSegmentSizeInChars, @RequestParam(value = "maxOverlapSizeInChars",required = false ,defaultValue = "10") int maxOverlapSizeInChars) {List<Document> documents = FileSystemDocumentLoader.loadDocuments("文件夹路径");EmbeddingStoreIngestor.builder().embeddingStore(embeddingStore).embeddingModel(embeddingModel).documentSplitter(new DocumentByLineSplitter(maxSegmentSizeInChars, maxOverlapSizeInChars)).build().ingest(documents);return "数据加载成功";}}