当前位置: 首页 > news >正文

大模型架构记录2 【综述-相关代码】

一 简单聊天机器人搭建 

1.1 openai调用

import os
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv

load_dotenv()
client = OpenAI()

# 打印所支持的模型
model_lst = client.models.list()

for model in model_lst:
    print (model.id)

# 调用API接口
completion = client.chat.completions.create(
    
    model="gpt-4",
    messages=[
        {"role": "system", "content": "你是一名专业的英语助教,给学生提供必要的支持如提供提示、纠正错误等,但对于学生的请求不能直接给出答案,而是要一步步引导学生完成任务。 你仅需要回复跟英语有关的问题,如与英语无关,不要回复。"},
        {"role": "user", "content": "如何学好英语?"},
        {"role": "assistant", "content": "xxxx"},
        {"role": "user", "content": "xx"}
    ],
    max_tokens = 500,
    #n=5,
    temperature=0.7,
)

print (completion)
#print(completion.choices[0].message)

1.2 ai画图

# 如果缺乏这些library, 请安装。 
# pip install openai python-dotenv
from dotenv import load_dotenv

load_dotenv()

from openai import OpenAI

client = OpenAI() # is this a new way of initializing OpenAI API?

response = client.images.generate(
  model="dall-e-3",
  prompt="在一个教室里,很多学生在上数学课,并且激烈地在跟老师讨论问题。",
  size="1024x1024",
  quality="standard",
  n=1,
)
image_url = response.data[0].url

1.3 llm-developing-mygpt

主要是简单搭建一个聊天机器人。

app.py

import gradio as gr
from loguru import logger
from MyGPT import MyGPT
from config import MODELS, DEFAULT_MODEL, MODEL_TO_MAX_TOKENS


mygpt = MyGPT()


def fn_prehandle_user_input(user_input, chat_history):
    # 检查输入
    if not user_input:
        gr.Warning("请输入您的问题")
        logger.warning("请输入您的问题")
        return chat_history

    # 用户消息在前端对话框展示
    chat_history.append([user_input, None])

    logger.info(f"\n用户输入: {user_input}, \n"
                f"历史记录: {chat_history}")
    return chat_history


def fn_predict(
        user_input,
        chat_history,
        model,
        max_tokens,
        temperature,
        stream):

    # 如果用户输入为空,则返回当前的聊天历史
    if not user_input:
        return chat_history

    # 打印日志,记录输入参数信息
    logger.info(f"\n用户输入: {user_input}, \n"
                f"历史记录: {chat_history}, \n"
                f"使用模型: {model}, \n"
                f"要生成的最大token数: {max_tokens}\n"
                f"温度: {temperature}\n"
                f"是否流式输出: {stream}")

    # 构建 messages 参数
    messages = user_input  # or [{"role": "user", "content": user_input}]
    if len(chat_history) > 1:
        messages = []
        for chat in chat_history:
            if chat[0] is not None:
                messages.append({"role": "user", "content": chat[0]})
            if chat[1] is not None:
                messages.append({"role": "assistant", "content": chat[1]})
    print(messages)

    # 生成回复
    bot_response = mygpt.get_completion(
        messages, model, max_tokens, temperature, stream)

    if stream:
        # 流式输出
        chat_history[-1][1] = ""
        for character in bot_response:
            character_content = character.choices[0].delta.content
            if character_content is not None:
                chat_history[-1][1] += character_content
                yield chat_history
    else:
        # 非流式输出
        chat_history[-1][1] = bot_response
        logger.info(f"历史记录: {chat_history}")
        yield chat_history


def fn_update_max_tokens(model, origin_set_tokens):
    """
    更新最大令牌数的函数。

    :param model: 要更新最大令牌数的模型。
    :param origin_set_tokens: 原始滑块组件设置的令牌数。
    :return: 包含新最大令牌数的滑块组件。
    """
    # 获取模型对应的新最大令牌数,如果没有设置则使用传入的最大令牌数
    new_max_tokens = MODEL_TO_MAX_TOKENS.get(model)
    new_max_tokens = new_max_tokens if new_max_tokens else origin_set_tokens

    # 如果原始设置的令牌数超过了新的最大令牌数,将其调整为默认值(这里设置为500,你可以根据需要调整)
    new_set_tokens = origin_set_tokens if origin_set_tokens <= new_max_tokens else 500

    # 创建新的最大令牌数滑块组件
    new_max_tokens_component = gr.Slider(
        minimum=0,
        maximum=new_max_tokens,
        value=new_set_tokens,
        step=1.0,
        label="max_tokens",
        interactive=True,
    )

    return new_max_tokens_component


