From d0f827e9df1c3ab37683a273c9aae5bb3bd823c0 Mon Sep 17 00:00:00 2001 From: ulleo Date: Tue, 7 Apr 2026 19:05:28 +0800 Subject: [PATCH] feat: improve generate sql quality --- backend/apps/chat/models/chat_model.py | 68 +++++++++++++++---- backend/apps/chat/task/llm.py | 94 ++++++++++++++++++++++---- backend/templates/template.yaml | 40 +++++++---- frontend/src/entity/supplier.ts | 4 ++ 4 files changed, 164 insertions(+), 42 deletions(-) diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index dfa2ea01..14ab59fd 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -1,8 +1,9 @@ from datetime import datetime from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Any, Union from fastapi import Body +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from pydantic import BaseModel from sqlalchemy import Column, Integer, Text, BigInteger, DateTime, Identity, Boolean from sqlalchemy import Enum as SQLAlchemyEnum @@ -230,6 +231,7 @@ class AiModelQuestion(BaseModel): regenerate_record_id: Optional[int] = None def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = True): + templates: dict[str, str] = {} _sql_template = get_sql_example_template(db_type) _base_template = get_sql_template() _process_check = _sql_template.get('process_check') if _sql_template.get('process_check') else _base_template[ @@ -245,22 +247,37 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T 'example_answer_2'] _example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[ 'example_answer_3'] - return _base_template['system'].format(engine=self.engine, schema=self.db_schema, question=self.question, - lang=self.lang, terminologies=self.terminologies, - data_training=self.data_training, custom_prompt=self.custom_prompt, - process_check=_process_check, - base_sql_rules=_base_sql_rules, - basic_sql_examples=_sql_examples, - example_engine=_example_engine, - example_answer_1=_example_answer_1, - example_answer_2=_example_answer_2, - example_answer_3=_example_answer_3) + + templates['system'] = _base_template['system'].format(process_check=_process_check) + templates['rules'] = _base_template['generate_rules'].format(lang=self.lang, + base_sql_rules=_base_sql_rules, + basic_sql_examples=_sql_examples, + example_engine=_example_engine, + example_answer_1=_example_answer_1, + example_answer_2=_example_answer_2, + example_answer_3=_example_answer_3) + templates['schema'] = _base_template['generate_basic_info'].format(engine=self.engine, schema=self.db_schema) + + if self.terminologies: + templates['terminologies'] = _base_template['generate_terminologies_info'].format( + terminologies=self.terminologies) + + if self.data_training: + templates['data_training'] = _base_template['generate_data_training_info'].format( + data_training=self.data_training) + + if self.custom_prompt: + templates['custom_prompt'] = _base_template['generate_custom_prompt_info'].format( + custom_prompt=self.custom_prompt) + + return templates def sql_user_question(self, current_time: str, change_title: bool): _question = self.question if self.regenerate_record_id: _question = get_sql_template()['regenerate_hint'] + self.question - return get_sql_template()['user'].format(engine=self.engine, schema=self.db_schema, question=_question, + return get_sql_template()['user'].format(lang=self.lang, engine=self.engine, schema=self.db_schema, + question=_question, rule=self.rule, current_time=current_time, error_msg=self.error_msg, change_title=change_title) @@ -358,3 +375,30 @@ class McpAssistant(BaseModel): url: str = Body(description='第三方数据接口') authorization: str = Body(description='第三方接口凭证') stream: Optional[bool] = Body(description='是否流式输出,默认为true开启, 关闭false则返回JSON对象', default=True) + + +class SystemPromptMessage(SystemMessage): + sqlbot_system: bool = True + + def __init__( + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any + ) -> None: + super().__init__(content=content, **kwargs) + + +class HumanPromptMessage(HumanMessage): + sqlbot_system: bool = True + + def __init__( + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any + ) -> None: + super().__init__(content=content, **kwargs) + + +class AIPromptMessage(AIMessage): + sqlbot_system: bool = True + + def __init__( + self, content: Union[str, list[Union[str, dict]]], **kwargs: Any + ) -> None: + super().__init__(content=content, **kwargs) diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index ba6540a0..b59efb37 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -14,7 +14,7 @@ import sqlparse from langchain.chat_models.base import BaseChatModel from langchain_community.utilities import SQLDatabase -from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage, AIMessage, BaseMessageChunk +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, BaseMessageChunk from sqlalchemy import and_, select from sqlalchemy.orm import sessionmaker, scoped_session from sqlbot_xpack.config.model import SysArgModel @@ -33,7 +33,7 @@ get_last_execute_sql_error, format_json_data, format_chart_fields, get_chat_brief_generate, get_chat_predict_data, \ get_chat_chart_config, trigger_log_error from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \ - ChatFinishStep, AxisObj + ChatFinishStep, AxisObj, SystemPromptMessage, HumanPromptMessage, AIPromptMessage from apps.data_training.curd.data_training import get_training_template from apps.datasource.crud.datasource import get_table_schema from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user @@ -216,12 +216,31 @@ def init_messages(self, session: Session): filter(lambda obj: obj.pid == self.chat_question.regenerate_record_id, self.generate_sql_logs), None) last_sql_messages: List[dict[str, Any]] = _temp_log.messages if _temp_log else [] + # 排除所有的系统提示词 + last_sql_messages = [obj for obj in last_sql_messages if obj.get("sqlbot_system") != True] + count_limit = self.base_message_round_count_limit self.sql_message = [] # add sys prompt - self.sql_message.append(SystemMessage( - content=self.chat_question.sql_sys_question(self.ds.type, self.enable_sql_row_limit))) + _system_templates = self.chat_question.sql_sys_question(self.ds.type, self.enable_sql_row_limit) + self.sql_message.append(SystemPromptMessage(content=_system_templates['system'])) + self.sql_message.append(HumanPromptMessage(content=_system_templates['rules'])) + self.sql_message.append( + AIPromptMessage(content='我已掌握所有规则,包括表结构、SQL规范、安全限制和输出格式,我会严格遵守这些规则。')) + self.sql_message.append(HumanPromptMessage(content=_system_templates['schema'])) + self.sql_message.append( + AIPromptMessage(content='我已确认您提供的数据库信息与表结构schema,我生成的SQL不会超出您提供的范围。')) + if _system_templates.get('custom_prompt'): + self.sql_message.append(HumanPromptMessage(content=_system_templates['custom_prompt'])) + self.sql_message.append(AIPromptMessage(content='我已确认您提供的额外信息,我会进行参考。')) + if _system_templates.get('terminologies'): + self.sql_message.append(HumanPromptMessage(content=_system_templates['terminologies'])) + self.sql_message.append(AIPromptMessage(content='我已确认您提供的术语信息,我会进行参考。')) + if _system_templates.get('data_training'): + self.sql_message.append(HumanPromptMessage(content=_system_templates['data_training'])) + self.sql_message.append(AIPromptMessage(content='我已确认您提供的SQL示例,我会进行参考。')) + if last_sql_messages is not None and len(last_sql_messages) > 0: last_rounds = get_last_conversation_rounds(last_sql_messages, rounds=count_limit) @@ -246,7 +265,7 @@ def init_messages(self, session: Session): self.chart_message = [] # add sys prompt - self.chart_message.append(SystemMessage(content=self.chat_question.chart_sys_question())) + self.chart_message.append(SystemPromptMessage(content=self.chat_question.chart_sys_question())) if last_chart_messages is not None and len(last_chart_messages) > 0: last_rounds = get_last_conversation_rounds(last_chart_messages, rounds=count_chart_limit) @@ -369,7 +388,7 @@ def generate_analysis(self, _session: Session): self.filter_custom_prompts(_session, CustomPromptTypeEnum.ANALYSIS, self.current_user.oid, ds_id) - analysis_msg.append(SystemMessage(content=self.chat_question.analysis_sys_question())) + analysis_msg.append(SystemPromptMessage(content=self.chat_question.analysis_sys_question())) analysis_msg.append(HumanMessage(content=self.chat_question.analysis_user_question())) self.current_logs[OperationEnum.ANALYSIS] = start_log(session=_session, @@ -379,6 +398,8 @@ def generate_analysis(self, _session: Session): record_id=self.record.id, full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, 'sqlbot_system', + False) is True, 'content': msg.content} for msg in analysis_msg]) @@ -400,6 +421,8 @@ def generate_analysis(self, _session: Session): OperationEnum.ANALYSIS], full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, 'sqlbot_system', + False) is True, 'content': msg.content} for msg in analysis_msg], reasoning_content=full_thinking_text, @@ -417,7 +440,7 @@ def generate_predict(self, _session: Session): self.filter_custom_prompts(_session, CustomPromptTypeEnum.PREDICT_DATA, self.current_user.oid, ds_id) predict_msg: List[Union[BaseMessage, dict[str, Any]]] = [] - predict_msg.append(SystemMessage(content=self.chat_question.predict_sys_question())) + predict_msg.append(SystemPromptMessage(content=self.chat_question.predict_sys_question())) predict_msg.append(HumanMessage(content=self.chat_question.predict_user_question())) self.current_logs[OperationEnum.PREDICT_DATA] = start_log(session=_session, @@ -427,6 +450,8 @@ def generate_predict(self, _session: Session): record_id=self.record.id, full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, 'sqlbot_system', + False) is True, 'content': msg.content} for msg in predict_msg]) @@ -449,6 +474,8 @@ def generate_predict(self, _session: Session): OperationEnum.PREDICT_DATA], full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, 'sqlbot_system', + False) is True, 'content': msg.content} for msg in predict_msg], reasoning_content=full_thinking_text, @@ -466,7 +493,7 @@ def generate_recommend_questions_task(self, _session: Session): embedding=False) guess_msg: List[Union[BaseMessage, dict[str, Any]]] = [] - guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number))) + guess_msg.append(SystemPromptMessage(content=self.chat_question.guess_sys_question(self.articles_number))) old_questions = list(map(lambda q: q.strip(), get_old_questions(_session, self.record.datasource))) guess_msg.append( @@ -479,6 +506,9 @@ def generate_recommend_questions_task(self, _session: Session): record_id=self.record.id, full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in guess_msg]) @@ -500,6 +530,9 @@ def generate_recommend_questions_task(self, _session: Session): OperationEnum.GENERATE_RECOMMENDED_QUESTIONS], full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in guess_msg], reasoning_content=full_thinking_text, @@ -512,7 +545,7 @@ def generate_recommend_questions_task(self, _session: Session): def select_datasource(self, _session: Session): datasource_msg: List[Union[BaseMessage, dict[str, Any]]] = [] - datasource_msg.append(SystemMessage(self.chat_question.datasource_sys_question())) + datasource_msg.append(SystemPromptMessage(self.chat_question.datasource_sys_question())) if self.current_assistant and self.current_assistant.type != 4: _ds_list = get_assistant_ds(session=_session, llm_service=self) else: @@ -552,6 +585,9 @@ def select_datasource(self, _session: Session): operate=OperationEnum.CHOOSE_DATASOURCE, record_id=self.record.id, full_message=[{'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in datasource_msg]) @@ -571,6 +607,9 @@ def select_datasource(self, _session: Session): OperationEnum.CHOOSE_DATASOURCE], full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in datasource_msg], reasoning_content=full_thinking_text, @@ -661,7 +700,10 @@ def generate_sql(self, _session: Session): operate=OperationEnum.GENERATE_SQL, record_id=self.record.id, full_message=[ - {'type': msg.type, 'content': msg.content} for msg + {'type': msg.type, + 'sqlbot_system': getattr(msg, 'sqlbot_system', + False) is True, + 'content': msg.content} for msg in self.sql_message]) full_thinking_text = '' full_sql_text = '' @@ -678,7 +720,11 @@ def generate_sql(self, _session: Session): self.current_logs[OperationEnum.GENERATE_SQL] = end_log(session=_session, log=self.current_logs[OperationEnum.GENERATE_SQL], - full_message=[{'type': msg.type, 'content': msg.content} + full_message=[{'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, + 'content': msg.content} for msg in self.sql_message], reasoning_content=full_thinking_text, token_usage=token_usage) @@ -690,7 +736,7 @@ def generate_with_sub_sql(self, session: Session, sql, sub_mappings: list): self.chat_question.sql = sql self.chat_question.sub_query = sub_query dynamic_sql_msg: List[Union[BaseMessage, dict[str, Any]]] = [] - dynamic_sql_msg.append(SystemMessage(content=self.chat_question.dynamic_sys_question())) + dynamic_sql_msg.append(SystemPromptMessage(content=self.chat_question.dynamic_sys_question())) dynamic_sql_msg.append(HumanMessage(content=self.chat_question.dynamic_user_question())) self.current_logs[OperationEnum.GENERATE_DYNAMIC_SQL] = start_log(session=session, @@ -699,6 +745,9 @@ def generate_with_sub_sql(self, session: Session, sql, sub_mappings: list): operate=OperationEnum.GENERATE_DYNAMIC_SQL, record_id=self.record.id, full_message=[{'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in dynamic_sql_msg]) @@ -720,6 +769,9 @@ def generate_with_sub_sql(self, session: Session, sql, sub_mappings: list): OperationEnum.GENERATE_DYNAMIC_SQL], full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in dynamic_sql_msg], reasoning_content=full_thinking_text, @@ -748,7 +800,7 @@ def build_table_filter(self, session: Session, sql: str, filters: list): self.chat_question.sql = sql self.chat_question.filter = filter permission_sql_msg: List[Union[BaseMessage, dict[str, Any]]] = [] - permission_sql_msg.append(SystemMessage(content=self.chat_question.filter_sys_question())) + permission_sql_msg.append(SystemPromptMessage(content=self.chat_question.filter_sys_question())) permission_sql_msg.append(HumanMessage(content=self.chat_question.filter_user_question())) self.current_logs[OperationEnum.GENERATE_SQL_WITH_PERMISSIONS] = start_log(session=session, @@ -758,6 +810,9 @@ def build_table_filter(self, session: Session, sql: str, filters: list): record_id=self.record.id, full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in permission_sql_msg]) @@ -778,6 +833,9 @@ def build_table_filter(self, session: Session, sql: str, filters: list): OperationEnum.GENERATE_SQL_WITH_PERMISSIONS], full_message=[ {'type': msg.type, + 'sqlbot_system': getattr(msg, + 'sqlbot_system', + False) is True, 'content': msg.content} for msg in permission_sql_msg], reasoning_content=full_thinking_text, @@ -813,7 +871,10 @@ def generate_chart(self, _session: Session, chart_type: Optional[str] = '', sche operate=OperationEnum.GENERATE_CHART, record_id=self.record.id, full_message=[ - {'type': msg.type, 'content': msg.content} for + {'type': msg.type, + 'sqlbot_system': getattr(msg, 'sqlbot_system', + False) is True, + 'content': msg.content} for msg in self.chart_message]) full_thinking_text = '' @@ -834,7 +895,10 @@ def generate_chart(self, _session: Session, chart_type: Optional[str] = '', sche self.current_logs[OperationEnum.GENERATE_CHART] = end_log(session=_session, log=self.current_logs[OperationEnum.GENERATE_CHART], full_message=[ - {'type': msg.type, 'content': msg.content} + {'type': msg.type, + 'sqlbot_system': getattr(msg, 'sqlbot_system', + False) is True, + 'content': msg.content} for msg in self.chart_message], reasoning_content=full_thinking_text, token_usage=token_usage) diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index 0fdf9786..c4a1e596 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -71,18 +71,21 @@ template: 内有等信息; 其中,:提供数据库引擎及版本信息; :以 M-Schema 格式提供数据库表结构信息; - :提供一组术语,块内每一个就是术语,其中同一个内的多个代表术语的多种叫法,也就是术语与它的同义词,即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件; - :提供一组SQL示例,你可以参考这些示例来生成你的回答,其中内是提问,内是对于该提问的解释或者对应应该回答的SQL示例。 + :提供一组术语,块内每一个就是术语,其中同一个内的多个代表术语的多种叫法,也就是术语与它的同义词,即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件; + :提供一组SQL示例,你可以参考这些示例来生成你的回答,其中内是提问,内是对于该提问的解释或者对应应该回答的SQL示例。 若有块,它会提供一组,可能会是额外添加的背景信息,或者是额外的生成SQL的要求,请结合额外信息或要求后生成你的回答。 - 用户的提问在内,内则会提供上次执行你提供的SQL时会出现的错误信息,内的会告诉你用户当前提问的时间 你必须遵守内规定的生成SQL规则 你必须遵守内规定的检查步骤生成你的回答 + 用户的提问在内,内则会提供上次执行你提供的SQL时会出现的错误信息,内的会告诉你用户当前提问的时间 + {process_check} + + generate_rules: | + 以下是你必须遵守的规则和可以参考的基础示例: 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 - 即使示例中显示了中文消息,在实际生成时也必须翻译成英文 你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL @@ -186,8 +189,6 @@ template: - {process_check} - {basic_sql_examples} @@ -298,23 +299,32 @@ template: - - 以下是正式的信息: + + generate_terminologies_info: | + 以下是你可以参考的术语: + {terminologies} + + generate_data_training_info: | + 以下是你可以参考的SQL示例: + {data_training} + + generate_custom_prompt_info: | + 以下是你可以参考的额外信息: + {custom_prompt} + + generate_basic_info: | + 以下是数据库与表结构信息,你生成的SQL使用到的表名与字段必须在提供的范围内 {engine} {schema} - - {terminologies} - {data_training} - {custom_prompt} - - ### 响应, 请根据上述要求直接返回JSON结果: - ```json user: | + 请根据上述要求,用语言:{lang} 进行回答 + 如果内的提问与上述要求冲突,你需要停止生成SQL并告知生成SQL失败的原因 + 请输出符合要求的JSON回答 {current_time} diff --git a/frontend/src/entity/supplier.ts b/frontend/src/entity/supplier.ts index 1e5a459d..cd672f9d 100644 --- a/frontend/src/entity/supplier.ts +++ b/frontend/src/entity/supplier.ts @@ -46,6 +46,10 @@ export const supplierList: Array<{ { key: 'extra_body', val: '{"enable_thinking": false}', type: 'json' }, ], model_options: [ + { name: 'qwen3.6-plus' }, + { name: 'qwen3.5-plus' }, + { name: 'qwen3.5-flash' }, + { name: 'qwen3-coder-next' }, { name: 'qwen3-coder-plus' }, { name: 'qwen3-coder-flash' }, { name: 'qwen-plus' },