Updated to BunRouter
This commit is contained in:
60
middleware/ratelimit.go
Normal file
60
middleware/ratelimit.go
Normal file
@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/uptrace/bunrouter"
|
||||
)
|
||||
|
||||
type rateLimiter struct {
|
||||
visitors map[string]*visitor
|
||||
mu sync.Mutex
|
||||
rps int
|
||||
}
|
||||
|
||||
type visitor struct {
|
||||
firstSeen time.Time
|
||||
requests int
|
||||
}
|
||||
|
||||
func NewRateLimiter(rps int) *rateLimiter {
|
||||
return &rateLimiter{
|
||||
visitors: make(map[string]*visitor),
|
||||
rps: rps,
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) RateLimit(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
ip, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
// handle error, e.g., return an HTTP 500 error
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
|
||||
v, exists := rl.visitors[ip]
|
||||
if !exists || time.Since(v.firstSeen) > 1*time.Minute {
|
||||
v = &visitor{
|
||||
firstSeen: time.Now(),
|
||||
}
|
||||
rl.visitors[ip] = v
|
||||
}
|
||||
|
||||
v.requests++
|
||||
|
||||
// Limit each IP to rps requests per minute
|
||||
if v.requests > rl.rps {
|
||||
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||
return nil
|
||||
}
|
||||
|
||||
return next(w, req)
|
||||
}
|
||||
}
|
41
middleware/requestid.go
Normal file
41
middleware/requestid.go
Normal file
@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/uptrace/bunrouter"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return string(c)
|
||||
}
|
||||
|
||||
var (
|
||||
HeaderXRequestID = "X-Request-ID"
|
||||
requestIDKey = contextKey("requestID")
|
||||
)
|
||||
|
||||
func RequestID(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
reqID := req.Header.Get(HeaderXRequestID)
|
||||
if reqID == "" {
|
||||
reqID = uuid.New().String()
|
||||
}
|
||||
|
||||
ctx := context.WithValue(req.Context(), requestIDKey, reqID)
|
||||
|
||||
w.Header().Set(HeaderXRequestID, reqID)
|
||||
return next(w, req.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func GetRequestID(ctx context.Context) string {
|
||||
if reqID, ok := ctx.Value(requestIDKey).(string); ok {
|
||||
return reqID
|
||||
}
|
||||
return ""
|
||||
}
|
21
middleware/secure.go
Normal file
21
middleware/secure.go
Normal file
@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/unrolled/secure"
|
||||
"github.com/uptrace/bunrouter"
|
||||
)
|
||||
|
||||
func SecureHeaders(next bunrouter.HandlerFunc) bunrouter.HandlerFunc {
|
||||
secureMiddleware := secure.New(secure.Options{
|
||||
FrameDeny: true,
|
||||
ContentTypeNosniff: true,
|
||||
BrowserXssFilter: true,
|
||||
})
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
secureMiddleware.HandlerFuncWithNext(w, req.Request, nil)
|
||||
return next(w, req)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user