Sqlite 数据库工具功能添加

发布于:2024-10-11 ⋅ 阅读:(10) ⋅ 点赞:(0)
#AUTHOR HongDaYu
from collections.abc import Callable
import sqlite3
from loguru import logger
import time
import os
import re

__doc__ = """
compatible: Linux or Windows
"""


class Sqlite3Database:
    __info__ = dict()
    __path_database__ = ""
    __database_self_diff__ = dict()
    __database_other_diff__ = dict()

    def __init__(self, database: str = ""):
        try:
            self.__path_database__ = database
            self.__connect__ = sqlite3.connect(database)
            self.__cur__ = self.__connect__.cursor()
            logger.info("connect {} successful", database)
            self.__cur__.execute('SELECT name from sqlite_master where type="table"')
            TableName = self.__cur__.fetchall()
            __TableName_ = list(map(lambda x: list(x)[0], TableName))
            self.__cur__.execute(
                'SELECT sql from sqlite_master where type="table" and name in ({0})'.format(
                    ", ".join("?" for _ in __TableName_)
                ),
                __TableName_,
            )
            TableSql = self.__cur__.fetchall()
            __TableSql_ = list(map(lambda x: list(x)[0], TableSql))
            self.__info__ = dict(
                map(lambda key, value: [key, value], __TableName_, __TableSql_)
            )

        except sqlite3.OperationalError as err:
            logger.error("OperationalError {} failed", err.args)
        except sqlite3.ProgrammingError as err:
            logger.error("ProgrammingError {} failed", err.args)

    # TODO: adapt windows str schema info
    def overrideSchemaTableInfo(self):
        self.__info__ = {}
        self.__cur__.execute('SELECT name from sqlite_master where type="table"')
        TableName = self.__cur__.fetchall()
        __TableName_ = list(map(lambda x: list(x)[0], TableName))
        self.__cur__.execute(
            'SELECT sql from sqlite_master where type="table" and name in ({0})'.format(
                ", ".join("?" for _ in __TableName_)
            ),
            __TableName_,
        )
        TableSql = self.__cur__.fetchall()
        __TableSql_ = list(map(lambda x: list(x)[0], TableSql))
        self.__info__ = dict(
            map(lambda key, value: [key, value], __TableName_, __TableSql_)
        )

    def addData(self, table: str, rows: list) -> tuple:
        try:
            res = self.__info__.get(table)
            if res is None:
                logger.error(
                    "can't find {} in {}".format(table, self.__path_database__)
                )
                return sqlite3.SQLITE_ERROR, None
            STRIP = "".join(list(filter(lambda x: x != "\n", res)))
            schema = re.findall(r"[(](.*?)[)]", STRIP)
            if len(schema) == 0:
                logger.error("can't match table schema ")
                return sqlite3.SQLITE_ERROR, None
            else:
                schema = schema[0]
            schema = schema.split(",")
            __res__ = filter(lambda x: len(schema) == len(x), rows)
            __failed_res__ = []
            __flag__ = False
            for it in rows:
                if len(schema) != len(it):
                    logger.error("This add data not Match schema")
                    logger.error("SCHEMA {} ROWS {}", len(schema), len(it))
                    logger.error(it)
                    __failed_res__.append(it)
                    __flag__ = True

            self.__cur__.executemany(
                "INSERT INTO "
                + table
                + " VALUES({})".format(",".join("?" for _ in schema)),
                __res__,
            )

        except (
            sqlite3.OperationalError,
            sqlite3.IntegrityError,
            sqlite3.InterfaceError,
            sqlite3.ProgrammingError,
            sqlite3.InternalError,
        ) as err:
            logger.error("addData {}", err.args[0])
            return sqlite3.SQLITE_ERROR, None
        return (
            (sqlite3.SQLITE_OK, None)
            if __flag__ is False
            else (sqlite3.SQLITE_ERROR, __failed_res__)
        )

    def deleteData(self, table: str, index: list):
        try:
            self.__cur__.execute(
                "DELETE FROM "
                + table
                + " WHERE ID in ({}) ".format(",".join(i for i in index))
            )
        except sqlite3.OperationalError as err:
            logger.error("delete {} data {}", table, err.args)

    def deleteDataCond(self, table, condition: str):
        try:
            self.__cur__.execute("DELETE FROM " + table + " " + condition)
        except sqlite3.OperationalError as err:
            logger.error("delete {} data {}", table, err.args)

    def updateData(self, table: str, updateValue: list, condition: str):
        try:
            self.__cur__.execute(
                "UPDATE "
                + table
                + " "
                + "SET {} {}".format(",".join(i for i in updateValue), condition)
            )
        except (sqlite3.OperationalError, sqlite3.IntegrityError) as err:
            logger.error("update {} data {}", table, err.args[0],f"err: {updateValue} condition: {condition}")

    def readData(self, table: str, cols: str = "*", condition: str = "") -> tuple:
        try:
            self.__cur__.execute("SELECT " + cols + " FROM " + table + " " + condition)
            return sqlite3.SQLITE_OK, self.__cur__.fetchall()
        except sqlite3.OperationalError as err:
            logger.error("readData {}", err.args)
        return sqlite3.SQLITE_ERROR, []

    def rawExec(self, sql: str, params: tuple = ()):
        try:
            self.__cur__.execute(sql, params)
        except sqlite3.OperationalError as err:
            logger.error("OperationalError {}", err.args)

    def backup(
        self,
        name: str = time.strftime("%Y%m%d-%H%M%S", time.localtime()),
        progress: Callable[[int, int, int], object] | None = None,
    ) -> str:
        newDatabase = name + "-" + os.path.basename(self.__path_database__)
        if not os.path.exists(newDatabase):
            with open(newDatabase, "w+") as fd:
                fd.close()
        __data_base = sqlite3.connect(newDatabase)
        self.__connect__.backup(__data_base, pages=1, progress=progress)
        return newDatabase

    # TODO: adapt windows file \r\n
    def sqlScripts(self, sqlScriptsPath: str):
        try:
            with open(sqlScriptsPath, "r") as fd:
                self.__connect__.executescript(
                    "".join(filter(lambda x: x != "\n", fd.read()))
                )
        except sqlite3.OperationalError as err:
            logger.error("OperationalError {}", err.args)
        except OSError as err:
            logger.error("openSqlScriptsPath err {}", err.args)

    def handle(self) -> sqlite3.Connection:
        return self.__connect__

    def write2database(self):
        self.__connect__.commit()

    def rollBack(self):
        self.__connect__.rollback()

    def tableInfo(self) -> dict:
        return self.__info__

    # Look all table data
    def __eq__(self, other) -> bool:
        __flags__ = True
        if isinstance(other, Sqlite3Database):
            if self.tableInfo() != other.tableInfo():
                return False
            for it in self.__info__:
                V1 = tuple(self.readData(it)[1])
                V2 = tuple(other.readData(it)[1])
                if V1 != V2:
                    self.__database_self_diff__[it] = tuple(set(V1) - set(V2))
                    self.__database_other_diff__[it] = tuple(set(V2) - set(V1))
                    __flags__ = False
            return __flags__
        return False

    # set 1 - set 2
    def __sub__(self, other:'Sqlite3Database') -> dict:
        if self.tableInfo() != other.tableInfo():
            return {}
        for it in self.__info__:
            V1 = tuple(self.readData(it)[1])
            V2 = tuple(other.readData(it)[1])
            if V1 != V2:
                self.__database_self_diff__[it] = tuple(set(V1) - set(V2))
                self.__database_other_diff__[it] = tuple(set(V2) - set(V1))
        return self.__database_self_diff__

    # colName col Value Text
    def patch_database(self,updateColNameType:str,updateColNumber:int,selIDColName:str,selIDColNameIndex:int,
    other:'Sqlite3Database') -> bool:
        if self.tableInfo() != other.tableInfo():
            return False
        for it in self.__info__:
            data = tuple(self.readData(it)[1])
            for __it in data:
                other.updateData(it,[f"{updateColNameType}=\"{__it[updateColNumber]}\"",],f"WHERE {selIDColName}"
                                                                          f"=\"{__it[selIDColNameIndex]}\"")
        other.write2database()
        return True

    def getDiffSelf(self) -> dict:
        return self.__database_self_diff__

    def getDiffOther(self) -> dict:
        return self.__database_other_diff__
from db import Sqlite3Database


db = Sqlite3Database('./drurmu.db')
db_n = Sqlite3Database('./latest_db/drurmu.db')

if db.patch_database("V0",6,"ID",0,db_n):
    print("Database patched")
else:
    print("Database not patched")