中间件概念:中间件是位于HTTP请求和最终处理函数之间的组件,可以拦截、处理和增强请求。类似于安检流程,每个请求都要经过多个检查点。
// 中间件基本结构
func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 前置处理(请求到达时执行)
fmt.Println("请求开始:", r.URL.Path)
// 调用下一个处理程序
next.ServeHTTP(w, r)
// 后置处理(响应返回时执行)
fmt.Println("请求结束:", r.URL.Path)
})
}
// 使用中间件
http.Handle("/api", middleware(finalHandler))
中间件链:多个中间件可以串联,形成处理流水线:
请求 → 中间件1 → 中间件2 → ... → 业务处理 → 响应
身份认证中间件:
日志记录中间件:
CORS跨域中间件:
请求限流中间件:
学生成绩管理系统中间件实现:
package main
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
)
// 学生成绩数据结构
type Student struct {
ID int `json:"id"`
Name string `json:"name"`
Score int `json:"score"`
}
var students = []Student{
{1, "张三", 85},
{2, "李四", 92},
{3, "王五", 78},
}
// 1. 日志记录中间件
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 调用下一个处理器
next.ServeHTTP(w, r)
// 记录请求日志
log.Printf("%s %s %s 处理时间: %v",
r.Method, r.URL.Path, r.RemoteAddr, time.Since(start))
})
}
// 2. 身份认证中间件(简单的API Key验证)
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get("X-API-Key")
// 简单的API Key检查(实际项目中应该更复杂)
if apiKey != "student-system-2024" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{
"error": "未授权的访问,请提供有效的API Key",
})
return
}
next.ServeHTTP(w, r)
})
}
// 3. CORS跨域中间件
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 设置CORS头部
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-API-Key")
// 处理预检请求
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// 4. 请求限流中间件(简单的令牌桶算法)
type rateLimiter struct {
tokens chan struct{}
resetTime time.Duration
}
func newRateLimiter(limit int, resetTime time.Duration) *rateLimiter {
rl := &rateLimiter{
tokens: make(chan struct{}, limit),
resetTime: resetTime,
}
// 初始化令牌
for i := 0; i < limit; i++ {
rl.tokens <- struct{}{}
}
// 定时补充令牌
go func() {
for range time.Tick(resetTime) {
select {
case rl.tokens <- struct{}{}:
default:
// 令牌桶已满,跳过
}
}
}()
return rl
}
func (rl *rateLimiter) limitMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-rl.tokens:
// 获取到令牌,继续处理
next.ServeHTTP(w, r)
default:
// 令牌不足,返回限流错误
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]string{
"error": "请求过于频繁,请稍后重试",
})
}
})
}
// 业务处理函数
func getStudents(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(students)
}
func addStudent(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "方法不允许", http.StatusMethodNotAllowed)
return
}
var newStudent Student
if err := json.NewDecoder(r.Body).Decode(&newStudent); err != nil {
http.Error(w, "无效的请求数据", http.StatusBadRequest)
return
}
// 简单的ID生成(实际项目应该用数据库自增ID)
newStudent.ID = len(students) + 1
students = append(students, newStudent)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(newStudent)
}
func main() {
// 初始化限流器:每分钟10个请求
limiter := newRateLimiter(10, time.Minute/10)
// 创建路由器
mux := http.NewServeMux()
// 注册路由(应用中间件链)
mux.Handle("/api/students",
loggingMiddleware(
corsMiddleware(
authMiddleware(
limiter.limitMiddleware(
http.HandlerFunc(getStudents))))))
mux.Handle("/api/students/add",
loggingMiddleware(
corsMiddleware(
authMiddleware(
limiter.limitMiddleware(
http.HandlerFunc(addStudents))))))
// 启动服务器
fmt.Println("学生成绩管理系统启动在 http://localhost:8080")
fmt.Println("测试时请在Header中添加: X-API-Key: student-system-2024")
log.Fatal(http.ListenAndServe(":8080", mux))
}
中间件组合和链式调用:
package main
import (
"net/http"
"strings"
)
// 中间件组合工具
type Middleware func(http.Handler) http.Handler
// 将多个中间件组合成一个
func applyMiddlewares(handler http.Handler, middlewares ...Middleware) http.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
// 实用的中间件示例
// 1. 请求ID中间件(用于请求追踪)
func requestIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 生成或获取请求ID(简化版)
requestID := fmt.Sprintf("%d", time.Now().UnixNano())
w.Header().Set("X-Request-ID", requestID)
next.ServeHTTP(w, r)
})
}
// 2. 超时控制中间件
func timeoutMiddleware(timeout time.Duration) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 创建超时上下文
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
// 使用带超时的上下文
r = r.WithContext(ctx)
// 创建响应包装器来检测超时
done := make(chan bool, 1)
go func() {
next.ServeHTTP(w, r)
done <- true
}()
select {
case <-done:
// 正常完成
case <-ctx.Done():
// 超时
w.WriteHeader(http.StatusRequestTimeout)
w.Write([]byte("请求超时"))
}
})
}
}
// 使用组合中间件
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, 中间件!"))
})
// 组合多个中间件
handler := applyMiddlewares(
mux,
requestIDMiddleware,
timeoutMiddleware(30*time.Second),
)
http.ListenAndServe(":8080", handler)
}