with gr.Blocks() as demo:
    # 标题
    gr.Markdown("# MyGPT")
    with gr.Row(equal_height=True):
        # 左侧对话栏
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(label="聊天机器人")
            user_input_textbox = gr.Textbox(label="用户输入框", value="你好")
            with gr.Row():
                submit_btn = gr.Button("Submit")
                clear_btn = gr.Button("Clear", elem_id="btn")
        # 右侧工具箱
        with gr.Column(scale=1):
            # 创建一个包含三个滑块的选项卡,用于调整模型的温度、最大长度和Top P参数
            with gr.Tab(label="参数"):
                # 选择模型
                model_dropdown = gr.Dropdown(
                    label="model",
                    choices=MODELS,
                    value=DEFAULT_MODEL,
                    multiselect=False,
                    interactive=True,
                )
                max_tokens_slider = gr.Slider(
                    minimum=0,
                    maximum=4096,
                    value=500,
                    step=1.0,
                    label="max_tokens",
                    interactive=True)
                temperature_slider = gr.Slider(
                    minimum=0,
                    maximum=1,
                    value=0.5,
                    step=0.01,
                    label="temperature",
                    interactive=True)
                stream_radio = gr.Radio(
                    choices=[
                        True,
                        False],
                    label="stream",
                    value=True,
                    interactive=True)

    # 模型有改动时,对应的 max_tokens_slider 滑块组件的最大值随之改动。
    # https://www.gradio.app/docs/dropdown
    model_dropdown.change(
        fn=fn_update_max_tokens,
        inputs=[model_dropdown, max_tokens_slider],
        outputs=max_tokens_slider
    )

    # 当用户在文本框处于焦点状态时按 Enter 键时,将触发此侦听器。
    # https://www.gradio.app/docs/textbox
    user_input_textbox.submit(
        fn=fn_prehandle_user_input,
        inputs=[
            user_input_textbox,
            chatbot],
        outputs=[chatbot]
    ).then(
        fn=fn_predict,
        inputs=[
            user_input_textbox,
            chatbot,
            model_dropdown,
            max_tokens_slider,
            temperature_slider,
            stream_radio],
        outputs=[chatbot]
    )

    # 单击按钮时触发。
    # https://www.gradio.app/docs/button
    submit_btn.click(
        fn=fn_prehandle_user_input,
        inputs=[
            user_input_textbox,
            chatbot],
        outputs=[chatbot]
    ).then(
        fn=fn_predict,
        inputs=[
            user_input_textbox,
            chatbot,
            model_dropdown,
            max_tokens_slider,
            temperature_slider,
            stream_radio],
        outputs=[chatbot]
    )

    clear_btn.click(lambda: None, None, chatbot, queue=False)


demo.queue().launch(share=True)

config.py


import os
from dotenv import load_dotenv, find_dotenv

# 加载环境变量
load_dotenv(find_dotenv())

# 获取 OpenAI API 密钥
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']

# 官方文档 - Models:https://platform.openai.com/docs/models
MODELS = [
    # 最新的 GPT-3.5 Turbo 模型,具有改进的指令遵循、JSON 模式、可重现输出、并行函数调用等。最多返回 4,096 个输出标记。
    "gpt-3.5-turbo-1106",
    "gpt-3.5-turbo",  # 当前指向 gpt-3.5-turbo-0613 。自 2023 年 12 月 11  日开始指向gpt-3.5-turbo-1106。
    "gpt-3.5-turbo-16k",  # 当前指向 gpt-3.5-turbo-0613 。将指向gpt-3.5-turbo-1106 2023 年 12 月 11 日开始。
    # "gpt-3.5-turbo-instruct",  # 与 text-davinci-003功能类似,但兼容遗留的 Completions 端点,而不是 Chat Completions。
    # "gpt-4-1106-preview", # 最新的 GPT-4 模型,具有改进的指令跟踪、JSON 模式、可重现输出、并行函数调用等。最多返回 4,096 个输出标记。此预览模型尚不适合生产流量。
    # "gpt-4-vision-preview", # 除了所有其他 GPT-4 Turbo 功能外,还能够理解图像。最多返回 4,096 个输出标记。这是一个预览模型版本,尚不适合生产流量。
    "gpt-4",  # 当前指向 gpt-4-0613。8192 tokens
    # "gpt-4-32k",  # 当前指向 gpt-4-32k-0613。32768 tokens
    # "gpt-4-0613", # 从 2023 年 6 月 13 日开始的 gpt-4 快照,改进了函数调用支持。
    # "gpt-4-32k-0613", # 从 2023 年 6 月 13 日开始的 gpt-4-32k 快照,改进了函数调用支持。
]
DEFAULT_MODEL = MODELS[1]
MODEL_TO_MAX_TOKENS = {
    "gpt-3.5-turbo-1106": 4096,
    "gpt-3.5-turbo": 4096,
    "gpt-3.5-turbo-16k": 16385,
    "gpt-4": 8192,
}

