Python实现MySQL建表语句转换成Clickhouse SQL

发布于:2025-06-22 ⋅ 阅读:(15) ⋅ 点赞:(0)

主程序:** main_converter.py **

import re
import json
import argparse

def load_config(config_path: str) -> dict:
    '''
    配置管理模块:加载JSON格式配置文件
    '''
    with open(config_path, mode="r", encoding="utf-8") as f: 
        return json.load(f)

class BaseConverter:
    '''
    转换规则接口及实现
    '''

    def apply(self, field: str) -> str: 
        raise NotlmplementedError("子类必须实现apply方法")

class TypeMappingConverter(BaseConverter):
    '''
    数据类型转换规则
    '''
    def __init__(self, mapping: dict): # mapping 为数据类型映射,如{"INT":"Ulnt32", "VARCHAR": "String",...}
        self.mapping = mapping

    def apply(self, field: str) -> str:
        for mysql_type, ch_type in self.mapping.items(): 
        # 针对 DECIMAL 单独处理,保留括号内容
        if mysql_type.upper() == "DECIMAL":
            field = re.sub(
                r'\bDECIMAL\((\d+\s*,\s*\d+)\)\b',
                lambda m: f"{ch_type}({m.group(1)})", 
                field,
                flags=re.IGNORECASE
            )
        else:
            pattern = r'\b' + mysql_type + r'(\(\d+\))?\b'
            field = re.sub(pattern, ch_type, field, flags=re.IGNORECASE) 
        return field

# 默认值替换规则
class DefaultValueConverter(BaseConverter):
     
    def __init__(self, replacements: dict):
        self.replacements = replacements

    def apply(self, field: str) -> str:
        for key, val in self.replacements.items():
            field = re.sub(
                r'\bDEFAULT\s+' + key + r'\b',
                'DEFAULT' + val, 
                field,
                flags=re.IGNORECASE
            )
        return field

