聊聊Spring AI Alibaba的RedisChatMemory
序
本文主要研究一下Spring AI Alibaba的RedisChatMemory
RedisChatMemory
community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/RedisChatMemory.java
public class RedisChatMemory implements ChatMemory, AutoCloseable {private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);private static final String DEFAULT_KEY_PREFIX = "spring_ai_alibaba_chat_memory";private static final String DEFAULT_HOST = "127.0.0.1";private static final int DEFAULT_PORT = 6379;private static final String DEFAULT_PASSWORD = null;private final JedisPool jedisPool;private final Jedis jedis;private final ObjectMapper objectMapper;public RedisChatMemory() {this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD);}public RedisChatMemory(String host, int port, String password) {JedisPoolConfig poolConfig = new JedisPoolConfig();this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password);this.jedis = jedisPool.getResource();this.objectMapper = new ObjectMapper();SimpleModule module = new SimpleModule();module.addDeserializer(Message.class, new MessageDeserializer());this.objectMapper.registerModule(module);logger.info("Connected to Redis at {}:{}", host, port);}@Overridepublic void add(String conversationId, List<Message> messages) {String key = DEFAULT_KEY_PREFIX + conversationId;for (Message message : messages) {try {String messageJson = objectMapper.writeValueAsString(message);jedis.rpush(key, messageJson);}catch (JsonProcessingException e) {throw new RuntimeException("Error serializing message", e);}}logger.info("Added messages to conversationId: {}", conversationId);}@Overridepublic List<Message> get(String conversationId, int lastN) {String key = DEFAULT_KEY_PREFIX + conversationId;List<String> messageStrings = jedis.lrange(key, -lastN, -1);List<Message> messages = new ArrayList<>();for (String messageString : messageStrings) {try {Message message = objectMapper.readValue(messageString, Message.class);messages.add(message);}catch (JsonProcessingException e) {throw new RuntimeException("Error deserializing message", e);}}logger.info("Retrieved {} messages for conversationId: {}", messages.size(), conversationId);return messages;}@Overridepublic void clear(String conversationId) {String key = DEFAULT_KEY_PREFIX + conversationId;jedis.del(key);logger.info("Cleared messages for conversationId: {}", conversationId);}@Overridepublic void close() {if (jedis != null) {jedis.close();logger.info("Redis connection closed.");}if (jedisPool != null) {jedisPool.close();logger.info("Jedis pool closed.");}}public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {try {String key = DEFAULT_KEY_PREFIX + conversationId;List<String> all = jedis.lrange(key, 0, -1);if (all.size() >= maxLimit) {all = all.stream().skip(Math.max(0, deleteSize)).toList();}this.clear(conversationId);for (String message : all) {jedis.rpush(key, message);}}catch (Exception e) {logger.error("Error clearing messages from Redis chat memory", e);throw new RuntimeException(e);}}public void updateMessageById(String conversationId, String messages) {String key = "spring_ai_alibaba_chat_memory:" + conversationId;try {this.jedis.del(key);this.jedis.rpush(key, new String[] { messages });}catch (Exception var6) {logger.error("Error updating messages from Redis chat memory", var6);throw new RuntimeException(var6);}}}
RedisChatMemory的构造器初始化了JedisPool并给ObjectMapper注册了
org.springframework.ai.chat.messages.Message
类型的MessageDeserializer;其add方法遍历messages挨个序列化为json然后rpush到spring_ai_alibaba_chat_memory{conversationId}
中;其get方法通过lrange取出最近n条记录,再反序列化为message对象;其clear方法直接删除该key;close方法则先关闭jedis再关闭jedisPool
MessageDeserializer
community/memories/spring-ai-alibaba-redis-memory/src/main/java/com/alibaba/cloud/ai/memory/redis/serializer/MessageDeserializer.java
public class MessageDeserializer extends JsonDeserializer<Message> {private static final Logger logger = LoggerFactory.getLogger(MessageDeserializer.class);public Message deserialize(JsonParser p, DeserializationContext ctxt) {ObjectMapper mapper = (ObjectMapper) p.getCodec();JsonNode node = null;Message message = null;try {node = mapper.readTree(p);String messageType = node.get("messageType").asText();switch (messageType) {case "USER" -> message = new UserMessage(node.get("text").asText(),mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {}), mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {}));case "ASSISTANT" -> message = new AssistantMessage(node.get("text").asText(),mapper.convertValue(node.get("metadata"), new TypeReference<Map<String, Object>>() {}), (List<AssistantMessage.ToolCall>) mapper.convertValue(node.get("toolCalls"),new TypeReference<Collection<AssistantMessage.ToolCall>>() {}),(List<Media>) mapper.convertValue(node.get("media"), new TypeReference<Collection<Media>>() {}));default -> throw new IllegalArgumentException("Unknown message type: " + messageType);};}catch (IOException e) {logger.error("Error deserializing message", e);}return message;}}
MessageDeserializer继承了JsonDeserializer,它读取messageType字段,然后对于USER类型创建UserMessage、对于ASSISTANT类型创建AssistantMessage
小结
spring-ai-alibaba-redis-memory提供了ChatMemory的redis实现,它通过jedis使用rpush添加message,通过lrange取出最近N条,通过del删除指定会话的消息。
doc
- java2ai