mygpt.py


import openai
from loguru import logger
from config import *


class MyGPT:
    def __init__(self, api_key=OPENAI_API_KEY):
        """
        初始化 MyGPT 类
        :param api_key: 设置 OpenAI API 密钥
        """
        self.client = openai.OpenAI(api_key=api_key)

    def get_completion(
            self,
            messages,
            model=DEFAULT_MODEL,
            max_tokens=200,
            temperature=0,
            stream=False,
    ):
        """
        Creates a model response for the given chat conversation.
        为给定的聊天对话创建模型响应。

        API官方文档:https://platform.openai.com/docs/api-reference/chat/create

        :param messages: 到目前为止,构成对话的消息列表。
        :param model: 要使用的模型的 ID。
        :param max_tokens: 聊天完成时可以生成的最大令牌数。
        :param temperature: 使用什么采样温度,介于 0 和 2 之间。
        较高的值(如 0.8)将使输出更加随机,而较低的值(如 0.2)将使其更具集中性和确定性。
        :param stream: 是否流式输出。
        :return: chat completion object(聊天完成对象),
        如果请求是流式处理的,则返回chat completion chunk(聊天完成区块对象)的流序列。
        """
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        elif not isinstance(messages, list):
            return "无效的 'messages' 类型。它应该是一个字符串或消息列表。"

        response = self.client.chat.completions.create(
            messages=messages,
            model=model,
            max_tokens=max_tokens,
            stream=stream,
            temperature=temperature,
        )

        if stream:
            # 流式输出
            return response

        # 非流式输出
        logger.debug(response.choices[0].message.content)
        logger.info(f"总token数: {response.usage.total_tokens}")
        return response.choices[0].message.content

    def get_completion(
            self,
            messages,
            model=DEFAULT_MODEL,
            max_tokens=200,
            temperature=0,
            stream=False,
    ):
        """
        Creates a model response for the given chat conversation.
        为给定的聊天对话创建模型响应。

        API官方文档:https://platform.openai.com/docs/api-reference/chat/create

        :param messages: 到目前为止,构成对话的消息列表。
        :param model: 要使用的模型的 ID。
        :param max_tokens: 聊天完成时可以生成的最大令牌数。
        :param temperature: 使用什么采样温度,介于 0 和 2 之间。
        较高的值(如 0.8)将使输出更加随机,而较低的值(如 0.2)将使其更具集中性和确定性。
        :param stream: 是否流式输出。
        :return: chat completion object(聊天完成对象),
        如果请求是流式处理的,则返回chat completion chunk(聊天完成区块对象)的流序列。
        """
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        elif not isinstance(messages, list):
            return "无效的 'messages' 类型。它应该是一个字符串或消息列表。"

        response = self.client.chat.completions.create(
            messages=messages,
            model=model,
            max_tokens=max_tokens,
            stream=stream,
            temperature=temperature,
        )

        if stream:
            # 流式输出
            return response

        # 非流式输出
        logger.debug(response.choices[0].message.content)
        logger.info(f"总token数: {response.usage.total_tokens}")
        return response.choices[0].message.content

    def get_embedding(self, input):
        """
        Creates an embedding vector representing the input text.
        创建表示输入文本的嵌入向量。

        API官方文档:https://platform.openai.com/docs/api-reference/embeddings/create

        :param input: 输入要嵌入的文本,编码为字符串或标记数组。若要在单个请求中嵌入多个输入,请传递字符串数组或令牌数组数组。
        输入不得超过模型的最大输入标记数(8192 text-embedding-ada-002 个标记),不能为空字符串,任何数组的维数必须小于或等于 2048。
        :return: 嵌入对象的列表。
        """
        response = self.client.embeddings.create(
            input=input,
            model='text-embedding-ada-002',
        )
        embeddings = [data.embedding for data in response.data]
        return embeddings


