package stats import ( "log" "sort" "sync" "time" // "gonum.org/v1/gonum/stat" // Keep commented until needed for percentiles ) // RequestResult holds the metrics for a single request. type RequestResult struct { IsSuccess bool Latency time.Duration TimeToFirstToken time.Duration // Relevant for streaming and non-streaming (time to first byte of response body) TotalTokens int // Tokens in the response content StreamLatency []time.Duration // Latency for each chunk in a stream (if streaming) // TODO: Add Error details? Model name? } // FinalStats holds the aggregated results of the benchmark run. type FinalStats struct { TotalRequests int64 SuccessfulRequests int64 FailedRequests int64 TotalDuration time.Duration AvgLatency time.Duration // Average latency of successful requests MinLatency time.Duration MaxLatency time.Duration P90Latency time.Duration P95Latency time.Duration P99Latency time.Duration AvgTimeToFirstToken time.Duration // Average TTFT of successful requests MinTimeToFirstToken time.Duration MaxTimeToFirstToken time.Duration P90TimeToFirstToken time.Duration P95TimeToFirstToken time.Duration P99TimeToFirstToken time.Duration AvgQPS float64 // Average Queries Per Second (TotalRequests / TotalDuration) AvgTokensPerSecond float64 // Average tokens per second (TotalTokens / Total Successful Latency) // Data for histograms LatencyData []time.Duration `json:"-"` // Raw latency data points (exclude from default JSON) TTFTData []time.Duration `json:"-"` // Raw TTFT data points (exclude from default JSON) // TODO: Add MaxQPS? Token Rate per model? } // StatsCollector collects and aggregates statistics during the benchmark. // It needs to be thread-safe. type StatsCollector struct { mutex sync.Mutex results []RequestResult // Store individual results for percentile calculation startTime time.Time totalTokens int64 totalReqTime time.Duration // Sum of latencies for successful requests successCount int64 failCount int64 minLatency time.Duration maxLatency time.Duration minTTFT time.Duration maxTTFT time.Duration initialized bool // Raw data storage latencies []time.Duration // Store raw latencies for histogram ttfts []time.Duration // Store raw TTFTs for histogram } // NewStatsCollector creates a new thread-safe StatsCollector. func NewStatsCollector() *StatsCollector { return &StatsCollector{ results: make([]RequestResult, 0, 1000), // Pre-allocate some capacity latencies: make([]time.Duration, 0, 1000), // Initialize ttfts: make([]time.Duration, 0, 1000), // Initialize startTime: time.Now(), minLatency: time.Duration(1<<63 - 1), // Max duration as initial min minTTFT: time.Duration(1<<63 - 1), // Max duration as initial min initialized: true, } } // RecordResult records the outcome of a single request. Thread-safe. func (sc *StatsCollector) RecordResult(result RequestResult) { sc.mutex.Lock() defer sc.mutex.Unlock() if !sc.initialized { // This shouldn't happen with NewStatsCollector, but defensive check return } log.Printf("[StatsCollector] Recording result: Success=%t, Latency=%s, TTFT=%s, Tokens=%d", result.IsSuccess, result.Latency, result.TimeToFirstToken, result.TotalTokens) sc.results = append(sc.results, result) // Append for percentile calculation if result.IsSuccess { sc.successCount++ sc.totalTokens += int64(result.TotalTokens) sc.totalReqTime += result.Latency if result.Latency < sc.minLatency { sc.minLatency = result.Latency } if result.Latency > sc.maxLatency { sc.maxLatency = result.Latency } sc.latencies = append(sc.latencies, result.Latency) // TTFT calculations only for successful requests if result.TimeToFirstToken > 0 { // Ensure valid TTFT if result.TimeToFirstToken < sc.minTTFT { sc.minTTFT = result.TimeToFirstToken } if result.TimeToFirstToken > sc.maxTTFT { sc.maxTTFT = result.TimeToFirstToken } sc.ttfts = append(sc.ttfts, result.TimeToFirstToken) } } else { sc.failCount++ } } // CalculateStats computes the final statistics based on recorded results. // totalDurationOverride allows specifying the exact benchmark duration used for QPS calculation. func (sc *StatsCollector) CalculateStats(totalDurationOverride time.Duration) FinalStats { sc.mutex.Lock() defer sc.mutex.Unlock() stats := FinalStats{ TotalDuration: totalDurationOverride, } if !sc.initialized || len(sc.results) == 0 { return stats // Return empty stats if no results } stats.SuccessfulRequests = sc.successCount stats.FailedRequests = sc.failCount stats.TotalRequests = sc.successCount + sc.failCount if totalDurationOverride.Seconds() > 0 { stats.AvgQPS = float64(stats.TotalRequests) / totalDurationOverride.Seconds() } if sc.successCount > 0 { stats.AvgLatency = sc.totalReqTime / time.Duration(sc.successCount) stats.MinLatency = sc.minLatency stats.MaxLatency = sc.maxLatency // Filter successful results for latency/TTFT percentile calculations successfulLatencies := make([]float64, 0, sc.successCount) successfulTTFTs := make([]float64, 0, sc.successCount) var totalTTFT time.Duration for _, res := range sc.results { if res.IsSuccess { successfulLatencies = append(successfulLatencies, float64(res.Latency.Nanoseconds())) if res.TimeToFirstToken > 0 { successfulTTFTs = append(successfulTTFTs, float64(res.TimeToFirstToken.Nanoseconds())) totalTTFT += res.TimeToFirstToken } } } // Calculate Avg TTFT if len(successfulTTFTs) > 0 { stats.AvgTimeToFirstToken = totalTTFT / time.Duration(len(successfulTTFTs)) stats.MinTimeToFirstToken = sc.minTTFT stats.MaxTimeToFirstToken = sc.maxTTFT } // Calculate average tokens/sec based on successful requests' total time if sc.totalReqTime.Seconds() > 0 { stats.AvgTokensPerSecond = float64(sc.totalTokens) / sc.totalReqTime.Seconds() } // --- Percentile Calculations (using basic sort for now) --- sort.Float64s(successfulLatencies) if len(successfulLatencies) > 0 { stats.P90Latency = time.Duration(percentile(successfulLatencies, 90)) * time.Nanosecond stats.P95Latency = time.Duration(percentile(successfulLatencies, 95)) * time.Nanosecond stats.P99Latency = time.Duration(percentile(successfulLatencies, 99)) * time.Nanosecond } sort.Float64s(successfulTTFTs) if len(successfulTTFTs) > 0 { stats.P90TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 90)) * time.Nanosecond stats.P95TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 95)) * time.Nanosecond stats.P99TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 99)) * time.Nanosecond } } else { // Handle cases where there are no successful requests stats.MinLatency = 0 stats.MaxLatency = 0 stats.MinTimeToFirstToken = 0 stats.MaxTimeToFirstToken = 0 } // Assign raw data for histograms stats.LatencyData = make([]time.Duration, len(sc.latencies)) copy(stats.LatencyData, sc.latencies) stats.TTFTData = make([]time.Duration, len(sc.ttfts)) copy(stats.TTFTData, sc.ttfts) return stats } // percentile calculates the p-th percentile of a sorted float64 slice. // Basic implementation, assumes data is already sorted. func percentile(data []float64, p float64) float64 { if len(data) == 0 { return 0 } if p <= 0 { return data[0] } if p >= 100 { return data[len(data)-1] } index := (p / 100.0) * float64(len(data)-1) lower := int(index) upper := lower + 1 if upper >= len(data) { return data[lower] // Should only happen if index is exactly len(data)-1 } weight := index - float64(lower) return data[lower]*(1-weight) + data[upper]*weight }