GO学习记录五——数据库表的增删改查

发布于:2025-08-15 ⋅ 阅读:(11) ⋅ 点赞:(0)

一、通过http接口实现简单的增删改查

main.go

package main

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"regexp"
	"time"

	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgxpool"
)

var db *pgxpool.Pool

// 启动函数
func main() {
	// 初始化数据库连接
	db = InitDB()

	// 注册路由
	RegisterRouter()
	// 启动 HTTP 服务
	StartHTTPSevrver()

}

// 启动 HTTP 服务
func StartHTTPSevrver() {
	address := "127.0.0.1:8080" //配置连接ip端口
	log.Printf("启动 HTTP 服务,监听端口 %s\n", address)
	err := http.ListenAndServe(address, nil)
	if err != nil {
		log.Fatalf("服务器启动失败:%v", err)
	}
}

// 注册路由
func RegisterRouter() {
	http.HandleFunc("/", helloHandler)    //http://localhost:8080/
	http.HandleFunc("/time", timeHandler) //http://localhost:8080/time
	//查询
	http.HandleFunc("/findTable", findTableNameHandler) //http://localhost:8080/findTable?tableName=name
	//添加
	http.HandleFunc("/addTable1", addTable1Handler) //http://localhost:8080/addTable1
	//删除
	http.HandleFunc("/deleteTableValue", deleteTableHandler) //http://localhost:8080/deleteTableValue?tableName=table1&fieldName=test1&fieldValue=123test
	//修改
	http.HandleFunc("/updateTableValue", updateTableHandler) //http://localhost:8080/updateTableValue?tableName=table1&findFieldName=test1&findFieldValue=hello&setFieldName=test3&setFieldValue=456

}

// 根路径处理器
func helloHandler(w http.ResponseWriter, r *http.Request) {
	log.Printf("访问路径:%s,来源:%s\n", r.URL.Path, r.RemoteAddr)
	fmt.Fprintf(w, "Hello, World! 👋")
}

// /time 路径处理器
func timeHandler(w http.ResponseWriter, r *http.Request) {
	log.Printf("访问路径:%s,来源:%s\n", r.URL.Path, r.RemoteAddr)
	currentTime := time.Now().Format("2006-01-02 15:04:05")
	fmt.Fprintf(w, "当前服务器时间:%s", currentTime)
}

