一、通过http接口实现简单的增删改查
main.go
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"regexp"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
var db *pgxpool.Pool
// 启动函数
func main() {
// 初始化数据库连接
db = InitDB()
// 注册路由
RegisterRouter()
// 启动 HTTP 服务
StartHTTPSevrver()
}
// 启动 HTTP 服务
func StartHTTPSevrver() {
address := "127.0.0.1:8080" //配置连接ip端口
log.Printf("启动 HTTP 服务,监听端口 %s\n", address)
err := http.ListenAndServe(address, nil)
if err != nil {
log.Fatalf("服务器启动失败:%v", err)
}
}
// 注册路由
func RegisterRouter() {
http.HandleFunc("/", helloHandler) //http://localhost:8080/
http.HandleFunc("/time", timeHandler) //http://localhost:8080/time
//查询
http.HandleFunc("/findTable", findTableNameHandler) //http://localhost:8080/findTable?tableName=name
//添加
http.HandleFunc("/addTable1", addTable1Handler) //http://localhost:8080/addTable1
//删除
http.HandleFunc("/deleteTableValue", deleteTableHandler) //http://localhost:8080/deleteTableValue?tableName=table1&fieldName=test1&fieldValue=123test
//修改
http.HandleFunc("/updateTableValue", updateTableHandler) //http://localhost:8080/updateTableValue?tableName=table1&findFieldName=test1&findFieldValue=hello&setFieldName=test3&setFieldValue=456
}
// 根路径处理器
func helloHandler(w http.ResponseWriter, r *http.Request) {
log.Printf("访问路径:%s,来源:%s\n", r.URL.Path, r.RemoteAddr)
fmt.Fprintf(w, "Hello, World! 👋")
}
// /time 路径处理器
func timeHandler(w http.ResponseWriter, r *http.Request) {
log.Printf("访问路径:%s,来源:%s\n", r.URL.Path, r.RemoteAddr)
currentTime := time.Now().Format("2006-01-02 15:04:05")
fmt.Fprintf(w, "当前服务器时间:%s", currentTime)
}
// 修改指定表名中,find字段名等于指定值的set字段名的数据
func updateTableHandler(w http.ResponseWriter, r *http.Request) {
// 解析请求参数
tableName := r.URL.Query().Get("tableName")
findFieldName := r.URL.Query().Get("findFieldName")
findFieldValue := r.URL.Query().Get("findFieldValue")
setFieldName := r.URL.Query().Get("setFieldName")
setFieldValue := r.URL.Query().Get("setFieldValue")
// 完整的参数验证
if tableName == "" || findFieldName == "" || setFieldName == "" {
http.Error(w, "缺少必要参数", http.StatusBadRequest)
return
}
// 🔐 白名单验证 - 只允许预定义的表和字段
allowedTables := map[string]bool{"table1": true, "table2": true}
allowedFields := map[string]bool{
"test1": true, "test2": true, "test3": true,
"test4": true, "test5": true, "test6": true, "test7": true,
}
if !allowedTables[tableName] {
http.Error(w, "不允许的表名", http.StatusBadRequest)
return
}
if !allowedFields[findFieldName] || !allowedFields[setFieldName] {
http.Error(w, "不允许的字段名", http.StatusBadRequest)
return
}
// ✅ 使用参数化查询,表名和字段名通过白名单验证后拼接
query := fmt.Sprintf(
"UPDATE %s SET %s = $1 WHERE %s = $2",
tableName, setFieldName, findFieldName,
)
result, err := db.Exec(context.Background(), query, setFieldValue, findFieldValue)
if err != nil {
http.Error(w, "更新数据失败: "+err.Error(), http.StatusInternalServerError)
return
}
// 检查是否实际更新了数据
rowsAffected := result.RowsAffected()
if rowsAffected == 0 {
http.Error(w, "未找到匹配的数据进行更新", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"message": "success",
"updated": fmt.Sprintf("%d 行已更新", rowsAffected),
})
}
// 删除指定表名中,指定字段名等于指定值的数据
func deleteTableHandler(w http.ResponseWriter, r *http.Request) {
// 解析请求参数
tableName := r.URL.Query().Get("tableName")
fieldName := r.URL.Query().Get("fieldName")
fieldValue := r.URL.Query().Get("fieldValue")
if tableName == "" || fieldName == "" || fieldValue == "" {
http.Error(w, "参数错误", http.StatusBadRequest)
return
}
// 执行 SQL 语句,使用参数化查询
query := fmt.Sprintf("DELETE FROM %s WHERE %s = $1", tableName, fieldName)
_, err := db.Exec(context.Background(), query, fieldValue)
if err != nil {
http.Error(w, "删除数据失败: "+err.Error(), http.StatusInternalServerError)
return
}
// 设置响应头
w.Header().Set("Content-Type", "application/json")
// 返回 JSON
json.NewEncoder(w).Encode(map[string]string{"message": "success"})
}
// 向table1表中添加数据,字段名=test1,test2,test3,test4,test5,test6,test7
func addTable1Handler(w http.ResponseWriter, r *http.Request) {
// 定义需要插入的数据结构
type requestData struct {
Test1 string `json:"test1"`
Test2 time.Time `json:"test2"`
Test3 uint32 `json:"test3"`
Test4 string `json:"test4"`
Test5 float64 `json:"test5"`
Test6 int32 `json:"test6"`
Test7 float64 `json:"test7"`
}
// 解析请求参数
var data requestData
err := json.NewDecoder(r.Body).Decode(&data)
if err != nil {
http.Error(w, "解析请求参数失败: "+err.Error(), http.StatusBadRequest)
return
}
// 执行 SQL 语句,使用参数化查询
query := "INSERT INTO table1 (test1, test2, test3, test4, test5, test6, test7) VALUES ($1, $2, $3, $4, $5, $6, $7)"
_, err = db.Exec(context.Background(), query, data.Test1, data.Test2, data.Test3, data.Test4, data.Test5, data.Test6, data.Test7)
if err != nil {
http.Error(w, "插入数据失败: "+err.Error(), http.StatusInternalServerError)
return
}
// 设置响应头
w.Header().Set("Content-Type", "application/json")
// 返回 JSON
json.NewEncoder(w).Encode(map[string]string{"message": "success"})
}
// 查下指定表名的全部数据
func findTableNameHandler(w http.ResponseWriter, r *http.Request) {
tableName := r.URL.Query().Get("tableName")
if tableName == "" {
http.Error(w, "tableName is empty", http.StatusBadRequest)
return
}
// ✅ 安全校验表名(防止 SQL 注入)
if !isValidTableName(tableName) {
http.Error(w, "invalid table name", http.StatusBadRequest)
return
}
// ✅ 使用参数化方式拼接表名(仅限对象名,如表、字段)
query := fmt.Sprintf("SELECT * FROM %s", tableName)
rows, err := db.Query(context.Background(), query)
if err != nil {
http.Error(w, "查询失败: "+err.Error(), http.StatusInternalServerError)
return
}
defer rows.Close()
// ✅ 使用 pgx 内置工具自动转为 []map[string]interface{}
data, err := pgx.CollectRows(rows, pgx.RowToMap)
if err != nil {
http.Error(w, "解析数据失败: "+err.Error(), http.StatusInternalServerError)
return
}
// ✅ 设置响应头
w.Header().Set("Content-Type", "application/json")
// ✅ 返回 JSON
json.NewEncoder(w).Encode(data)
}
// 安全校验表名(防止 SQL 注入)
func isValidTableName(name string) bool {
// 只允许字母、数字、下划线,且不能以数字开头
matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, name)
return matched
}
tools.go
package main
//引用的包
import (
"context"
"database/sql"
"fmt"
"log"
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgxpool" //pgsql数据库组件
"github.com/xuri/excelize/v2" //解析excel文件包
)
// 定义数据库相关配置
const (
host = "localhost" //数据库ip
port = 5432 //数据库端口
user = "postgres" //数据库用户名
password = "postgres" //数据库密码
dbname = "postgresLearning" //数据库名
)
// 初始化数据库连接
func InitDB() *pgxpool.Pool {
// 构建连接字符串
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
host, port, user, password, dbname)
// 连接数据库
pool, err := pgxpool.New(context.Background(), psqlInfo)
//db, err := sql.Open("postgres", psqlInfo)
if err != nil {
log.Fatal(err)
}
//defer pool.Close() 这里有一个注意点,这块代码回直接关闭数据库连接
// 检查连接
err = pool.Ping(context.Background())
if err != nil {
log.Fatal(err)
}
LogInfo("Successfully connected to PostgreSQL database!")
return pool
}
// 读取Excel文件
func ReadExcel(path string, showLog bool) []Table {
createTable := []Table{}
// 打开Excel文件
f, err := excelize.OpenFile(path)
if err != nil {
fmt.Println(err)
return createTable
}
// 获取工作表名称列表
sheetNames := f.GetSheetList()
//遍历sheet列表
for _, sheetName := range sheetNames {
if showLog {
LogInfo("开始处理%s工作表", sheetName)
}
itemTable := Table{
Name: sheetName,
}
// 读取指定工作表的所有行
rows, err := f.GetRows(sheetName)
if err != nil {
LogError("读取%s工作表失败,原因: %v", sheetName, err)
continue
}
for _, row := range rows[1:] {
itemColumns := Column{
Name: row[0],
Type: parseColumnType(row[1]),
Length: row[2],
NotNull: parseBool(row[3], showLog),
Unique: parseBool(row[4], showLog),
Primary: parseBool(row[5], showLog),
}
if len(row) > 6 {
itemColumns.Default = row[6]
}
itemTable.Columns = append(itemTable.Columns, itemColumns)
if showLog {
// 遍历行中的单元格
for _, colCell := range row {
LogInfo("%s", colCell)
}
}
}
createTable = append(createTable, itemTable)
}
return createTable
}
// 创建数据库表
func createDBTable(db *sql.DB, table Table) {
success, createTableSQL := CreateTable(table)
if success {
LogInfo("sql=%s", createTableSQL)
_, err := db.Exec(createTableSQL)
if err != nil {
LogError("创建%s数据表失败,原因: %v", table.Name, err)
} else {
LogSuccess("创建%s数据表成功", table.Name)
}
} else {
LogError("创建%s数据表失败,原因: %s", table.Name, createTableSQL)
}
}
// ColumnItemType 定义列类型的自定义类型
type ColumnItemType string
// 支持的列类型常量
const (
VARCHAR ColumnItemType = "VARCHAR"
TIMESTAMP ColumnItemType = "TIMESTAMP"
SERIAL ColumnItemType = "SERIAL"
TEXT ColumnItemType = "TEXT"
DECIMAL ColumnItemType = "DECIMAL"
INT ColumnItemType = "INT"
)
func parseColumnType(typeStr string) ColumnItemType {
switch strings.ToUpper(typeStr) {
case "VARCHAR":
return VARCHAR
case "TIMESTAMP":
return TIMESTAMP
case "SERIAL":
return SERIAL
case "TEXT":
return TEXT
case "DECIMAL":
return DECIMAL
case "INT":
return INT
// 处理其他可能的类型
default:
return ColumnItemType(typeStr) // 如果类型不在预定义的范围内,可以返回原字符串或默认值
}
}
// 辅助函数,将字符串转换为整数,如果转换失败则打印错误信息并返回0
func mustAtoi(s string, showLog bool) int {
i, err := strconv.Atoi(s)
if err != nil {
if showLog {
LogWarning("无法解析int值,原因: %v", err)
}
return 0 // 或者你可以选择返回一个默认值,或者根据错误处理逻辑来决定
}
return i
}
// 添加一个函数来将字符串转换为布尔值
func parseBool(str string, showLog bool) bool {
switch str {
case "true", "1", "TRUE", "T", "Y", "YES":
return true
case "false", "0", "FALSE", "F", "N", "NO":
return false
default:
// 你可以根据需要处理默认情况,比如记录日志或者返回一个默认值
if showLog {
LogWarning("无法解析布尔值,设置默认值=false, 原始值=%s", str)
}
return false
}
}
// Column 定义字段结构
type Column struct {
Name string
Type ColumnItemType
Length string // 长度,仅对 VARCHAR、DECIMAL 等有效
NotNull bool
Unique bool
Primary bool
Default string
}
// Table 定义表结构
type Table struct {
Name string
Columns []Column
}
// CreateTable 生成 CREATE TABLE SQL 语句
func CreateTable(table Table) (bool, string) {
if table.Name == "" {
LogError("表名不能为空")
return false, ""
}
if len(table.Columns) == 0 {
LogError("字段列表不能为空")
return false, ""
}
var fieldDefs []string
for _, col := range table.Columns {
if col.Name == "" {
LogError("字段名不能为空")
return false, ""
}
def := col.Name + " " + string(col.Type) // 注意:col.Type 是 ColumnItemType,需转为 string
// 处理长度(仅对支持长度的类型)
if len(col.Length) > 0 && (col.Type == VARCHAR || col.Type == DECIMAL) {
// 可以根据 Type 判断是否支持 Length,例如只对 VARCHAR 和 DECIMAL 生效
def += "(" + col.Length + ")"
}
// 添加约束
if col.NotNull {
def += " NOT NULL"
}
if col.Unique {
def += " UNIQUE"
}
if col.Primary {
def += " PRIMARY KEY"
}
if col.Default != "" {
// 判断是否需要为 DEFAULT 值加引号
if col.Type == TEXT || col.Type == VARCHAR {
def += fmt.Sprintf(" DEFAULT '%s'", EscapeString(col.Default))
} else {
def += " DEFAULT " + col.Default
}
}
fieldDefs = append(fieldDefs, def)
}
// 拼接完整 SQL
sql := fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s (%s);",
table.Name,
strings.Join(fieldDefs, ", "),
)
return true, sql
}
// EscapeString 是一个假设的函数,用于转义SQL字符串中的特殊字符
func EscapeString(s string) string {
// 实现对字符串s中单引号等特殊字符的转义
return strings.ReplaceAll(s, "'", "''") // 示例:转义单引号
}
// 辅助函数:判断 DEFAULT 是否需要加引号
func isStringDefault(defaultValue string) bool {
// 尝试将 defaultValue 转换为数字或 NULL
_, err1 := strconv.ParseFloat(defaultValue, 64)
_, err2 := strconv.ParseBool(defaultValue)
// 如果 defaultValue 可以转换为数字或布尔值,或者它是 "NULL",则不需要加引号
return !(err1 == nil || err2 == nil || strings.ToUpper(defaultValue) == "NULL")
}
// ==================================封装打印log========================================
const (
Red = "31"
Green = "32"
Yellow = "33"
Blue = "34"
Purple = "35"
Cyan = "36"
White = "37"
)
// PrintColor 打印指定颜色的文本
// colorCode: ANSI 颜色码
// format: 格式化字符串,如 "创建%s表成功"
// args: 格式化参数
func LogColor(colorCode string, format string, args ...interface{}) {
// \033[颜色码m + 文本 + \033[0m(重置)
colored := fmt.Sprintf("\033[%sm%s\033[0m", colorCode, fmt.Sprintf(format, args...))
fmt.Println(colored)
}
func LogError(format string, args ...interface{}) {
LogColor(Red, format, args...)
}
func LogInfo(format string, args ...interface{}) {
LogColor(White, format, args...)
}
func LogWarning(format string, args ...interface{}) {
LogColor(Yellow, format, args...)
}
func LogSuccess(format string, args ...interface{}) {
LogColor(Green, format, args...)
}
//==================================封装打印log END========================================
二、遇到的问题
试着把数据库连接和打印日志的方法提取出来了,使用多文件调用的方式,启动程序的命令就改变了
从之前的 go run main.go 变为了 go run .