228 lines
7.7 KiB
Go

package stats
import (
"log"
"sort"
"sync"
"time"
// "gonum.org/v1/gonum/stat" // Keep commented until needed for percentiles
)
// RequestResult holds the metrics for a single request.
type RequestResult struct {
IsSuccess bool
Latency time.Duration
TimeToFirstToken time.Duration // Relevant for streaming and non-streaming (time to first byte of response body)
TotalTokens int // Tokens in the response content
StreamLatency []time.Duration // Latency for each chunk in a stream (if streaming)
// TODO: Add Error details? Model name?
}
// FinalStats holds the aggregated results of the benchmark run.
type FinalStats struct {
TotalRequests int64
SuccessfulRequests int64
FailedRequests int64
TotalDuration time.Duration
AvgLatency time.Duration // Average latency of successful requests
MinLatency time.Duration
MaxLatency time.Duration
P90Latency time.Duration
P95Latency time.Duration
P99Latency time.Duration
AvgTimeToFirstToken time.Duration // Average TTFT of successful requests
MinTimeToFirstToken time.Duration
MaxTimeToFirstToken time.Duration
P90TimeToFirstToken time.Duration
P95TimeToFirstToken time.Duration
P99TimeToFirstToken time.Duration
AvgQPS float64 // Average Queries Per Second (TotalRequests / TotalDuration)
AvgTokensPerSecond float64 // Average tokens per second (TotalTokens / Total Successful Latency)
// Data for histograms
LatencyData []time.Duration `json:"-"` // Raw latency data points (exclude from default JSON)
TTFTData []time.Duration `json:"-"` // Raw TTFT data points (exclude from default JSON)
// TODO: Add MaxQPS? Token Rate per model?
}
// StatsCollector collects and aggregates statistics during the benchmark.
// It needs to be thread-safe.
type StatsCollector struct {
mutex sync.Mutex
results []RequestResult // Store individual results for percentile calculation
startTime time.Time
totalTokens int64
totalReqTime time.Duration // Sum of latencies for successful requests
successCount int64
failCount int64
minLatency time.Duration
maxLatency time.Duration
minTTFT time.Duration
maxTTFT time.Duration
initialized bool
// Raw data storage
latencies []time.Duration // Store raw latencies for histogram
ttfts []time.Duration // Store raw TTFTs for histogram
}
// NewStatsCollector creates a new thread-safe StatsCollector.
func NewStatsCollector() *StatsCollector {
return &StatsCollector{
results: make([]RequestResult, 0, 1000), // Pre-allocate some capacity
latencies: make([]time.Duration, 0, 1000), // Initialize
ttfts: make([]time.Duration, 0, 1000), // Initialize
startTime: time.Now(),
minLatency: time.Duration(1<<63 - 1), // Max duration as initial min
minTTFT: time.Duration(1<<63 - 1), // Max duration as initial min
initialized: true,
}
}
// RecordResult records the outcome of a single request. Thread-safe.
func (sc *StatsCollector) RecordResult(result RequestResult) {
sc.mutex.Lock()
defer sc.mutex.Unlock()
if !sc.initialized {
// This shouldn't happen with NewStatsCollector, but defensive check
return
}
log.Printf("[StatsCollector] Recording result: Success=%t, Latency=%s, TTFT=%s, Tokens=%d", result.IsSuccess, result.Latency, result.TimeToFirstToken, result.TotalTokens)
sc.results = append(sc.results, result) // Append for percentile calculation
if result.IsSuccess {
sc.successCount++
sc.totalTokens += int64(result.TotalTokens)
sc.totalReqTime += result.Latency
if result.Latency < sc.minLatency {
sc.minLatency = result.Latency
}
if result.Latency > sc.maxLatency {
sc.maxLatency = result.Latency
}
sc.latencies = append(sc.latencies, result.Latency)
// TTFT calculations only for successful requests
if result.TimeToFirstToken > 0 { // Ensure valid TTFT
if result.TimeToFirstToken < sc.minTTFT {
sc.minTTFT = result.TimeToFirstToken
}
if result.TimeToFirstToken > sc.maxTTFT {
sc.maxTTFT = result.TimeToFirstToken
}
sc.ttfts = append(sc.ttfts, result.TimeToFirstToken)
}
} else {
sc.failCount++
}
}
// CalculateStats computes the final statistics based on recorded results.
// totalDurationOverride allows specifying the exact benchmark duration used for QPS calculation.
func (sc *StatsCollector) CalculateStats(totalDurationOverride time.Duration) FinalStats {
sc.mutex.Lock()
defer sc.mutex.Unlock()
stats := FinalStats{
TotalDuration: totalDurationOverride,
}
if !sc.initialized || len(sc.results) == 0 {
return stats // Return empty stats if no results
}
stats.SuccessfulRequests = sc.successCount
stats.FailedRequests = sc.failCount
stats.TotalRequests = sc.successCount + sc.failCount
if totalDurationOverride.Seconds() > 0 {
stats.AvgQPS = float64(stats.TotalRequests) / totalDurationOverride.Seconds()
}
if sc.successCount > 0 {
stats.AvgLatency = sc.totalReqTime / time.Duration(sc.successCount)
stats.MinLatency = sc.minLatency
stats.MaxLatency = sc.maxLatency
// Filter successful results for latency/TTFT percentile calculations
successfulLatencies := make([]float64, 0, sc.successCount)
successfulTTFTs := make([]float64, 0, sc.successCount)
var totalTTFT time.Duration
for _, res := range sc.results {
if res.IsSuccess {
successfulLatencies = append(successfulLatencies, float64(res.Latency.Nanoseconds()))
if res.TimeToFirstToken > 0 {
successfulTTFTs = append(successfulTTFTs, float64(res.TimeToFirstToken.Nanoseconds()))
totalTTFT += res.TimeToFirstToken
}
}
}
// Calculate Avg TTFT
if len(successfulTTFTs) > 0 {
stats.AvgTimeToFirstToken = totalTTFT / time.Duration(len(successfulTTFTs))
stats.MinTimeToFirstToken = sc.minTTFT
stats.MaxTimeToFirstToken = sc.maxTTFT
}
// Calculate average tokens/sec based on successful requests' total time
if sc.totalReqTime.Seconds() > 0 {
stats.AvgTokensPerSecond = float64(sc.totalTokens) / sc.totalReqTime.Seconds()
}
// --- Percentile Calculations (using basic sort for now) ---
sort.Float64s(successfulLatencies)
if len(successfulLatencies) > 0 {
stats.P90Latency = time.Duration(percentile(successfulLatencies, 90)) * time.Nanosecond
stats.P95Latency = time.Duration(percentile(successfulLatencies, 95)) * time.Nanosecond
stats.P99Latency = time.Duration(percentile(successfulLatencies, 99)) * time.Nanosecond
}
sort.Float64s(successfulTTFTs)
if len(successfulTTFTs) > 0 {
stats.P90TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 90)) * time.Nanosecond
stats.P95TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 95)) * time.Nanosecond
stats.P99TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 99)) * time.Nanosecond
}
} else {
// Handle cases where there are no successful requests
stats.MinLatency = 0
stats.MaxLatency = 0
stats.MinTimeToFirstToken = 0
stats.MaxTimeToFirstToken = 0
}
// Assign raw data for histograms
stats.LatencyData = make([]time.Duration, len(sc.latencies))
copy(stats.LatencyData, sc.latencies)
stats.TTFTData = make([]time.Duration, len(sc.ttfts))
copy(stats.TTFTData, sc.ttfts)
return stats
}
// percentile calculates the p-th percentile of a sorted float64 slice.
// Basic implementation, assumes data is already sorted.
func percentile(data []float64, p float64) float64 {
if len(data) == 0 {
return 0
}
if p <= 0 {
return data[0]
}
if p >= 100 {
return data[len(data)-1]
}
index := (p / 100.0) * float64(len(data)-1)
lower := int(index)
upper := lower + 1
if upper >= len(data) {
return data[lower] // Should only happen if index is exactly len(data)-1
}
weight := index - float64(lower)
return data[lower]*(1-weight) + data[upper]*weight
}