if __name__ == "__main__":
    # 测试
    mygpt = MyGPT()

    # prompt
    prompt = '你好'
    response = mygpt.get_completion(prompt, temperature=1)
    print(response)

    # # messages
    # messages = [
    #     {'role': 'user', 'content': '什么是大模型'},
    # ]
    # response = mygpt.get_completion(messages, temperature=1)
    # print(response)

    # vectors = mygpt.get_embedding("input text")
    # print(len(vectors), len(vectors[0]))
    # # 1 1536
    #
    # vectors = mygpt.get_embedding(["input text 1", "input text 2"])
    # print(len(vectors), len(vectors[0]))
    # # 2 1536

二 构建向量数据库

2.1 文本拆分-英文

Split by Sentence
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

text = ("The Earth's atmosphere is composed of layers, including the troposphere, "
        "stratosphere, mesosphere, thermosphere, and exosphere. The troposphere is "
        "the lowest layer where all weather takes place and contains 75% of the atmosphere's mass. "
        "Above this, the stratosphere contains the ozone layer, which protects the Earth "
        "from harmful ultraviolet radiation.")

# Split the text into sentences
chunks = sent_tokenize(text)

for i, chunk in enumerate(chunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

Fixed length chunks
def fixed_length_chunks(text, chunk_size):
    return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]

chunks = fixed_length_chunks(text, 100)  # 假设我们想要100个字符的块

