Rate limiter with Redis and Golang
I. Introduction
Rate limiter is a mechanism that controls the number of requests or tasks performed within a certain period. It helps prevent excessive resource usage or denial of service (DDoS) attacks. When the number of requests exceeds a set limit, subsequent requests may be denied or delayed until the limit is reset. This ensures that the system operates stably and fairly for all users.
Some algorithm applies to implement Rate limiter
- Leaky Bucket
- Fixed Window Counter
- Sliding Window Log
- Sliding Window Counter
Now, there are a lot of services that support and provide for you to config Rate limiter on your website as soon as
Today, I will guide you through the implementation Rate limiter with Redis, Golang, and Fixed Window Counter algorithm:
- Concept: The Fixed Window Counter algorithm counts the number of requests in a fixed period, called a "window". For example, you can define a window of 1 minute.
- Requests: Every time a request arrives, the system checks the number of requests made in the current window. The new request will be rejected if the number of requests exceeds the specified limit during that period. The window is reset every fixed period.
II. Implement
Now, let's implement with Golang
Init the Redis connection
func initRedis(redisUrl string) (*redis.Client, error) {
opts, err := redis.ParseURL(redisUrl)
if err != nil {
log.Fatal("failed to connect redis:", err)
return nil, nil
}
opts.PoolSize = 30
opts.ReadTimeout = 5 * time.Second
opts.WriteTimeout = 5 * time.Second
opts.Username = ""
redisClient := redis.NewClient(opts)
cmd := redisClient.Ping(context.Background())
if cmd.Err() != nil {
log.Fatal("failed to ping redis: ", cmd.Err())
return nil, nil
}
return redisClient, nil
}
The function gets IP from the request. I will limit the number of requests in a minute following client's IP with 60 requests / 1 minute
func getIPFromRequest(r *http.Request) string {
ips := r.Header.Get("X-Forwarded-For")
ipList := strings.Split(ips, ",")
for _, ip := range ipList {
if ip = strings.TrimSpace(ip); ip != "" && ip != "::1" && ip != "127.0.0.1" {
return ip
}
}
ip := r.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
Middleware to check rate limit. I will write a Lua script with Redis. I will increase one unit when the request is accessed and check if the counter is greater than the maximum number of requests in a minute then return "not pass" or else "pass". If the request passes the middleware then It forwards it to the service to process and return status code 200 else it returns status code 429 for too many requests.
func (h *HandlerAPI) RateLimiter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Origin, Authorization")
w.Header().Set("Access-Control-Allow-Credentials", "true")
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// rate limit
ip := getIPFromRequest(r)
script := `
local currentCount = tonumber(redis.call('GET', KEYS[1]) or '0')
if currentCount == 0 then
redis.call('SET', KEYS[1], 0, 'EX', ARGV[1])
end
redis.call('SET', KEYS[1], currentCount + 1, 'KEEPTTL')
if currentCount > tonumber(ARGV[2]) then
return "not pass"
else
return "pass"
end`
// Running Lua Script
resultStr, err := h.RedisClient.Eval(context.Background(), script, []string{ip}, 60, MaxRequestOneMinute).Result()
if err != nil {
logrus.Warnf("Running Lua Script is failed with err: %v", err)
return
}
if resultStr == "not pass" {
w.WriteHeader(http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
III. Test
Test handler will return "Hello Viet Nam" with status code 200
func (h *HandlerAPI) testHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusOK)
_, err := fmt.Fprintf(w, "Hello Viet Nam")
if err != nil {
return
}
}
The main goroutine
func main() {
redisClient, err := initRedis("redis://default:@localhost:6379")
if err != nil {
panic("failed to init redis")
}
handler := HandlerAPI{
RedisClient: redisClient,
}
mux := http.NewServeMux()
mux.Handle("/test", handler.RateLimiter(http.HandlerFunc(handler.testHandler)))
// Start the server
log.Fatal(http.ListenAndServe(":3000", mux))
}
I will write the function to test the Rate limiter. I test to call 100 current request.
func TestHandlerAPI_testHandler(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
callAPI()
}()
}
wg.Wait()
fmt.Println("finish")
}
func callAPI() {
url := "http://localhost:3000/test"
method := "GET"
client := &http.Client{}
req, err := http.NewRequest(method, url, nil)
if err != nil {
fmt.Println(err)
return
}
res, err := client.Do(req)
if err != nil {
fmt.Println(err)
return
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
return
}
}(res.Body)
fmt.Println(res.StatusCode)
}
IV. Result
V. Reference
All rights reserved