// 修改指定表名中,find字段名等于指定值的set字段名的数据
func updateTableHandler(w http.ResponseWriter, r *http.Request) {
	// 解析请求参数
	tableName := r.URL.Query().Get("tableName")
	findFieldName := r.URL.Query().Get("findFieldName")
	findFieldValue := r.URL.Query().Get("findFieldValue")
	setFieldName := r.URL.Query().Get("setFieldName")
	setFieldValue := r.URL.Query().Get("setFieldValue")

	// 完整的参数验证
	if tableName == "" || findFieldName == "" || setFieldName == "" {
		http.Error(w, "缺少必要参数", http.StatusBadRequest)
		return
	}

	// 🔐 白名单验证 - 只允许预定义的表和字段
	allowedTables := map[string]bool{"table1": true, "table2": true}
	allowedFields := map[string]bool{
		"test1": true, "test2": true, "test3": true,
		"test4": true, "test5": true, "test6": true, "test7": true,
	}

	if !allowedTables[tableName] {
		http.Error(w, "不允许的表名", http.StatusBadRequest)
		return
	}
	if !allowedFields[findFieldName] || !allowedFields[setFieldName] {
		http.Error(w, "不允许的字段名", http.StatusBadRequest)
		return
	}

	// ✅ 使用参数化查询,表名和字段名通过白名单验证后拼接
	query := fmt.Sprintf(
		"UPDATE %s SET %s = $1 WHERE %s = $2",
		tableName, setFieldName, findFieldName,
	)

	result, err := db.Exec(context.Background(), query, setFieldValue, findFieldValue)
	if err != nil {
		http.Error(w, "更新数据失败: "+err.Error(), http.StatusInternalServerError)
		return
	}

	// 检查是否实际更新了数据
	rowsAffected := result.RowsAffected()
	if rowsAffected == 0 {
		http.Error(w, "未找到匹配的数据进行更新", http.StatusNotFound)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(map[string]string{
		"message": "success",
		"updated": fmt.Sprintf("%d 行已更新", rowsAffected),
	})

}

// 删除指定表名中,指定字段名等于指定值的数据
func deleteTableHandler(w http.ResponseWriter, r *http.Request) {
	// 解析请求参数
	tableName := r.URL.Query().Get("tableName")
	fieldName := r.URL.Query().Get("fieldName")
	fieldValue := r.URL.Query().Get("fieldValue")
	if tableName == "" || fieldName == "" || fieldValue == "" {
		http.Error(w, "参数错误", http.StatusBadRequest)
		return
	}
	// 执行 SQL 语句,使用参数化查询
	query := fmt.Sprintf("DELETE FROM %s WHERE %s = $1", tableName, fieldName)
	_, err := db.Exec(context.Background(), query, fieldValue)
	if err != nil {
		http.Error(w, "删除数据失败: "+err.Error(), http.StatusInternalServerError)
		return
	}
	// 设置响应头
	w.Header().Set("Content-Type", "application/json")
	// 返回 JSON
	json.NewEncoder(w).Encode(map[string]string{"message": "success"})
}

// 向table1表中添加数据,字段名=test1,test2,test3,test4,test5,test6,test7
func addTable1Handler(w http.ResponseWriter, r *http.Request) {
	// 定义需要插入的数据结构
	type requestData struct {
		Test1 string    `json:"test1"`
		Test2 time.Time `json:"test2"`
		Test3 uint32    `json:"test3"`
		Test4 string    `json:"test4"`
		Test5 float64   `json:"test5"`
		Test6 int32     `json:"test6"`
		Test7 float64   `json:"test7"`
	}

	// 解析请求参数
	var data requestData
	err := json.NewDecoder(r.Body).Decode(&data)
	if err != nil {
		http.Error(w, "解析请求参数失败: "+err.Error(), http.StatusBadRequest)
		return
	}

	// 执行 SQL 语句,使用参数化查询
	query := "INSERT INTO table1 (test1, test2, test3, test4, test5, test6, test7) VALUES ($1, $2, $3, $4, $5, $6, $7)"
	_, err = db.Exec(context.Background(), query, data.Test1, data.Test2, data.Test3, data.Test4, data.Test5, data.Test6, data.Test7)
	if err != nil {
		http.Error(w, "插入数据失败: "+err.Error(), http.StatusInternalServerError)
		return
	}

	// 设置响应头
	w.Header().Set("Content-Type", "application/json")

	// 返回 JSON
	json.NewEncoder(w).Encode(map[string]string{"message": "success"})
}

// 查下指定表名的全部数据
func findTableNameHandler(w http.ResponseWriter, r *http.Request) {
	tableName := r.URL.Query().Get("tableName")
	if tableName == "" {
		http.Error(w, "tableName is empty", http.StatusBadRequest)
		return
	}

	// ✅ 安全校验表名(防止 SQL 注入)
	if !isValidTableName(tableName) {
		http.Error(w, "invalid table name", http.StatusBadRequest)
		return
	}

	// ✅ 使用参数化方式拼接表名(仅限对象名,如表、字段)
	query := fmt.Sprintf("SELECT * FROM %s", tableName)

	rows, err := db.Query(context.Background(), query)
	if err != nil {
		http.Error(w, "查询失败: "+err.Error(), http.StatusInternalServerError)
		return
	}
	defer rows.Close()

	// ✅ 使用 pgx 内置工具自动转为 []map[string]interface{}
	data, err := pgx.CollectRows(rows, pgx.RowToMap)
	if err != nil {
		http.Error(w, "解析数据失败: "+err.Error(), http.StatusInternalServerError)
		return
	}

	// ✅ 设置响应头
	w.Header().Set("Content-Type", "application/json")

	// ✅ 返回 JSON
	json.NewEncoder(w).Encode(data)
}

// 安全校验表名(防止 SQL 注入)
func isValidTableName(name string) bool {
	// 只允许字母、数字、下划线,且不能以数字开头
	matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, name)
	return matched
}

tools.go

package main

//引用的包
import (
	"context"
	"database/sql"
	"fmt"
	"log"
	"strconv"
	"strings"

	"github.com/jackc/pgx/v5/pgxpool" //pgsql数据库组件
	"github.com/xuri/excelize/v2"     //解析excel文件包
)

// 定义数据库相关配置
const (
	host     = "localhost"        //数据库ip
	port     = 5432               //数据库端口
	user     = "postgres"         //数据库用户名
	password = "postgres"         //数据库密码
	dbname   = "postgresLearning" //数据库名
)

