From 784a00fc01a58bcfe9c186100557ada996d31c5b Mon Sep 17 00:00:00 2001 From: ulleo Date: Wed, 8 Apr 2026 16:55:59 +0800 Subject: [PATCH] feat: improve generate sql/chart quality --- backend/apps/chat/models/chat_model.py | 10 ++++-- backend/apps/chat/task/llm.py | 8 ++++- backend/templates/sql_examples/Oracle.yaml | 4 ++- backend/templates/template.yaml | 38 +++++++++++++--------- 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/backend/apps/chat/models/chat_model.py b/backend/apps/chat/models/chat_model.py index 14ab59fd..73dc40ad 100644 --- a/backend/apps/chat/models/chat_model.py +++ b/backend/apps/chat/models/chat_model.py @@ -248,7 +248,7 @@ def sql_sys_question(self, db_type: Union[str, DB], enable_query_limit: bool = T _example_answer_3 = _sql_template['example_answer_3_with_limit'] if enable_query_limit else _sql_template[ 'example_answer_3'] - templates['system'] = _base_template['system'].format(process_check=_process_check) + templates['system'] = _base_template['system'].format(lang=self.lang, process_check=_process_check) templates['rules'] = _base_template['generate_rules'].format(lang=self.lang, base_sql_rules=_base_sql_rules, basic_sql_examples=_sql_examples, @@ -282,10 +282,14 @@ def sql_user_question(self, current_time: str, change_title: bool): change_title=change_title) def chart_sys_question(self): - return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang) + templates: dict[str, str] = { + 'system': get_chart_template()['system'].format(lang=self.lang), + 'rules': get_chart_template()['generate_rules'].format(lang=self.lang) + } + return templates def chart_user_question(self, chart_type: Optional[str] = '', schema: Optional[str] = ''): - return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule, + return get_chart_template()['user'].format(lang=self.lang, sql=self.sql, question=self.question, rule=self.rule, chart_type=chart_type, schema=schema) def analysis_sys_question(self): diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index b59efb37..dd314fd0 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -261,11 +261,17 @@ def init_messages(self, session: Session): filter(lambda obj: obj.pid == self.chat_question.regenerate_record_id, self.generate_chart_logs), None) last_chart_messages: List[dict[str, Any]] = _temp_log.messages if _temp_log else [] + # 排除所有的系统提示词 + last_chart_messages = [obj for obj in last_chart_messages if obj.get("sqlbot_system") != True] + count_chart_limit = self.base_message_round_count_limit self.chart_message = [] # add sys prompt - self.chart_message.append(SystemPromptMessage(content=self.chat_question.chart_sys_question())) + _chart_system_templates = self.chat_question.chart_sys_question() + self.chart_message.append(SystemPromptMessage(content=_chart_system_templates['system'])) + self.chart_message.append(HumanPromptMessage(content=_chart_system_templates['rules'])) + self.chart_message.append(AIPromptMessage(content='我已掌握所有规则,我会严格遵守这些规则来生成符合要求的JSON。')) if last_chart_messages is not None and len(last_chart_messages) > 0: last_rounds = get_last_conversation_rounds(last_chart_messages, rounds=count_chart_limit) diff --git a/backend/templates/sql_examples/Oracle.yaml b/backend/templates/sql_examples/Oracle.yaml index 60d36d47..7592bb4c 100644 --- a/backend/templates/sql_examples/Oracle.yaml +++ b/backend/templates/sql_examples/Oracle.yaml @@ -10,7 +10,9 @@ template: 7. 强制检查:验证SQL语法是否符合规范 8. 确定图表类型(根据规则选择table/column/bar/line/pie) 9. 确定对话标题 - 10. 返回JSON结果 + 10. 生成JSON结果 + 11. 强制检查:JSON格式是否正确 + 12. 返回JSON结果 quot_rule: | diff --git a/backend/templates/template.yaml b/backend/templates/template.yaml index c4a1e596..c33e31b4 100644 --- a/backend/templates/template.yaml +++ b/backend/templates/template.yaml @@ -19,7 +19,9 @@ template: 6. 强制检查:验证SQL语法是否符合规范 7. 确定图表类型(根据规则选择table/column/bar/line/pie) 8. 确定对话标题 - 9. 返回JSON结果 + 9. 生成JSON结果 + 10. 强制检查:JSON格式是否正确 + 11. 返回JSON结果 query_limit: | @@ -84,9 +86,6 @@ template: generate_rules: | 以下是你必须遵守的规则和可以参考的基础示例: - - 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 - 你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL @@ -102,7 +101,7 @@ template: 你只需要根据提供给你的信息生成的SQL,不需要你实际去数据库进行查询 - + 请使用JSON格式返回你的回答: 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table","brief":"如何需要生成对话标题,在这里填写你生成的对话标题,否则不需要这个字段"}} 若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}} @@ -322,9 +321,9 @@ template: user: | - 请根据上述要求,用语言:{lang} 进行回答 - 如果内的提问与上述要求冲突,你需要停止生成SQL并告知生成SQL失败的原因 - 请输出符合要求的JSON回答 + ## 请根据上述要求,使用语言:{lang} 进行回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + ## 如果内的提问与上述要求冲突,你需要停止生成SQL并告知生成SQL失败的原因 + ## 回答中不需要输出你的分析,请直接输出符合要求的JSON {current_time} @@ -348,13 +347,23 @@ template: :需要参考的SQL :以 M-Schema 格式提供 SQL 内用到表的数据库表结构信息,你可以参考字段名与字段备注来生成图表使用到的字段名 :推荐你生成的图表类型 + 你必须遵守内规定的生成图表结构的规则 + 你必须遵守内规定的检查步骤生成你的回答 - 你必须遵守以下规则: + + 1. 分析提供的,结合确认图表需要的指标,维度和分类 + 2. 应用规则 + 3. 检查指标,维度和分类字段是否在SQL内存在 + 4. 结合确认指标,维度和分类展示用的名称 + 5. 生成JSON结果 + 6. 强制检查:JSON格式是否正确 + 7. 返回JSON结果 + + + generate_rules: | + 以下是你必须遵守的规则和可以参考的基础示例: - - 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 - 支持的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 提供给你的值则为 table/column/bar/line/pie 中的一个,若没有推荐类型,则由你自己选择一个合适的类型。 图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table @@ -487,11 +496,10 @@ template: - - ### 响应, 请根据上述要求直接返回JSON结果: - ```json user: | + ## 请根据上述要求,使用语言:{lang} 进行回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + ## 回答中不需要输出你的分析,请直接输出符合要求的JSON {question}