go中间件学习

发布于:2025-03-14 ⋅ 阅读:(11) ⋅ 点赞:(0)

本博文源于笔者正在学习go中间件,罗列了较为常用的中间件,例如日志记录、认证授权、跨域资源共享、请求体解析、静态文件处理、错误处理、性能分析、速率限制、session

1、日志记录中间件

可以追加打印用,例如,将请求进行打印

func logMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fmt.Printf("Request: %s %s\n", r.Method, r.URL)
		next.ServeHTTP(w, r)
	})
}

2、认证和授权中间件

对请求进行认证。

func authMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.Header.Get("API-Key") != "my-secret-key" {
			http.Error(w, "Unauthorized", http.StatusUnauthorized)
			return
		}
		next.ServeHTTP(w, r)
	})
}

3、跨域资源共享中间件

import "github.com/rs/cors"

func handler(w http.ResponseWriter, r *http.Request) {
	w.Write([]byte("Hello World"))
}
func main() {
	fmt.Println("http://127.0.0.1:8082/")
	mux := http.NewServeMux()
	mux.HandleFunc("/", handler)

	c := cors.New(cors.Options{
		AllowedOrigins: []string{"*"},
		AllowedMethods: []string{"GET", "POST", "PUT", "DELETE"},
		AllowedHeaders: []string{"Content-Type"},
	})
	handlerWithCors := c.Handler(mux)
	log.Fatal(http.ListenAndServe(":8082", handlerWithCors))
}

4、请求体解析中间件

对请求体进行解析

package main

import (
	"encoding/json"
	"fmt"
	"log"
	"net/http"
)

type Data struct {
	Name  string `json:"name"`
	Value string `json:"value"`
}

func handler(w http.ResponseWriter, r *http.Request) {
	var data Data
	err := json.NewDecoder(r.Body).Decode(&data)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}
	response := map[string]interface{}{
		"message": "Hello " + data.Name + "!",
		"data":    data,
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(response)
}

func main() {
	fmt.Println("http://127.0.0.1:3000/api/data")
	//通过json包解析请求体中的json数据,并将其
	http.HandleFunc("/api/data", handler)
	log.Fatal(http.ListenAndServe(":3000", nil))
}

5、静态文件处理中间件

对服务器下的文件进行请求

package main

import (
	"fmt"
	"log"
	"net/http"
)

func main() {
	//提供static文件夹中的静态文件,用户可以通过访问下面的网站进行访问资源
	fmt.Println("http://127.0.0.1:3001/static/demo.png")
	fs := http.FileServer(http.Dir("static"))
	http.Handle("/static/", http.StripPrefix("/static/", fs))
	log.Fatal(http.ListenAndServe(":3001", nil))
}

6、错误处理中间件

遇到的错误可以方便进行recover处理

package main

import (
	"fmt"
	"log"
	"net/http"
)

// 错误处理中间件
func errorHandlingMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		defer func() {
			if err := recover(); err != nil {
				http.Error(w, fmt.Sprintf("Internal Server Error: %v", err), http.StatusInternalServerError)
			}
		}()
		next.ServeHTTP(w, r)
	})
}

func handler(w http.ResponseWriter, r *http.Request) {
	panic("Something went wrong!") // 模拟错误
}

func main() {
	//通过panic与recover捕捉错误
	fmt.Println("http://127.0.0.1:3002/panic")
	mux := http.NewServeMux()
	mux.HandleFunc("/panic", handler)

	http.Handle("/", errorHandlingMiddleware(mux))

	log.Fatal(http.ListenAndServe(":3002", nil))
}

7、性能分析中间件

时间运行时统计的一个中间件

package main

import (
	"fmt"
	"log"
	"net/http"
	"time"
)

func performanceMonitoringMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		start := time.Now()
		next.ServeHTTP(w, r)
		duration := time.Since(start)
		fmt.Printf("Request took %v\n", duration)
	})
}

func handler(w http.ResponseWriter, r *http.Request) {
	time.Sleep(2 * time.Second)
	w.Write([]byte("Hello World"))
}
func main() {
	//记录每个请求的处理时间,并输出在控制台。
	fmt.Println("http://127.0.0.1:3003")
	mux := http.NewServeMux()
	mux.HandleFunc("/", handler)
	http.Handle("/", performanceMonitoringMiddleware(mux))
	log.Fatal(http.ListenAndServe(":3003", nil))
}

8、缓存中间件

对缓存进行redis写入

package main

import (
	"context"
	"fmt"
	"github.com/go-redis/redis/v8"
	"log"
	"net/http"
	"time"
)

var rdb *redis.Client

func initRedis() {
	rdb = redis.NewClient(&redis.Options{
		Addr: "localhost:6379",
	})
}

func cacheMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		key := r.URL.Path
		cachedValue, err := rdb.Get(context.Background(), key).Result()
		fmt.Println("!!!!!!", cachedValue, "??", err, "????", key)
		if err == nil {
			fmt.Println("Cache hit!")
			w.Write([]byte(cachedValue))
			return
		}
		fmt.Println("Cache miss.")
		next.ServeHTTP(w, r)
		rdb.Set(context.Background(), key, "This is the cached response", 60*time.Second)
	})
}

func handler(w http.ResponseWriter, r *http.Request) {
	time.Sleep(1 * time.Second)
	w.Write([]byte("hello from the server"))
}