// 初始化数据库连接
func InitDB() *pgxpool.Pool {
	// 构建连接字符串
	psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
		host, port, user, password, dbname)

	// 连接数据库
	pool, err := pgxpool.New(context.Background(), psqlInfo)
	//db, err := sql.Open("postgres", psqlInfo)
	if err != nil {
		log.Fatal(err)
	}
	//defer pool.Close()  这里有一个注意点,这块代码回直接关闭数据库连接
	// 检查连接
	err = pool.Ping(context.Background())
	if err != nil {
		log.Fatal(err)
	}
	LogInfo("Successfully connected to PostgreSQL database!")
	return pool
}

// 读取Excel文件
func ReadExcel(path string, showLog bool) []Table {
	createTable := []Table{}
	// 打开Excel文件
	f, err := excelize.OpenFile(path)
	if err != nil {
		fmt.Println(err)
		return createTable
	}

	// 获取工作表名称列表
	sheetNames := f.GetSheetList()
	//遍历sheet列表
	for _, sheetName := range sheetNames {
		if showLog {
			LogInfo("开始处理%s工作表", sheetName)
		}
		itemTable := Table{
			Name: sheetName,
		}
		// 读取指定工作表的所有行
		rows, err := f.GetRows(sheetName)
		if err != nil {
			LogError("读取%s工作表失败,原因: %v", sheetName, err)
			continue
		}
		for _, row := range rows[1:] {
			itemColumns := Column{
				Name:    row[0],
				Type:    parseColumnType(row[1]),
				Length:  row[2],
				NotNull: parseBool(row[3], showLog),
				Unique:  parseBool(row[4], showLog),
				Primary: parseBool(row[5], showLog),
			}
			if len(row) > 6 {
				itemColumns.Default = row[6]
			}
			itemTable.Columns = append(itemTable.Columns, itemColumns)
			if showLog {
				// 遍历行中的单元格
				for _, colCell := range row {
					LogInfo("%s", colCell)
				}
			}
		}
		createTable = append(createTable, itemTable)
	}
	return createTable
}

// 创建数据库表
func createDBTable(db *sql.DB, table Table) {
	success, createTableSQL := CreateTable(table)
	if success {
		LogInfo("sql=%s", createTableSQL)
		_, err := db.Exec(createTableSQL)
		if err != nil {
			LogError("创建%s数据表失败,原因: %v", table.Name, err)
		} else {
			LogSuccess("创建%s数据表成功", table.Name)
		}
	} else {
		LogError("创建%s数据表失败,原因: %s", table.Name, createTableSQL)
	}
}

// ColumnItemType 定义列类型的自定义类型
type ColumnItemType string

// 支持的列类型常量
const (
	VARCHAR   ColumnItemType = "VARCHAR"
	TIMESTAMP ColumnItemType = "TIMESTAMP"
	SERIAL    ColumnItemType = "SERIAL"
	TEXT      ColumnItemType = "TEXT"
	DECIMAL   ColumnItemType = "DECIMAL"
	INT       ColumnItemType = "INT"
)

func parseColumnType(typeStr string) ColumnItemType {
	switch strings.ToUpper(typeStr) {
	case "VARCHAR":
		return VARCHAR
	case "TIMESTAMP":
		return TIMESTAMP
	case "SERIAL":
		return SERIAL
	case "TEXT":
		return TEXT
	case "DECIMAL":
		return DECIMAL
	case "INT":
		return INT
	// 处理其他可能的类型
	default:
		return ColumnItemType(typeStr) // 如果类型不在预定义的范围内,可以返回原字符串或默认值
	}
}

// 辅助函数,将字符串转换为整数,如果转换失败则打印错误信息并返回0
func mustAtoi(s string, showLog bool) int {
	i, err := strconv.Atoi(s)
	if err != nil {
		if showLog {
			LogWarning("无法解析int值,原因: %v", err)
		}
		return 0 // 或者你可以选择返回一个默认值,或者根据错误处理逻辑来决定
	}
	return i
}

// 添加一个函数来将字符串转换为布尔值
func parseBool(str string, showLog bool) bool {
	switch str {
	case "true", "1", "TRUE", "T", "Y", "YES":
		return true
	case "false", "0", "FALSE", "F", "N", "NO":
		return false
	default:
		// 你可以根据需要处理默认情况,比如记录日志或者返回一个默认值
		if showLog {
			LogWarning("无法解析布尔值,设置默认值=false, 原始值=%s", str)
		}
		return false
	}
}

