diff --git a/Dockerfile b/Dockerfile index b1c1e925..cb7faf67 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,14 @@ ENV DEBIAN_FRONTEND=noninteractive # Create necessary directories RUN mkdir -p ${APP_HOME} ${UI_HOME} +# ODBC build/runtime deps for pyodbc + SQL Server +RUN apt-get update && apt-get install -y --no-install-recommends \ + unixodbc \ + unixodbc-dev \ + freetds-bin \ + tdsodbc \ + && rm -rf /var/lib/apt/lists/* + WORKDIR ${APP_HOME} COPY --from=sqlbot-ui-builder ${UI_HOME} ${UI_HOME} @@ -70,6 +78,12 @@ FROM registry.cn-qingdao.aliyuncs.com/dataease/sqlbot-python-pg:latest RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ echo "Asia/Shanghai" > /etc/timezone +RUN apt-get update && apt-get install -y --no-install-recommends \ + unixodbc \ + freetds-bin \ + tdsodbc \ + && rm -rf /var/lib/apt/lists/* + # Set runtime environment variables ENV PYTHONUNBUFFERED=1 ENV SQLBOT_HOME=/opt/sqlbot diff --git a/Dockerfile-base b/Dockerfile-base index c78ab7c1..0ae3a7c2 100644 --- a/Dockerfile-base +++ b/Dockerfile-base @@ -16,6 +16,10 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ gnupg \ gcc \ g++ \ + unixodbc \ + unixodbc-dev \ + freetds-bin \ + tdsodbc \ libcairo2-dev \ libpango1.0-dev \ libjpeg-dev \ diff --git a/backend/apps/db/db.py b/backend/apps/db/db.py index bbf9e97e..147f3758 100644 --- a/backend/apps/db/db.py +++ b/backend/apps/db/db.py @@ -9,7 +9,7 @@ import oracledb import psycopg2 -import pymssql +import pyodbc from apps.db.db_sql import get_table_sql, get_field_sql, get_version_sql from common.error import ParseSQLResultError @@ -63,10 +63,8 @@ def get_uri_from_config(type: str, conf: DatasourceConf) -> str: else: db_url = f"mysql+pymysql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" elif equals_ignore_case(type, "sqlServer"): - if conf.extraJdbc is not None and conf.extraJdbc != '': - db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" - else: - db_url = f"mssql+pymssql://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}" + odbc_connect = urllib.parse.quote_plus(get_sqlserver_odbc_conn_str(conf)) + db_url = f"mssql+pyodbc:///?odbc_connect={odbc_connect}" elif equals_ignore_case(type, "pg", "excel"): if conf.extraJdbc is not None and conf.extraJdbc != '': db_url = f"postgresql+psycopg2://{urllib.parse.quote(conf.username)}:{urllib.parse.quote(conf.password)}@{conf.host}:{conf.port}/{conf.database}?{conf.extraJdbc}" @@ -98,7 +96,7 @@ def get_extra_config(conf: DatasourceConf): if conf.extraJdbc: config_arr = conf.extraJdbc.split("&") for config in config_arr: - kv = config.split("=") + kv = config.split("=", 1) if len(kv) == 2 and kv[0] and kv[1]: config_dict[kv[0]] = kv[1] else: @@ -106,31 +104,41 @@ def get_extra_config(conf: DatasourceConf): return config_dict -def get_origin_connect(type: str, conf: DatasourceConf): +def _escape_odbc_value(value: str) -> str: + # ODBC uses ';' as a separator and '}' as an escape marker in braced values. + text = '' if value is None else str(value) + return '{' + text.replace('}', '}}') + '}' + + +def get_sqlserver_odbc_conn_str(conf: DatasourceConf) -> str: extra_config_dict = get_extra_config(conf) + driver = conf.driver.strip() if conf.driver and conf.driver.strip() else 'FreeTDS' + driver = driver.replace('}', '}}') + conn_parts = [ + f"DRIVER={{{driver}}}", + f"SERVER={_escape_odbc_value(conf.host)}", + f"PORT={_escape_odbc_value(conf.port)}", + f"DATABASE={_escape_odbc_value(conf.database)}", + f"UID={_escape_odbc_value(conf.username)}", + f"PWD={_escape_odbc_value(conf.password)}", + ] + + # Keep compatibility for SQL Server 2008+ when lowVersion is enabled. + if conf.lowVersion is None or conf.lowVersion: + conn_parts.append('TDS_Version=7.0') + + if conf.timeout is not None and conf.timeout > 0: + conn_parts.append(f"Timeout={conf.timeout}") + + for key, value in extra_config_dict.items(): + conn_parts.append(f"{key}={_escape_odbc_value(value)}") + + return ';'.join(conn_parts) + + +def get_origin_connect(type: str, conf: DatasourceConf): if equals_ignore_case(type, "sqlServer"): - # none or true, set tds_version = 7.0 - if conf.lowVersion is None or conf.lowVersion: - return pymssql.connect( - server=conf.host, - port=str(conf.port), - user=conf.username, - password=conf.password, - database=conf.database, - timeout=conf.timeout, - tds_version='7.0', # options: '4.2', '7.0', '8.0' ..., - **extra_config_dict - ) - else: - return pymssql.connect( - server=conf.host, - port=str(conf.port), - user=conf.username, - password=conf.password, - database=conf.database, - timeout=conf.timeout, - **extra_config_dict - ) + return pyodbc.connect(get_sqlserver_odbc_conn_str(conf), timeout=conf.timeout) # use sqlalchemy @@ -150,7 +158,7 @@ def get_engine(ds: CoreDatasource, timeout: int = 0) -> Engine: else: engine = create_engine(get_uri(ds), connect_args={"connect_timeout": conf.timeout}, poolclass=NullPool) elif equals_ignore_case(ds.type, 'sqlServer'): - engine = create_engine('mssql+pymssql://', creator=lambda: get_origin_connect(ds.type, conf), + engine = create_engine('mssql+pyodbc://', creator=lambda: get_origin_connect(ds.type, conf), poolclass=NullPool) elif equals_ignore_case(ds.type, 'oracle'): engine = create_engine(get_uri(ds), poolclass=NullPool) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 06ae1a4f..d010c96f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "pymysql (>=1.1.1,<2.0.0)", "cryptography (>=44.0.3,<45.0.0)", "llama_index>=0.12.35", - "pymssql (>=2.3.4,<3.0.0)", + "pyodbc (>=5.1.0,<6.0.0)", "pandas (>=2.2.3,<3.0.0)", "openpyxl (>=3.1.5,<4.0.0)", "psycopg2-binary (>=2.9.10,<3.0.0)",