用Function Calling让GPT查询数据库(含示例)
在本文中,我们通过一个简单的示例,介绍了 GPT模型结合Function Calling技术查询数据库 的基本流程。
Function Calling 是OpenAI推出的一项功能,允许大模型根据用户提问,自动生成函数调用指令,并由程序端实际执行外部操作(如数据库查询、API调用),再将结果返回给模型,最终组织成自然语言回复用户。
主要实现步骤包括:
1、定义数据库结构:描述当前可查询的数据库表和字段。
2、注册可调用的函数:告诉GPT模型有哪些可以用的函数,例如ask_database函数,用于执行SQL查询。
3、发送用户提问:将用户的问题封装为消息,提交给GPT模型。
4、模型生成调用指令:GPT识别问题,自动生成对应SQL语句,并以Function Calling的形式返回。
5、实际执行查询:程序接收指令,调用本地数据库执行查询,拿到真实数据。
6、将查询结果反馈给模型:GPT根据查询结果,组织自然语言进行回复。
这种模式大大拓展了大模型的应用场景,让AI不仅能理解问题,还能结合真实世界数据,完成更复杂、实时的任务处理。
import sqlite3
import json
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())# 初始化OpenAI客户端
client = OpenAI()# JSON 打印工具
# 更安全的 JSON 打印工具
def print_json(data):def default(o):if hasattr(o, "model_dump"):return o.model_dump()elif hasattr(o, "__dict__"):return o.__dict__else:return str(o)print(json.dumps(data, indent=2, ensure_ascii=False, default=default))# 描述数据库表结构
database_schema_string = """
CREATE TABLE orders (id INT PRIMARY KEY NOT NULL, -- 主键,不允许为空customer_id INT NOT NULL, -- 客户ID,不允许为空product_id STR NOT NULL, -- 产品ID,不允许为空price DECIMAL(10,2) NOT NULL, -- 价格,不允许为空status INT NOT NULL, -- 订单状态,整数类型,不允许为空。0代表待支付,1代表已支付,2代表已退款create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 创建时间,默认为当前时间pay_time TIMESTAMP -- 支付时间,可以为空
);
"""def get_sql_completion(messages, model="gpt-4o-mini"):response = client.chat.completions.create(model=model,messages=messages,temperature=0,tools=[{ # 摘自 OpenAI 官方示例 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_with_chat_models.ipynb"type": "function","function": {"name": "ask_database","description": "Use this function to answer user questions about business. \Output should be a fully formed SQL query.","parameters": {"type": "object","properties": {"query": {"type": "string","description": f"""SQL query extracting info to answer the user's question.SQL should be written using this database schema:{database_schema_string}The query should be returned in plain text, not in JSON.The query should only contain grammars supported by SQLite.""",}},"required": ["query"],}}}],)return response.choices[0].messageimport sqlite3# 创建数据库连接
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()# 创建orders表
cursor.execute(database_schema_string)# 插入5条明确的模拟记录
mock_data = [(1, 1001, 'TSHIRT_1', 50.00, 0, '2023-09-12 10:00:00', None),(2, 1001, 'TSHIRT_2', 75.50, 1, '2023-09-16 11:00:00', '2023-08-16 12:00:00'),(3, 1002, 'SHOES_X2', 25.25, 2, '2023-10-17 12:30:00', '2023-08-17 13:00:00'),(4, 1003, 'SHOES_X2', 25.25, 1, '2023-10-17 12:30:00', '2023-08-17 13:00:00'),(5, 1003, 'HAT_Z112', 60.75, 1, '2023-10-20 14:00:00', '2023-08-20 15:00:00'),(6, 1002, 'WATCH_X001', 90.00, 0, '2023-10-28 16:00:00', None)
]for record in mock_data:cursor.execute('''INSERT INTO orders (id, customer_id, product_id, price, status, create_time, pay_time)VALUES (?, ?, ?, ?, ?, ?, ?)''', record)# 提交事务
conn.commit()def ask_database(query):cursor.execute(query)records = cursor.fetchall()return recordsprompt = "10月的销售额"
# prompt = "统计每月每件商品的销售额"
# prompt = "哪个用户消费最高?消费多少?"messages = [{"role": "system", "content": "你是一个数据分析师,基于数据库的数据回答问题"},{"role": "user", "content": prompt}
]
response = get_sql_completion(messages)
if response.content is None:response.content = ""
messages.append(response)
print("====Function Calling====")
print_json(response)if response.tool_calls is not None:tool_call = response.tool_calls[0]if tool_call.function.name == "ask_database":arguments = tool_call.function.argumentsargs = json.loads(arguments)print("====SQL====")print(args["query"])result = ask_database(args["query"])print("====DB Records====")print(result)messages.append({"tool_call_id": tool_call.id,"role": "tool","name": "ask_database","content": str(result)})response = get_sql_completion(messages)messages.append(response)print("====最终回复====")print(response.content)print("=====对话历史=====")
print_json(messages)
在这一部分示例中,我们演示了如何通过GPT的Function Calling技术,实现跨多张表的联合查询。
与简单查询单张表不同,多表查询通常涉及外键关联、多表联动分析,需要生成更复杂的SQL语句(比如 JOIN、聚合统计等)。
通过在系统消息中描述好数据库的多表结构,并注册可以生成SQL的 ask_database 函数,大模型能够根据自然语言提问,自动推理出正确的跨表SQL查询语句,极大提升了查询的智能化程度。
import sqlite3
import json
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())# 初始化OpenAI客户端
client = OpenAI()# JSON 打印工具
# 更安全的 JSON 打印工具
def print_json(data):def default(o):if hasattr(o, "model_dump"):return o.model_dump()elif hasattr(o, "__dict__"):return o.__dict__else:return str(o)print(json.dumps(data, indent=2, ensure_ascii=False, default=default))# 描述数据库表结构
database_schema_string = """
CREATE TABLE customers (id INT PRIMARY KEY NOT NULL, -- 主键,不允许为空customer_name VARCHAR(255) NOT NULL, -- 客户名,不允许为空email VARCHAR(255) UNIQUE, -- 邮箱,唯一register_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- 注册时间,默认为当前时间
);
CREATE TABLE products (id INT PRIMARY KEY NOT NULL, -- 主键,不允许为空product_name VARCHAR(255) NOT NULL, -- 产品名称,不允许为空price DECIMAL(10,2) NOT NULL -- 价格,不允许为空
);
CREATE TABLE orders (id INT PRIMARY KEY NOT NULL, -- 主键,不允许为空customer_id INT NOT NULL, -- 客户ID,不允许为空product_id INT NOT NULL, -- 产品ID,不允许为空price DECIMAL(10,2) NOT NULL, -- 价格,不允许为空status INT NOT NULL, -- 订单状态,整数类型,不允许为空。0代表待支付,1代表已支付,2代表已退款create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- 创建时间,默认为当前时间pay_time TIMESTAMP -- 支付时间,可以为空
);
"""def get_sql_completion(messages, model="gpt-4o-mini"):response = client.chat.completions.create(model=model,messages=messages,temperature=0,tools=[{ # 摘自 OpenAI 官方示例 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_with_chat_models.ipynb"type": "function","function": {"name": "ask_database","description": "Use this function to answer user questions about business. \Output should be a fully formed SQL query.","parameters": {"type": "object","properties": {"query": {"type": "string","description": f"""SQL query extracting info to answer the user's question.SQL should be written using this database schema:{database_schema_string}The query should be returned in plain text, not in JSON.The query should only contain grammars supported by SQLite.""",}},"required": ["query"],}}}],)return response.choices[0].messageprompt = "统计每月每件商品的销售额"
prompt = "这星期消费最高的用户是谁?他买了哪些商品? 每件商品买了几件?花费多少?"
messages = [{"role": "system", "content": "你是一个数据分析师,基于数据库中的表回答用户问题"},{"role": "user", "content": prompt}
]
response = get_sql_completion(messages)
sql = json.loads(response.tool_calls[0].function.arguments)["query"]
print(sql)