func main() {
	initRedis()
	fmt.Println("http://127.0.0.1:3005")
	mux := http.NewServeMux()
	mux.HandleFunc("/", handler)
	http.Handle("/", cacheMiddleware(mux))
	log.Fatal(http.ListenAndServe(":3005", nil))
}

9、速率限制中间件

对速率限制,即对请求次数进行统计

package main

import (
	"context"
	"fmt"
	"github.com/go-redis/redis/v8"
	"log"
	"net/http"
	"time"
)

var rdb *redis.Client

func initRedis() {
	rdb = redis.NewClient(&redis.Options{
		Addr: "localhost:6379",
	})
}

func rateLimitMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		ip := r.RemoteAddr
		key := "rate_limit:" + ip

		count, err := rdb.Get(context.Background(), key).Int()
		if err != nil && err != redis.Nil {
			http.Error(w, "Internal Server Error", http.StatusInternalServerError)
			return
		}
		if count >= 5 {
			http.Error(w, "Too many requests", http.StatusTooManyRequests)
			return
		}
		rdb.Incr(context.Background(), key)
		rdb.Expire(context.Background(), key, 1*time.Minute)
		next.ServeHTTP(w, r)
	})
}

func handler(w http.ResponseWriter, r *http.Request) {
	w.Write([]byte("hello from the server"))
}

func main() {
	initRedis()
	fmt.Println("http://127.0.0.1:3005")
	mux := http.NewServeMux()
	mux.HandleFunc("/", handler)
	http.Handle("/", rateLimitMiddleware(mux))
	log.Fatal(http.ListenAndServe(":3005", nil))
}

10、session中间件

通过第三方包来管理用户会话,中间件检查用户是否已通过身份验证,只有通过验证的用户才可以访问保护的页面

package main

import (
	"fmt"
	"log"
	"net/http"
)

import "github.com/gorilla/sessions"

//通过三包来管理用户会话,中间件检查用户是否已通过身份验证,只有通过验证的用户才可以访问保护的页面

var store = sessions.NewCookieStore([]byte("my-secret-key"))

func sessionMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		session, _ := store.Get(r, "session-name")
		//检查session是否存在
		if session.Values["authenticated"] == true {
			next.ServeHTTP(w, r)
		} else {
			http.Error(w, "Forbidden", http.StatusForbidden)
		}
	})
}

func loginHandler(w http.ResponseWriter, r *http.Request) {
	session, _ := store.Get(r, "session-name")
	session.Values["authenticated"] = true
	session.Save(r, w)
	fmt.Fprintln(w, "You are logged in!")

}

func main() {
	fmt.Println("http://127.0.0.1:8082/login")
	fmt.Println("http://127.0.0.1:8082/protected")

	mux := http.NewServeMux()
	mux.HandleFunc("/login", loginHandler)
	mux.Handle("/protected", sessionMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fmt.Fprintln(w, "You have access to this page")
	})))
	log.Fatal(http.ListenAndServe(":8082", mux))
}

11、请求数据验证中间件

对数据进行验证

package main

import (
	"encoding/json"
	"fmt"
	"github.com/go-playground/validator/v10"
	"log"
	"net/http"
)

// 用户结构体,用于验证
type User struct {
	Name  string `json:"name" validate:"required"`
	Email string `json:"email" validate:"required,email"`
}

// 数据验证中间件
func validationMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		// 解析请求体中的 JSON 数据
		var user User
		decoder := json.NewDecoder(r.Body)
		if err := decoder.Decode(&user); err != nil {
			http.Error(w, "Invalid request body", http.StatusBadRequest)
			return
		}

		// 验证数据
		validate := validator.New()
		if err := validate.Struct(user); err != nil {
			http.Error(w, fmt.Sprintf("Validation failed: %v", err), http.StatusBadRequest)
			return
		}

		// 验证通过,继续处理请求
		next.ServeHTTP(w, r)
	})
}

func handler(w http.ResponseWriter, r *http.Request) {
	w.Write([]byte("User data is valid!"))
}

func main() {
	mux := http.NewServeMux()
	mux.HandleFunc("/user", handler)

	// 包装请求处理器,加入数据验证中间件
	http.Handle("/user", validationMiddleware(mux))

	log.Fatal(http.ListenAndServe(":3000", nil))
}

12、压缩中间件

对请求头进行查阅,观察是否需要压缩

package main

import (
	"compress/gzip"
	"fmt"
	"log"
	"net/http"
	"strings"
)

func compressionMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
			w.Header().Set("Content-Encoding", "gzip")
			gz := gzip.NewWriter(w)
			defer gz.Close()
			gzw := &gzipResponseWriter{Writer: gz, ResponseWriter: w}
			next.ServeHTTP(gzw, r)
		} else {
			next.ServeHTTP(w, r)
		}
	})
}

type gzipResponseWriter struct {
	http.ResponseWriter
	Writer *gzip.Writer
}

func (g *gzipResponseWriter) Write(b []byte) (int, error) {
	return g.Writer.Write(b)
}

func handler(w http.ResponseWriter, r *http.Request) {
	responseText := "This is a test response that will be compressed if the client supports Gzip.\n"
	for i := 0; i < 5; i++ {
		responseText += responseText
	}
	w.Write([]byte(responseText))
}

func main() {
	fmt.Println("http://127.0.0.1:8085/")
	mux := http.NewServeMux()
	mux.HandleFunc("/", handler)
	http.Handle("/", compressionMiddleware(mux))
	log.Fatal(http.ListenAndServe(":8085", nil))
}