for i, chunk in enumerate(chunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

Chunks with overlapping window
def sliding_window_chunks(text, chunk_size, stride):
    return [text[i:i+chunk_size] for i in range(0, len(text), stride)]

chunks = sliding_window_chunks(text, 100, 50)  # 100个字符的块,步长为50

for i, chunk in enumerate(chunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

RecursiveCharacterTextSplitter from langchain
# pip install langchain may be required
from langchain.text_splitter import RecursiveCharacterTextSplitter

text = "The Earth's atmosphere is a layer of gases surrounding the planet Earth and retained by Earth's gravity. It contains roughly 78% nitrogen and 21% oxygen, with trace amounts of other gases. The atmosphere protects life on Earth by absorbing ultraviolet solar radiation and reducing temperature extremes between day and night."

splitter = RecursiveCharacterTextSplitter(
    chunk_size = 150,
    chunk_overlap = 20,
    length_function = len,
)
trunks = splitter.split_text(text)
for i, chunk in enumerate(trunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

2.2 文本拆分-中文

# 按照sentence来切分
#  %pip install jieba
import re

text = "在这里,我们有一段超过200字的中文文本作为输入例子。这段文本是关于自然语言处理的简介。自然语言处理(NLP)是计算机科学、人工智能和语言学的交叉领域,它旨在让计算机能够理解和处理人类语言。在这一领域中,机器学习技术扮演着核心角色。通过使用各种算法,计算机可以解析、理解、甚至生成人类可以理解的语言。这一技术已广泛应用于机器翻译、情感分析、自动摘要、实体识别等多个方面。随着深度学习技术的发展,自然语言处理的准确性和效率都得到了显著提升。当前,一些高级的NLP系统已经能够完成复杂的语言理解任务,例如问答系统、语音识别和对话系统等。自然语言处理的研究不仅有助于改善人机交互,而且对于提高机器的自主性和智能化水平也具有重要意义。"
# 正则表达式匹配中文句子结束的标点符号
sentences = re.split(r'(。|?|!|\…\…)', text)
# 重新组合句子和结尾的标点符号
chunks = [sentence + (punctuation if punctuation else '') for sentence, punctuation in zip(sentences[::2], sentences[1::2])]
for i, chunk in enumerate(chunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

# 按照固定字符数切分
def split_by_fixed_char_count(text, count):
    return [text[i:i+count] for i in range(0, len(text), count)]

# 假设我们按照每100个字符来切分文本
chunks = split_by_fixed_char_count(text, 100)
for i, chunk in enumerate(chunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

# 按照固定sentence数切分

def split_by_fixed_sentence_count(sentences, count):
    return [sentences[i:i+count] for i in range(0, len(sentences), count)]

# 假设我们按照每5个句子来切分文本
chunks = split_by_fixed_sentence_count(sentences, 5)

for i, chunk in enumerate(chunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

from langchain.text_splitter import RecursiveCharacterTextSplitter

text = """
在这里,我们有一段超过200字的中文文本作为输入例子。这段文本是关于自然语言处理的简介。自然语言处理(NLP)是计算机科学、人工智能和语言学的交叉领域,它旨在让计算机能够理解和处理人类语言。在这一领域中,机器学习技术扮演着核心角色。通过使用各种算法,计算机可以解析、理解、甚至生成人类可以理解的语言。这一技术已广泛应用于机器翻译、情感分析、自动摘要、实体识别等多个方面。随着深度学习技术的发展,自然语言处理的准确性和效率都得到了显著提升。当前,一些高级的NLP系统已经能够完成复杂的语言理解任务,例如问答系统、语音识别和对话系统等。自然语言处理的研究不仅有助于改善人机交互,而且对于提高机器的自主性和智能化水平也具有重要意义。
"""

splitter = RecursiveCharacterTextSplitter(
    chunk_size = 150,
    chunk_overlap = 0,
    length_function = len,
)

trunks = splitter.split_text(text)
for i, chunk in enumerate(trunks):
    print(f"块 {i+1}: {len(chunk)}: {chunk}")

2.3 文本向量化

vector_embedding_and_similarity

from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())  # 读取本地 .env 文件,里面定义了 OPENAI_API_KEY
client = OpenAI()

def get_embedding(text, model="text-embedding-ada-002"):
   text = text.replace("\n", " ")
   return client.embeddings.create(input = [text], model=model).data[0].embedding
import numpy as np

def cosine_similarity(A, B):
    dot_product = np.dot(A, B)
    norm_A = np.linalg.norm(A)
    norm_B = np.linalg.norm(B)
    return dot_product / (norm_A * norm_B)
emb1 = get_embedding("大模型的应用场景很多")
emb2 = get_embedding("大模型")
emb3 = get_embedding("大模型有很多应用场景")
emb4 = get_embedding("Java开发")

cosine_similarity(emb1, emb2)   # 0.9228796051583866
cosine_similarity(emb1, emb4)   # 0.796571635600008

2.4 读取文件写入向量数据库


import json
import os
from ast import literal_eval

from utils import *
from db_qdrant import Qdrant


def preprocess_data(df_news):
    # 数据预处理块

    # 将包含字符串表示的列表转换为实际的列表
    # pd.notna(x) 检查x是否为非缺失值(即不是NaN),确保不对缺失值进行转换。
    # literal_eval(x) 是一个安全的方式来将字符串转换为相应的Python对象
    df_news['title_entities'] = df_news['title_entities'].apply(
        lambda x: literal_eval(x) if pd.notna(x) else [])
    df_news['abstract_entities'] = df_news['abstract_entities'].apply(
        lambda x: literal_eval(x) if pd.notna(x) else [])

    # 使用空字符串填充其他列的 NaN 值
    df_news = df_news.fillna('')

    # 新增 news_info 列,合并`类别、子类别、标题、摘要`字符串
    concatenation_order = ["category", "sub_category", "title", "abstract"]
    df_news['news_info'] = df_news.apply(lambda row: ' '.join(
        f"[{col}]:{row[col]}" for col in concatenation_order), axis=1)
    news_info_list = df_news['news_info'].values.tolist()
    logger.trace(
        f"新增 news_info 列 | len(news_info_list): {len(news_info_list)}")
    return df_news, news_info_list


def store_embeddings_to_json(embeddings, ids, payloads, file_path):
    # 存储嵌入为 JSON 文件
    json_data = {
        "batch_ids": ids,
        "batch_embeddings": embeddings,
        "batch_payloads": payloads
    }
    with open(file_path, 'w') as json_file:
        json.dump(json_data, json_file)


def compute_and_store_embeddings(data_list, embedding_folder, batch_size=1000):
    # 嵌入计算和存储块

    # 分批次向量化
    ids = list(range(1, len(data_list) + 1))  # 生成递增的 ids 列表

    for batch_idx, i in enumerate(range(0, len(data_list), batch_size)):
        # 获取批次数据 batch_ids、batch_payloads
        batch_ids = ids[i:i + batch_size]
        df_news_batch = df_news.iloc[i:i + batch_size]
        batch_payloads = df_news_batch.to_dict(orient='records')

        # 计算嵌入 batch_embeddings
        batch_data = data_list[i:i + batch_size]
        batch_embeddings = embed_sentences(batch_data)

        # 存储为 JSON 文件
        file_path = os.path.join(
            embedding_folder,
            f"batch_{batch_idx + 1}.json")
        store_embeddings_to_json(
            batch_embeddings,
            batch_ids,
            batch_payloads,
            file_path)

        # 打印存储信息
        logger.info(f"批次 {batch_idx} 数据存储成功,文件路径: {file_path}")


def load_embeddings_and_save_to_qdrant(
        collection_name,
        embedding_folder,
        batch_size):
    # 加载嵌入和存储到向量数据库

    qdrant = Qdrant()

    # 创建新的集合
    if qdrant.create_collection(collection_name):
        logger.success(f"创建集合成功 | collection_name: {collection_name}")
    else:
        logger.error(f"创建集合失败 | collection_name: {collection_name}")

    # 分批次存储到向量数据库
    for batch_idx, i in enumerate(range(0, len(news_info_list), batch_size)):
        # 读取 JSON 文件
        file_path = os.path.join(
            embedding_folder,
            f"batch_{batch_idx + 1}.json")
        if os.path.exists(file_path):
            with open(file_path, 'r') as json_file:
                json_data = json.load(json_file)

                batch_ids = json_data["batch_ids"]
                batch_embeddings = json_data["batch_embeddings"]
                batch_payloads = json_data["batch_payloads"]

                # 插入数据到 Qdrant
                if qdrant.add_points(
                        collection_name,
                        batch_ids,
                        batch_embeddings,
                        batch_payloads):
                    logger.success(f"批次 {batch_idx + 1} 插入成功")
                else:
                    logger.error(f"批次 {batch_idx + 1} 插入失败")
        else:
            logger.warning(f"文件 {file_path} 不存在,跳过该批次数据的插入。")

    logger.info("所有数据插入完成。")


# 读取新闻数据
df_news = get_df_news()

# 数据预处理
df_news, news_info_list = preprocess_data(df_news)

# 指定存储 embeddings 的文件夹路径
embedding_folder = 'embeddings_folder'
os.makedirs(embedding_folder, exist_ok=True)

# 计算和存储嵌入
compute_and_store_embeddings(news_info_list, embedding_folder, 1000)

# 加载嵌入和存储到向量数据库
collection_name = "all_news"
# load_embeddings_and_save_to_qdrant(collection_name, embedding_folder, 1000)

2.5 构建向量数据库类


from loguru import logger
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, Batch
from qdrant_client.http.exceptions import UnexpectedResponse  # 捕获错误信息

from config import QDRANT_HOST, QDRANT_PORT, QDRANT_EMBEDDING_DIMS


class Qdrant:
    def __init__(self):
        self.client = QdrantClient(
            host=QDRANT_HOST,
            port=QDRANT_PORT)  # 创建客户端实例
        self.size = QDRANT_EMBEDDING_DIMS  # openai embedding 维度 = 1536

    def get_points_count(self, collection_name):
        """
        先检查集合是否存在。
        - 如果集合存在,返回该集合的 points_count (集合中确切的points_count)
        - 如果集合不存在,创建集合。
            - 创建集合成功,则返回 points_count (0: 刚创建完points_count就是0)
            - 创建集合失败,则返回 points_count (-1: 创建失败了,定义points_count为-1)

        Returns:
            points_count

        Raises:
            UnexpectedResponse: 如果在获取集合信息时发生意外的响应。
            ValueError: Collection test_collection not found
        """
        try:
            collection_info = self.get_collection(collection_name)
        except (UnexpectedResponse, ValueError) as e:  # 集合不存在,创建新的集合
            if self.create_collection(collection_name):
                logger.success(
                    f"创建集合成功 | collection_name: {collection_name} points_count: 0")
                return 0
            else:
                logger.error(
                    f"创建集合失败 | collection_name: {collection_name} 错误信息:{e}")
                return -1
        except Exception as e:
            logger.error(
                f"获取集合信息时发生错误 | collection_name: {collection_name} 错误信息:{e}")
            return -1  # 返回错误码或其他适当的值
        else:
            points_count = collection_info.points_count
            logger.success(
                f"库里已有该集合 | collection_name: {collection_name} points_count:{points_count}")
            return points_count

    def list_all_collection_names(self):
        """
        CollectionsResponse类型举例:
        CollectionsResponse(collections=[
            CollectionDescription(name='GreedyAIEmployeeHandbook'),
            CollectionDescription(name='python')
        ])
        CollectionsResponse(collections=[])
        """
        CollectionsResponse = self.client.get_collections()
        collection_names = [
            CollectionDescription.name for CollectionDescription in CollectionsResponse.collections]
        return collection_names

    # 获取集合信息
    def get_collection(self, collection_name):
        """
        获取集合信息。

        Args:
            collection_name (str, optional): 自定义的集合名称。如果未提供,则使用默认的self.collection_name。

        Returns:
            collection_info: 集合信息。
        """
        collection_info = self.client.get_collection(
            collection_name=collection_name)
        return collection_info

    # 创建集合
    def create_collection(self, collection_name) -> bool:
        """
        创建集合。

        Args:
            collection_name (str, optional): 自定义的集合名称。如果未提供,则使用默认的self.collection_name。

        Returns:
            bool: 如果成功创建集合,则返回True;否则返回False。
        """

        return self.client.recreate_collection(
            collection_name=collection_name, vectors_config=VectorParams(
                size=self.size, distance=Distance.COSINE), )

    def add_points(self, collection_name, ids, vectors, payloads):
        # 将数据点添加到Qdrant
        self.client.upsert(
            collection_name=collection_name,
            wait=True,
            points=Batch(
                ids=ids,
                payloads=payloads,
                vectors=vectors
            )
        )
        return True

    # 搜索
    def search(self, collection_name, query_vector, limit=3):
        return self.client.search(
            collection_name=collection_name,
            query_vector=query_vector,
            limit=limit,
            with_payload=True
        )

    def search_with_query_filter(
            self,
            collection_name,
            query_vector,
            query_filter,
            limit=3):
        """
        根据向量相似度和指定的过滤条件,在集合中搜索最相似的points。
        API 文档:https://qdrant.github.io/qdrant/redoc/index.html#tag/points/operation/search_points
        :param collection_name:要搜索的集合名称
        :param query_vector:用于相似性比较的向量
        :param query_filter:过滤条件
        :param limit:要返回的结果的最大数量
        :return:
        """
        return self.client.search(
            collection_name=collection_name,
            query_filter=query_filter,
            query_vector=query_vector,
            limit=limit,
            with_payload=True
        )


if __name__ == "__main__":
    qdrant = Qdrant()

    # 创建集合
    # collection_name = "test"

    # 获取集合信息
    # qdrant.get_collection(collection_name)
    # 如果之前没有创建集合,则会报以下错误
    # qdrant_client.http.exceptions.UnexpectedResponse: Unexpected Response: 404 (Not Found)
    # Raw response content:
    # b'{"status":{"error":"Not found: Collection `test` doesn\'t exist!"},"time":0.000198585}'

    # 获取集合信息,如果没有该集合则创建
    collection_name = "all_news"
    count = qdrant.get_points_count(collection_name)
    print(count)
    # 如果之前没有创建集合,且正确创建了该集合,则输出0。例:创建集合成功。集合名:test。节点数量:0。
    # 如果之前创建了该集合,则输出该集合内部的节点数量。例:库里已有该集合。集合名:test。节点数量:0。

    # 删除集合
    # collection_name = "test"
    # qdrant.client.delete_collection(collection_name)

2.6 计算向量相似度

from sentence_transformers import SentenceTransformer


texts1 = ["胡子长得太快怎么办?", "在香港哪里买手表好"]
texts2 = ["胡子长得快怎么办?", "怎样使胡子不浓密!", "香港买手表哪里好", "在杭州手机到哪里买"]

model = SentenceTransformer('../Dmeta-embedding')  # DMetaSoul/Dmeta-embedding
print("Use pytorch device: {}".format(model.device))

# 使用 GPU 加载模型
# import torch
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print("Use pytorch device: {}".format(device))
# model = SentenceTransformer('../Dmeta-embedding', device=device)

embs1 = model.encode(texts1, normalize_embeddings=True)
embs2 = model.encode(texts2, normalize_embeddings=True)

# 计算两两相似度
similarity = embs1 @ embs2.T
print(similarity)

# 获取 texts1[i] 对应的最相似 texts2[j]
for i in range(len(texts1)):
    scores = []
    for j in range(len(texts2)):
        scores.append([texts2[j], similarity[i][j]])
    scores = sorted(scores, key=lambda x: x[1], reverse=True)

    print(f"查询文本:{texts1[i]}")
    for text2, score in scores:
        print(f"相似文本:{text2},打分:{score}")
    print()

2.7 推荐系统案例


from utils import *
from db_qdrant import Qdrant
from qdrant_client.http import models

# 获取数据
df_behaviors_sample = get_df_behaviors_sample()
df_news = get_df_news()

# 循环 df_behaviors_sample 的每一行
for _, row in df_behaviors_sample.iterrows():
    user_id = row['user_id']
    click_history = row['click_history'].split()

    # 召回

    # 生成历史交互字符串 historical_records_str
    historical_records = generate_historical_records(df_news, click_history)
    historical_records_str = '\n'.join(historical_records)
    logger.info(
        f"历史交互字符串 | historical_records_str: \n{historical_records_str}")

    # 生成用户画像 user_profile
    user_profile = generate_user_profile(historical_records_str)

    # 向量化用户画像 user_emb
    user_emb = embed_sentences([user_profile])[0]

    # 过滤条件 query_filter
    # 统计出当前用户的(新闻类别,新闻子类别)偏好组合
    user_favorite_combinations = get_user_favorite_combinations(
        click_history, df_news)

    should_items = []
    for category, sub_category in user_favorite_combinations:
        should_item = models.Filter(
            must=[
                models.FieldCondition(
                    key="category",
                    match=models.MatchValue(
                        value=category,
                    )
                ),
                models.FieldCondition(
                    key="sub_category",
                    match=models.MatchValue(
                        value=sub_category,
                    )
                )
            ]
        )

        should_items.append(should_item)

    query_filter = models.Filter(
        should=should_items
    )

    # 使用 Qdrant 查询与用户画像字符串最相似的 news 列表
    qdrant = Qdrant()
    scored_point_list = qdrant.search_with_query_filter(
        "all_news", user_emb, query_filter, 20)
    coarse_top_news = [
        scored_point.payload for scored_point in scored_point_list]
    logger.info(f"len(top_news): {len(coarse_top_news)}")

    if coarse_top_news:
        # 排序
        coarse_top_news_str = '\n'.join(
            [f"{idx}. {news}" for idx, news in enumerate(coarse_top_news)])
        fine_top_news = fine_ranking(
            user_profile,
            historical_records_str,
            coarse_top_news_str,
            5)

        for idx in fine_top_news:
            news = coarse_top_news[int(idx)]
            logger.success(int(idx))
            logger.success(news)
    break

结合向量数据库内容构建prompt

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import openai

# 初始化模型和向量数据库
model = SentenceTransformer('all-MiniLM-L6-v2')  # 轻量级嵌入模型
index = faiss.read_index("path/to/your_vector_index.faiss")  # 加载预存索引
documents = ["doc1 text...", "doc2 text...", ...]  # 假设已加载原始文本

def retrieve_context(query: str, top_k: int = 3) -> str:
    """检索相关上下文"""
    query_embedding = model.encode([query])  # 编码为向量
    distances, indices = index.search(query_embedding, top_k)  # FAISS 搜索
    return "\n".join([documents[i] for i in indices[0]])

def generate_answer(query: str, context: str) -> str:
    """调用大模型生成回答"""
    prompt = f"""根据以下上下文回答问题:
{context}

问题:{query}
答案:"""
    response = openai.Completion.create(
        engine="text-davinci-003",
        prompt=prompt,
        max_tokens=500,
        temperature=0.3
    )
    return response.choices[0].text.strip()

# 用户输入处理
user_query = "如何学习机器学习?"
context = retrieve_context(user_query)
answer = generate_answer(user_query, context)
print(answer)

相关文章:

  • 数据服务赋能数据治理:从“One Service”到QuickAPI的演进
  • redis操作
  • 【搜索】dfs(回溯、剪枝、记忆化)
  • 【C++】类和对象(二)默认成员函数之构造函数、析构函数
  • Springboot集成Debezium监听postgresql变更
  • CQL学习
  • 游戏引擎学习第177天
  • 996引擎-接口测试:背包
  • pnpm 报错 Error: Cannot find matching keyid 解决
  • Mybatis的基础操作——03
  • 西交建筑学本科秋天毕业想转码,自学了Python+408,华为OD社招还是考研更香?
  • 第十四章:模板实例化_《C++ Templates》notes
  • 如何编写SLURM系统的GRES资源插件
  • Lustre 语言的 Rust 生成相关的工作
  • Autosar OS配置-Timing Protection配置及实现--基于ETAS工具
  • 题单:精挑细选
  • 生物化学笔记:医学免疫学原理02 抗原概念+免疫应答+抗原的分类
  • SQL语言——MySQL
  • MuJoCo 仿真 Panda 机械臂!末端位置实时追踪 + 可视化(含缩放交互)
  • 系统架构书单推荐(一)领域驱动设计与面向对象
  • 西班牙葡萄牙突发全国大停电,欧洲近年来最严重停电事故何以酿成
  • 中方会否公布铁线礁的领海基线?外交部:中方执法活动旨在反制菲方侵权挑衅
  • 国家卫健委:工作相关肌肉骨骼疾病、精神和行为障碍成职业健康新挑战
  • 程璧“自由生长”,刘卓辉“被旋律牵着走”
  • 中国海警局新闻发言人就菲律宾非法登临铁线礁发表谈话
  • 当代视角全新演绎,《风雪夜归人》重归首都剧场