flyfei

中间件(身份认证、日志记录、CORS跨域、请求限流)

1. 语法讲解

中间件概念:中间件是位于HTTP请求和最终处理函数之间的组件,可以拦截、处理和增强请求。类似于安检流程,每个请求都要经过多个检查点。

// 中间件基本结构
func middleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // 前置处理(请求到达时执行)
        fmt.Println("请求开始:", r.URL.Path)
        
        // 调用下一个处理程序
        next.ServeHTTP(w, r)
        
        // 后置处理(响应返回时执行)
        fmt.Println("请求结束:", r.URL.Path)
    })
}

// 使用中间件
http.Handle("/api", middleware(finalHandler))

中间件链:多个中间件可以串联,形成处理流水线:

请求 → 中间件1 → 中间件2 → ... → 业务处理 → 响应  

2. 应用场景

身份认证中间件

日志记录中间件

CORS跨域中间件

请求限流中间件

3. 编程实例

学生成绩管理系统中间件实现

package main

import (
    "encoding/json"
    "fmt"
    "log"
    "net/http"
    "strings"
    "time"
)

// 学生成绩数据结构
type Student struct {
    ID    int    `json:"id"`
    Name  string `json:"name"`
    Score int    `json:"score"`
}

var students = []Student{
    {1, "张三", 85},
    {2, "李四", 92},
    {3, "王五", 78},
}

// 1. 日志记录中间件
func loggingMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        start := time.Now()
        
        // 调用下一个处理器
        next.ServeHTTP(w, r)
        
        // 记录请求日志
        log.Printf("%s %s %s 处理时间: %v", 
            r.Method, r.URL.Path, r.RemoteAddr, time.Since(start))
    })
}

// 2. 身份认证中间件(简单的API Key验证)
func authMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        apiKey := r.Header.Get("X-API-Key")
        
        // 简单的API Key检查(实际项目中应该更复杂)
        if apiKey != "student-system-2024" {
            w.Header().Set("Content-Type", "application/json")
            w.WriteHeader(http.StatusUnauthorized)
            json.NewEncoder(w).Encode(map[string]string{
                "error": "未授权的访问,请提供有效的API Key",
            })
            return
        }
        
        next.ServeHTTP(w, r)
    })
}

// 3. CORS跨域中间件
func corsMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // 设置CORS头部
        w.Header().Set("Access-Control-Allow-Origin", "*")
        w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
        w.Header().Set("Access-Control-Allow-Headers", "Content-Type, X-API-Key")
        
        // 处理预检请求
        if r.Method == "OPTIONS" {
            w.WriteHeader(http.StatusOK)
            return
        }
        
        next.ServeHTTP(w, r)
    })
}

// 4. 请求限流中间件(简单的令牌桶算法)
type rateLimiter struct {
    tokens    chan struct{}
    resetTime time.Duration
}

func newRateLimiter(limit int, resetTime time.Duration) *rateLimiter {
    rl := &rateLimiter{
        tokens:    make(chan struct{}, limit),
        resetTime: resetTime,
    }
    
    // 初始化令牌
    for i := 0; i < limit; i++ {
        rl.tokens <- struct{}{}
    }
    
    // 定时补充令牌
    go func() {
        for range time.Tick(resetTime) {
            select {
            case rl.tokens <- struct{}{}:
            default:
                // 令牌桶已满,跳过
            }
        }
    }()
    
    return rl
}

func (rl *rateLimiter) limitMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        select {
        case <-rl.tokens:
            // 获取到令牌,继续处理
            next.ServeHTTP(w, r)
        default:
            // 令牌不足,返回限流错误
            w.Header().Set("Content-Type", "application/json")
            w.WriteHeader(http.StatusTooManyRequests)
            json.NewEncoder(w).Encode(map[string]string{
                "error": "请求过于频繁,请稍后重试",
            })
        }
    })
}

// 业务处理函数
func getStudents(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(students)
}

func addStudent(w http.ResponseWriter, r *http.Request) {
    if r.Method != "POST" {
        http.Error(w, "方法不允许", http.StatusMethodNotAllowed)
        return
    }
    
    var newStudent Student
    if err := json.NewDecoder(r.Body).Decode(&newStudent); err != nil {
        http.Error(w, "无效的请求数据", http.StatusBadRequest)
        return
    }
    
    // 简单的ID生成(实际项目应该用数据库自增ID)
    newStudent.ID = len(students) + 1
    students = append(students, newStudent)
    
    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(http.StatusCreated)
    json.NewEncoder(w).Encode(newStudent)
}

func main() {
    // 初始化限流器:每分钟10个请求
    limiter := newRateLimiter(10, time.Minute/10)
    
    // 创建路由器
    mux := http.NewServeMux()
    
    // 注册路由(应用中间件链)
    mux.Handle("/api/students", 
        loggingMiddleware(
            corsMiddleware(
                authMiddleware(
                    limiter.limitMiddleware(
                        http.HandlerFunc(getStudents))))))
    
    mux.Handle("/api/students/add", 
        loggingMiddleware(
            corsMiddleware(
                authMiddleware(
                    limiter.limitMiddleware(
                        http.HandlerFunc(addStudents))))))
    
    // 启动服务器
    fmt.Println("学生成绩管理系统启动在 http://localhost:8080")
    fmt.Println("测试时请在Header中添加: X-API-Key: student-system-2024")
    log.Fatal(http.ListenAndServe(":8080", mux))
}

4. 其他用法

中间件组合和链式调用

package main

import (
    "net/http"
    "strings"
)

// 中间件组合工具
type Middleware func(http.Handler) http.Handler

// 将多个中间件组合成一个
func applyMiddlewares(handler http.Handler, middlewares ...Middleware) http.Handler {
    for i := len(middlewares) - 1; i >= 0; i-- {
        handler = middlewares[i](handler)
    }
    return handler
}

// 实用的中间件示例

// 1. 请求ID中间件(用于请求追踪)
func requestIDMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // 生成或获取请求ID(简化版)
        requestID := fmt.Sprintf("%d", time.Now().UnixNano())
        w.Header().Set("X-Request-ID", requestID)
        next.ServeHTTP(w, r)
    })
}


// 2. 超时控制中间件
func timeoutMiddleware(timeout time.Duration) Middleware {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            // 创建超时上下文
            ctx, cancel := context.WithTimeout(r.Context(), timeout)
            defer cancel()
            
            // 使用带超时的上下文
            r = r.WithContext(ctx)
            
            // 创建响应包装器来检测超时
            done := make(chan bool, 1)
            go func() {
                next.ServeHTTP(w, r)
                done <- true
            }()
            
            select {
            case <-done:
                // 正常完成
            case <-ctx.Done():
                // 超时
                w.WriteHeader(http.StatusRequestTimeout)
                w.Write([]byte("请求超时"))
            }
        })
    }
}

// 使用组合中间件
func main() {
    mux := http.NewServeMux()
    mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        w.Write([]byte("Hello, 中间件!"))
    })
    
    // 组合多个中间件
    handler := applyMiddlewares(
        mux,
        requestIDMiddleware,
        timeoutMiddleware(30*time.Second),
    )
    
    http.ListenAndServe(":8080", handler)
}

5. 课时总结