// cmd/load-test/main.go
package main

import (
	"encoding/json"
	"flag"
	"fmt"
	"log"
	"sync"
	"sync/atomic"
	"time"

	"github.com/gorilla/websocket"
)

type Stats struct {
	Connected    int64
	Disconnected int64
	MessagesRecv int64
	Errors       int64
	TotalLatency int64
	MessageCount int64
}

var stats Stats

func connectClient(id int, url string, duration time.Duration, wg *sync.WaitGroup) {
	defer wg.Done()

	conn, _, err := websocket.DefaultDialer.Dial(url, nil)
	if err != nil {
		atomic.AddInt64(&stats.Errors, 1)
		log.Printf("Client %d: Connection failed: %v", id, err)
		return
	}
	defer conn.Close()

	atomic.AddInt64(&stats.Connected, 1)
	log.Printf("Client %d: Connected", id)

	done := make(chan struct{})
	messageCount := 0

	// Read messages
	go func() {
		defer close(done)
		for {
			_, message, err := conn.ReadMessage()
			if err != nil {
				if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
					log.Printf("Client %d: Read error: %v", id, err)
				}
				return
			}

			atomic.AddInt64(&stats.MessagesRecv, 1)
			messageCount++

			var data map[string]interface{}
			if err := json.Unmarshal(message, &data); err == nil {
				if messageCount%10 == 0 {
					log.Printf("Client %d: Received %d messages", id, messageCount)
				}
			}
		}
	}()

	// Keep connection alive for duration
	timer := time.NewTimer(duration)
	ticker := time.NewTicker(30 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-done:
			atomic.AddInt64(&stats.Disconnected, 1)
			return
		case <-timer.C:
			log.Printf("Client %d: Test duration completed, closing", id)
			conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
			atomic.AddInt64(&stats.Disconnected, 1)
			return
		case <-ticker.C:
			// Send ping to keep alive
			if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
				log.Printf("Client %d: Ping failed: %v", id, err)
				return
			}
		}
	}
}

func printStats(interval time.Duration, stopChan chan struct{}) {
	ticker := time.NewTicker(interval)
	defer ticker.Stop()

	startTime := time.Now()

	for {
		select {
		case <-ticker.C:
			elapsed := time.Since(startTime).Seconds()
			connected := atomic.LoadInt64(&stats.Connected)
			disconnected := atomic.LoadInt64(&stats.Disconnected)
			messages := atomic.LoadInt64(&stats.MessagesRecv)
			errors := atomic.LoadInt64(&stats.Errors)

			fmt.Printf("\n=== Stats (%.0fs) ===\n", elapsed)
			fmt.Printf("Connected:    %d\n", connected)
			fmt.Printf("Disconnected: %d\n", disconnected)
			fmt.Printf("Active:       %d\n", connected-disconnected)
			fmt.Printf("Messages:     %d (%.2f msg/s)\n", messages, float64(messages)/elapsed)
			fmt.Printf("Errors:       %d\n", errors)
			fmt.Println("==================")

		case <-stopChan:
			return
		}
	}
}

func main() {
	clients := flag.Int("clients", 100, "Number of concurrent clients")
	duration := flag.Duration("duration", 60*time.Second, "Test duration")
	url := flag.String("url", "ws://localhost:8080/ws?limit=20", "WebSocket URL")
	rampUp := flag.Duration("rampup", 10*time.Second, "Ramp-up time to start all clients")
	flag.Parse()

	log.Printf("Starting load test:")
	log.Printf("  Clients: %d", *clients)
	log.Printf("  Duration: %s", *duration)
	log.Printf("  URL: %s", *url)
	log.Printf("  Ramp-up: %s", *rampUp)

	var wg sync.WaitGroup
	stopStats := make(chan struct{})

	// Start stats printer
	go printStats(5*time.Second, stopStats)

	// Calculate delay between client starts
	delayBetweenClients := *rampUp / time.Duration(*clients)

	// Start clients with ramp-up
	startTime := time.Now()
	for i := 1; i <= *clients; i++ {
		wg.Add(1)
		go connectClient(i, *url, *duration, &wg)

		if i < *clients {
			time.Sleep(delayBetweenClients)
		}
	}

	log.Printf("All %d clients started in %s", *clients, time.Since(startTime))

	// Wait for all clients to finish
	wg.Wait()
	close(stopStats)

	// Final stats
	fmt.Println("\n=== Final Results ===")
	fmt.Printf("Total Connected:    %d\n", atomic.LoadInt64(&stats.Connected))
	fmt.Printf("Total Disconnected: %d\n", atomic.LoadInt64(&stats.Disconnected))
	fmt.Printf("Total Messages:     %d\n", atomic.LoadInt64(&stats.MessagesRecv))
	fmt.Printf("Total Errors:       %d\n", atomic.LoadInt64(&stats.Errors))
	fmt.Println("====================")
}