// Column 定义字段结构
type Column struct {
	Name    string
	Type    ColumnItemType
	Length  string // 长度,仅对 VARCHAR、DECIMAL 等有效
	NotNull bool
	Unique  bool
	Primary bool
	Default string
}

// Table 定义表结构
type Table struct {
	Name    string
	Columns []Column
}

// CreateTable 生成 CREATE TABLE SQL 语句
func CreateTable(table Table) (bool, string) {
	if table.Name == "" {
		LogError("表名不能为空")
		return false, ""
	}
	if len(table.Columns) == 0 {
		LogError("字段列表不能为空")
		return false, ""
	}

	var fieldDefs []string

	for _, col := range table.Columns {
		if col.Name == "" {
			LogError("字段名不能为空")
			return false, ""
		}

		def := col.Name + " " + string(col.Type) // 注意:col.Type 是 ColumnItemType,需转为 string

		// 处理长度(仅对支持长度的类型)
		if len(col.Length) > 0 && (col.Type == VARCHAR || col.Type == DECIMAL) {
			// 可以根据 Type 判断是否支持 Length,例如只对 VARCHAR 和 DECIMAL 生效
			def += "(" + col.Length + ")"
		}

		// 添加约束
		if col.NotNull {
			def += " NOT NULL"
		}
		if col.Unique {
			def += " UNIQUE"
		}
		if col.Primary {
			def += " PRIMARY KEY"
		}
		if col.Default != "" {
			// 判断是否需要为 DEFAULT 值加引号
			if col.Type == TEXT || col.Type == VARCHAR {
				def += fmt.Sprintf(" DEFAULT '%s'", EscapeString(col.Default))
			} else {
				def += " DEFAULT " + col.Default
			}
		}
		fieldDefs = append(fieldDefs, def)
	}

	// 拼接完整 SQL
	sql := fmt.Sprintf(
		"CREATE TABLE IF NOT EXISTS %s (%s);",
		table.Name,
		strings.Join(fieldDefs, ", "),
	)

	return true, sql
}

// EscapeString 是一个假设的函数,用于转义SQL字符串中的特殊字符
func EscapeString(s string) string {
	// 实现对字符串s中单引号等特殊字符的转义
	return strings.ReplaceAll(s, "'", "''") // 示例:转义单引号
}

// 辅助函数:判断 DEFAULT 是否需要加引号
func isStringDefault(defaultValue string) bool {
	// 尝试将 defaultValue 转换为数字或 NULL
	_, err1 := strconv.ParseFloat(defaultValue, 64)
	_, err2 := strconv.ParseBool(defaultValue)

	// 如果 defaultValue 可以转换为数字或布尔值,或者它是 "NULL",则不需要加引号
	return !(err1 == nil || err2 == nil || strings.ToUpper(defaultValue) == "NULL")
}

// ==================================封装打印log========================================
const (
	Red    = "31"
	Green  = "32"
	Yellow = "33"
	Blue   = "34"
	Purple = "35"
	Cyan   = "36"
	White  = "37"
)

// PrintColor 打印指定颜色的文本
// colorCode: ANSI 颜色码
// format: 格式化字符串,如 "创建%s表成功"
// args: 格式化参数
func LogColor(colorCode string, format string, args ...interface{}) {
	// \033[颜色码m + 文本 + \033[0m(重置)
	colored := fmt.Sprintf("\033[%sm%s\033[0m", colorCode, fmt.Sprintf(format, args...))
	fmt.Println(colored)
}

func LogError(format string, args ...interface{}) {
	LogColor(Red, format, args...)
}
func LogInfo(format string, args ...interface{}) {
	LogColor(White, format, args...)
}
func LogWarning(format string, args ...interface{}) {
	LogColor(Yellow, format, args...)
}
func LogSuccess(format string, args ...interface{}) {
	LogColor(Green, format, args...)
}

//==================================封装打印log END========================================

二、遇到的问题
试着把数据库连接和打印日志的方法提取出来了,使用多文件调用的方式,启动程序的命令就改变了
从之前的 go run main.go 变为了 go run .


网站公告

今日签到

点亮在社区的每一天
去签到