首页
从mysql数据库注释生成word文档脚本

# -*- coding: utf-8 -*-
"""
将mysql数据库中表结构根据注释信息生成word文档
"""
import sys

reload(sys)
sys.setdefaultencoding('utf8')

import json
try:
    import pymysql
    from docx import Document
    from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
    from docx.shared import Pt
except ImportError:
    print('请先安装 pymysql python-docx 库!\npip install pymysql python-docx')
    exit(1)



class Table:
    name = ''
    comment = ''
    columns = []

    def __init__(self, name, comment):
        self.name = name
        self.comment = comment

    def __str__(self):
        return self.name + "\t" + self.comment


class Column:
    name = ''
    type = ''
    allow_null = ''
    default_value = ''
    comment = ''
    key = ''

    def __init__(self, name, type, allow_null, default_value, comment, key):
        self.name = name
        self.type = type
        self.allow_null = allow_null
        self.default_value = default_value
        self.comment = comment
        self.key = key

    def __str__(self):
        return self.name + "\t" + self.type + "\t" + self.allow_null + \
            "\t" + self.default_value + "\t" + self.comment + "\t" + self.key


class Mysql2docx(object):
    db_name = ''

    #
    def __init__(self):
        dbName = ''

    def get_comment(self, comment):
        if comment is None:
            return ""
        try:
            data = json.loads(comment)
            return data[0]['value']
        except BaseException:
            return comment

    def get_tables(self, db):
        sql = "select table_name, TABLE_COMMENT from information_schema.tables " \
              "where table_schema = '%s' and table_type = 'base table'" % self.db_name
        cursor = db.cursor()
        cursor.execute(sql)
        data = cursor.fetchall()
        tables = list()
        for table in data:
            t = Table(table[0], self.get_comment(table[1]))
            tables.append(t)
        cursor.close()
        return tables

    def get_columns(self, db, table_name):
        sql = "SELECT  " \
              "COLUMN_NAME 列名,  " \
              "COLUMN_TYPE 数据类型,  " \
              "IS_NULLABLE 是否为空,    " \
              "COLUMN_DEFAULT 默认值,    " \
              "COLUMN_COMMENT 备注,   " \
              "COLUMN_KEY 键值   " \
              "FROM  INFORMATION_SCHEMA.COLUMNS  " \
              "where  table_schema ='%s'  AND   table_name  = '%s';" % (self.db_name, table_name)
        cursor = db.cursor()
        cursor.execute(sql)
        data = cursor.fetchall()
        columns = list()
        fks = self.get_fks(db, table_name)
        for column in data:

            key_type = ''
            if column[5] == 'PRI':
                key_type = 'PK'
            elif column[0] in fks:
                key_type = 'FK'
            c = Column(
                column[0],
                column[1],
                column[2],
                column[3],
                self.get_comment(
                    column[4]),
                key_type)
            columns.append(c)
        cursor.close()
        return columns

    def get_fks(self, db, table_name):

        sql = """
        select
        TABLE_NAME,COLUMN_NAME,CONSTRAINT_NAME, REFERENCED_TABLE_NAME,REFERENCED_COLUMN_NAME
        from INFORMATION_SCHEMA.KEY_COLUMN_USAGE
        where CONSTRAINT_SCHEMA ='%s' AND
        TABLE_NAME = '%s' and REFERENCED_TABLE_NAME is not null;
        """ % (self.db_name, table_name)

        cursor = db.cursor()
        cursor.execute(sql)
        data = cursor.fetchall()
        fks = [i[1] for i in data]
        return fks

    def do(self, db_host, db_user, db_password, db_name, db_port, doc='数据库设计文档.docx'):
        print("dbHost:%s,dbUser:%s,dbPassword:%s,db_name:%s,dbPort:%d" %
              (db_host, db_user, db_password, db_name, db_port))
        self.db_name = db_name
        db = pymysql.connect(db_host,
                             db_user,
                             db_password,
                             db_name,
                             db_port,
                             charset="utf8")
        tables = self.get_tables(db)
        for table in tables:
            tableName = table.name
            table.columns = self.get_columns(db, tableName)

        document = Document()
        p = document.add_paragraph()
        paragraph_format = p.paragraph_format
        run = p.add_run(u'数据库设计文档')
        run.font.size = Pt(24)
        paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER

        for table in tables:
            print("table:%s" % table)
            document.add_heading("%s %s" % (table.name, table.comment), 2)
            print("len(table.columns) %d" % len(table.columns))
            t = document.add_table(
                rows=len(
                    table.columns) + 1,
                cols=5,
                style="Light List Accent 5")
            cells = t.rows[0].cells
            cells[0].text = u'字段名'
            cells[1].text = u'字段类型'
            cells[2].text = u'是否为空'
            cells[3].text = u'键引用'
            cells[4].text = u'注释'
            i = 0
            for column in table.columns:
                i += 1
                rowCells = t.rows[i].cells
                rowCells[0].text = column.name
                rowCells[1].text = column.type
                rowCells[2].text = column.allow_null
                # if column.default_value!=None:
                #     rowCells[3].text = column.default_value
                rowCells[4].text = column.comment
                if column.key:
                    rowCells[3].text = column.key

            document.add_page_break()

        document.save(doc)


if __name__ == '__main__':
    m = Mysql2docx()
    m.do('127.0.0.1', 'root', 'root', 'portal', 3306)