# 忽略关键字规则:移除AUTO_INCREMENT、PRIMARY KEY、KEY、 FOREIGN KEY 等语句
class IgnoreKeywordConverter(BaseConverter):
    def __init__(self, keywords: list): 
        self.keywords = keywords

    def apply(self, field: str) -> str:
        for keyword in self.keywords:
            field = re.sub(keyword, ", field, flags=re.IGNORECASE) 
        return field

# 注释规则,可选择保留或移除MySQL 中的COMMENT 部分
class CommentConverter(BaseConverter):
    def __init__(self, keep: bool =True):
        self.keep = keep

    def apply(self, field: str) -> str:
        if self.keep: 
            return field 
        else:
            field = re.sub(r'\bCOMMENT\s+\'[^\']*\", ", field, flags=re.IGNORECASE)
            return field

#注册转换规则,依次调用各规则处理字段定义
class RuleRegistry: 
    def __init__(self): 
        self.rules = []

    def register(self, rule: BaseConverter):
        self.rules.append(rule)

    def process(self,field: str) -> str:
        for rule in self.rules:
            field = rule.apply(field)
        return re.sub(r'\s+','', field).strip()

# 去除MySQL专有语法的Converter
class MySQLSyntaxCleaner(BaseConverter):

    def apply(self, field: str) -> str:
        # 去除 CHARACTER SET 和 COLLATE
        field = re.sub(r'CHARACTER SET\s+\w+', '', field, flags=re.IGNORECASE)
field = re.sub(r'COLLATE\s+\w+', ", field, flags=re.IGNORECASE)
        # 去除长度限制,如 String(60)-> String
        field = re.sub(r'(String|VARCHAR|CHAR)\s*\(\d+\)', 'String', field, flags=re.IGNORECASE)
        # text 类型统一为 String
        field = re.sub(r'\btext\b', 'String', field, flags=re.IGNORECASE)
        # 去除 NOT NULL, NULL, DEFAULT xxx, AUTO_INCREMENT
        field = re.sub(r'\bNOT NULL\b', '', field, flags=re.IGNORECASE)
        field = re.sub(r'\bNULL\b', '', field, flags=re.IGNORECASE)
        field = re.sub(r'\bDEFAULT\s+[^,]+', '', field, flags=re.IGNORECASE)
        field = re.sub(r'\bAUTO_INCREMENT\b', '', field, flags=re.IGNORECASE)
        # 去除多余空格和逗号
        field = re.sub(r'\s+', '', field).strip()
        return field

    #-------------------------------
    # SQL 解析模块:解析 MySQL 建表语句
    # ------------------------------
    def split_fields(fields_part: str) -> list:         
        fields = [] 
        buf = ''
        depth = 0
        for c in fields_part:
            if c == '(': 
                depth += 1 
            elif c == ')': 
                depth -= 1
            if c == ',' and depth ==0:
                if buf.strip():
                    fields.append(buf.strip()) 
                buf = ''
            else:
                buf += c 
        if buf.strip():
            fields.append(buf.strip()) 
        return fields

    def parse_mysql_create_table(sql: str) -> (str, list, str): 
        sql = re.sub(r'\s+', '', sql.strip()) 
        # 支持反引号包裹的表名
        table_match = re.search(r'CREATE\s+TABLE\s+`?(\w+)`?\s*\(', sql, re.IGNORECASE)
        if not table_match:
            raise ValueError("无法解析表名")
        table_name = table_match.group(1)
        fields_part_match = re.search(r'\((.*)\)\s*ENGINE=', sql, re.IGNORECASE)
        if not fields_part_match:
            raise ValueError("无法解析字段部分")
        fields_part = fields_part_match.group(1)
        fields = split_fields(fields_part)
        engine_match = re.search(r'ENGINE\s*=\s*(\w+)', sql, re.IGNORECASE) 
        engine = engine_match.group(1) if engine_match else ""
        return table_name, fields, engine

    #--------------------
    # 构造 ClickHouse建表语句生成模块
    #--------------------
    def generate_clickhouse_create_table(mysql_sql: str, config: dict) -> str:
        table_name, fields, _ =  parse_mysql_create_table(mysql_sql)     
        registry = RuleRegistry()
         registry.register(TypeMappingConverter(config.get("type_mapping", {})))
        registry.register(DefaultValueConverter(config.get("default_value_replacements", {})))
        registry.register(IgnoreKeywordConverter(config.get("ignore_keywords", [])))
        registry.register(CommentConverter(keep=False))
         registry.register(MySQLSyntaxCleaner()) 

        new_fields = []
        partition_field = None

        for field in fields:
            # 跳过表级约束和索引 
            if re.match(r'^(KEY|CONSTRAINT|FOREIGN\s+KEY|PRIMARY\s+KEY|FULLTE XT)', field, re.IGNORECASE):
                continue
            conv_field = registry.process(field)
            if conv_field == "" or conv_field == ",": 
                continue
            # 去除多余逗号
            conv_field = conv_field.rstrip(',')
            new_fields.append(conv_field)
            # 自动选择分区/排序字段
            if (not partition_field) and re.search(config.get("partition_rule",
{}).get("match", "DateTime|date|create"), conv_field, re.IGNORECASE):
                fld_match = re.match(r'` ?(\w+)`?\s+', conv_field) 
                if fld_match:
                    partition_field = fld_match.group(1)

        if not partition_field:
            partition_field = config.get("order_rule", {}).get("default_field", new_fields[0].split()[0])
        ch_table_name = table_name + "_local"
        new_fields_str = ",\n ".join(new_fields)
        partition_expr = config.get("partition_rule",  {}).get("expression", "{field}").format(field=partition_field)     
        order_by = partition_field

        ch_sql = f"""CREATE TABLE{ch_table_name}(
    {new_fields_str}
) ENGINE = MergeTree()
PARTITION BY {partition_expr}
ORDER BY ({order_by});
"""
        return ch_sql

#---------------------
#主程序入口:解析命令行参数并执行转换
#---------------------

    def main():
        parser = argparse.ArgumentParser(description="将MySQL 建表语句转换为 ClickHouse建表语句")
        parser.add_argument("--config", type=str, default="rules.json", help="转换规则配置文件路径")
        parser.add_argument("--input", type=str, required=True, help="包含 MySQL 建表语句的 SQL 文件路径")
        parser.add_argument("--output", type=str, default="clickhouse.sql", help="输出 ClickHouse 建表语句的文件路径")
        args = parser.parse_args()

        config = load_config(args.config)
        with open(args.input, mode="r", encoding="utf-8") as f:
            mysql_sql = f.read()
        config = load_config("rules.json")
        ch_sql = generate_clickhouse_create_table(mysql_sql, config)
        with open(args.output, mode="w", encoding="utf-8") as f:
            f.write(ch_sql)
        print("ClickHouse 建表语句已生成,文件位置为:",args.output)

if __name__ == '__main__': 
    main()

测试用例:** test converter.py **

import unittest
from main_converter import generate_clickhouse_create_table

class TestMySQLToClickHouseConversion(unittest.TestCase):
    def setUp(self): # 配置规则与rules.json 基本内容保持一致
        self.config = {"type_mapping":{ "INT": "UInt32", "BIGINT": "UInt64", "VARCHAR": "String", "CHAR": "String", "DATETIME": "DateTime", "DATE": "Date", "DECIMAL": "Decimal"}, "default_value_replacements":{"CURRENT_TIMESTAMP": "now()"}, "ignore_keywords":
["AUTO_INCREMENT", "PRIMARY KEY", "KEY", "CONSTRAINT", "FOREIGN KEY", "CHECK"], "partition_rule": {"match": "DateTime", "expression": "toYYYYMM({field})"}, "order_rule": { "default_field": "order_date" }}
        self.mysql_sql = """CREATE TABLE orders ( order_id INT
AUTO_INCREMENT PRIMARY KEY, user_id INT NOT NULL, region VARCHAR(50) COMMENT '区域', order_date DATETIME DEFAULT
CURRENT_TIMESTAMP, amount DECIMAL(10,2), KEY idx_user (user_id), CONSTRAINT fk_user FOREIGN KEY (user_id) REFERENCES users(id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8;"""

    def test_conversion(self):
        ch_sql = generate_clickhouse_create_table(self.mysql_sql, self.config)
        # 检查生成的 SQL是否包含MergeTree引擎及分区与排序配置
        self.assertIn("ENGINE = MergeTree()", ch_sql)
        self.assertIn("PARTITION BY toYYYYMM(order_date)", ch_sql) 
        self.assertIn("ORDER BY (order_date)", ch_sql)
        # 检查数据类型替换是否正确:order_id 应转换为 UInt32
        self.assertIn("order_id UInt32", ch_sql)
        # 检查默认值替换:CURRENT_TIMESTAMP转换为now()
        self.assertIn("DEFAULT now()", ch_sql)
        # 检查忽略关键词是否处理:AUTO_INCREMENT、PRIMARY KEY等不应出现在字段定义中
        self.assertNotIn("AUTO_INCREMENT", ch_sql)
        self.assertNotIn("PRIMARY KEY", ch_sql)

unittest.main()

配置文件:** rules.json **

{
    "type_mapping": {
        "INT": "UInt32",
        "BIGINT": "UInt64",
        "VARCHAR": "String",
        "CHAR": "String",
        "DATETIME": "DateTime",
        "DATE": "Date",
        "DECIMAL": "Decimal"
    },
    "default_value_replacements": {
        "CURRENT_TIMESTAMP": "now()"
    },
    "ignore_keywords": [
        "AUTO_INCREMENT",
        "PRIMARY KEY",
        "KEY",
        "CONSTRAINT",
        "FOREIGN KEY",
        "CHECK"
    ],
    "partition_rule":{
        "match": "DateTime",
        "expression": "toYYYYMM({field})"
    },
    "order_rule": {
        "default_field": "order_date"
    }
}

网站公告

今日签到

点亮在社区的每一天
去签到