228 lines
7.7 KiB
Go
228 lines
7.7 KiB
Go
package stats
|
|
|
|
import (
|
|
"log"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
// "gonum.org/v1/gonum/stat" // Keep commented until needed for percentiles
|
|
)
|
|
|
|
// RequestResult holds the metrics for a single request.
|
|
type RequestResult struct {
|
|
IsSuccess bool
|
|
Latency time.Duration
|
|
TimeToFirstToken time.Duration // Relevant for streaming and non-streaming (time to first byte of response body)
|
|
TotalTokens int // Tokens in the response content
|
|
StreamLatency []time.Duration // Latency for each chunk in a stream (if streaming)
|
|
// TODO: Add Error details? Model name?
|
|
}
|
|
|
|
// FinalStats holds the aggregated results of the benchmark run.
|
|
type FinalStats struct {
|
|
TotalRequests int64
|
|
SuccessfulRequests int64
|
|
FailedRequests int64
|
|
TotalDuration time.Duration
|
|
AvgLatency time.Duration // Average latency of successful requests
|
|
MinLatency time.Duration
|
|
MaxLatency time.Duration
|
|
P90Latency time.Duration
|
|
P95Latency time.Duration
|
|
P99Latency time.Duration
|
|
AvgTimeToFirstToken time.Duration // Average TTFT of successful requests
|
|
MinTimeToFirstToken time.Duration
|
|
MaxTimeToFirstToken time.Duration
|
|
P90TimeToFirstToken time.Duration
|
|
P95TimeToFirstToken time.Duration
|
|
P99TimeToFirstToken time.Duration
|
|
AvgQPS float64 // Average Queries Per Second (TotalRequests / TotalDuration)
|
|
AvgTokensPerSecond float64 // Average tokens per second (TotalTokens / Total Successful Latency)
|
|
// Data for histograms
|
|
LatencyData []time.Duration `json:"-"` // Raw latency data points (exclude from default JSON)
|
|
TTFTData []time.Duration `json:"-"` // Raw TTFT data points (exclude from default JSON)
|
|
// TODO: Add MaxQPS? Token Rate per model?
|
|
}
|
|
|
|
// StatsCollector collects and aggregates statistics during the benchmark.
|
|
// It needs to be thread-safe.
|
|
type StatsCollector struct {
|
|
mutex sync.Mutex
|
|
results []RequestResult // Store individual results for percentile calculation
|
|
startTime time.Time
|
|
totalTokens int64
|
|
totalReqTime time.Duration // Sum of latencies for successful requests
|
|
successCount int64
|
|
failCount int64
|
|
minLatency time.Duration
|
|
maxLatency time.Duration
|
|
minTTFT time.Duration
|
|
maxTTFT time.Duration
|
|
initialized bool
|
|
// Raw data storage
|
|
latencies []time.Duration // Store raw latencies for histogram
|
|
ttfts []time.Duration // Store raw TTFTs for histogram
|
|
}
|
|
|
|
// NewStatsCollector creates a new thread-safe StatsCollector.
|
|
func NewStatsCollector() *StatsCollector {
|
|
return &StatsCollector{
|
|
results: make([]RequestResult, 0, 1000), // Pre-allocate some capacity
|
|
latencies: make([]time.Duration, 0, 1000), // Initialize
|
|
ttfts: make([]time.Duration, 0, 1000), // Initialize
|
|
startTime: time.Now(),
|
|
minLatency: time.Duration(1<<63 - 1), // Max duration as initial min
|
|
minTTFT: time.Duration(1<<63 - 1), // Max duration as initial min
|
|
initialized: true,
|
|
}
|
|
}
|
|
|
|
// RecordResult records the outcome of a single request. Thread-safe.
|
|
func (sc *StatsCollector) RecordResult(result RequestResult) {
|
|
sc.mutex.Lock()
|
|
defer sc.mutex.Unlock()
|
|
|
|
if !sc.initialized {
|
|
// This shouldn't happen with NewStatsCollector, but defensive check
|
|
return
|
|
}
|
|
|
|
log.Printf("[StatsCollector] Recording result: Success=%t, Latency=%s, TTFT=%s, Tokens=%d", result.IsSuccess, result.Latency, result.TimeToFirstToken, result.TotalTokens)
|
|
sc.results = append(sc.results, result) // Append for percentile calculation
|
|
|
|
if result.IsSuccess {
|
|
sc.successCount++
|
|
sc.totalTokens += int64(result.TotalTokens)
|
|
sc.totalReqTime += result.Latency
|
|
|
|
if result.Latency < sc.minLatency {
|
|
sc.minLatency = result.Latency
|
|
}
|
|
if result.Latency > sc.maxLatency {
|
|
sc.maxLatency = result.Latency
|
|
}
|
|
sc.latencies = append(sc.latencies, result.Latency)
|
|
|
|
// TTFT calculations only for successful requests
|
|
if result.TimeToFirstToken > 0 { // Ensure valid TTFT
|
|
if result.TimeToFirstToken < sc.minTTFT {
|
|
sc.minTTFT = result.TimeToFirstToken
|
|
}
|
|
if result.TimeToFirstToken > sc.maxTTFT {
|
|
sc.maxTTFT = result.TimeToFirstToken
|
|
}
|
|
sc.ttfts = append(sc.ttfts, result.TimeToFirstToken)
|
|
}
|
|
} else {
|
|
sc.failCount++
|
|
}
|
|
}
|
|
|
|
// CalculateStats computes the final statistics based on recorded results.
|
|
// totalDurationOverride allows specifying the exact benchmark duration used for QPS calculation.
|
|
func (sc *StatsCollector) CalculateStats(totalDurationOverride time.Duration) FinalStats {
|
|
sc.mutex.Lock()
|
|
defer sc.mutex.Unlock()
|
|
|
|
stats := FinalStats{
|
|
TotalDuration: totalDurationOverride,
|
|
}
|
|
|
|
if !sc.initialized || len(sc.results) == 0 {
|
|
return stats // Return empty stats if no results
|
|
}
|
|
|
|
stats.SuccessfulRequests = sc.successCount
|
|
stats.FailedRequests = sc.failCount
|
|
stats.TotalRequests = sc.successCount + sc.failCount
|
|
|
|
if totalDurationOverride.Seconds() > 0 {
|
|
stats.AvgQPS = float64(stats.TotalRequests) / totalDurationOverride.Seconds()
|
|
}
|
|
|
|
if sc.successCount > 0 {
|
|
stats.AvgLatency = sc.totalReqTime / time.Duration(sc.successCount)
|
|
stats.MinLatency = sc.minLatency
|
|
stats.MaxLatency = sc.maxLatency
|
|
|
|
// Filter successful results for latency/TTFT percentile calculations
|
|
successfulLatencies := make([]float64, 0, sc.successCount)
|
|
successfulTTFTs := make([]float64, 0, sc.successCount)
|
|
var totalTTFT time.Duration
|
|
|
|
for _, res := range sc.results {
|
|
if res.IsSuccess {
|
|
successfulLatencies = append(successfulLatencies, float64(res.Latency.Nanoseconds()))
|
|
if res.TimeToFirstToken > 0 {
|
|
successfulTTFTs = append(successfulTTFTs, float64(res.TimeToFirstToken.Nanoseconds()))
|
|
totalTTFT += res.TimeToFirstToken
|
|
}
|
|
}
|
|
}
|
|
|
|
// Calculate Avg TTFT
|
|
if len(successfulTTFTs) > 0 {
|
|
stats.AvgTimeToFirstToken = totalTTFT / time.Duration(len(successfulTTFTs))
|
|
stats.MinTimeToFirstToken = sc.minTTFT
|
|
stats.MaxTimeToFirstToken = sc.maxTTFT
|
|
}
|
|
|
|
// Calculate average tokens/sec based on successful requests' total time
|
|
if sc.totalReqTime.Seconds() > 0 {
|
|
stats.AvgTokensPerSecond = float64(sc.totalTokens) / sc.totalReqTime.Seconds()
|
|
}
|
|
|
|
// --- Percentile Calculations (using basic sort for now) ---
|
|
sort.Float64s(successfulLatencies)
|
|
if len(successfulLatencies) > 0 {
|
|
stats.P90Latency = time.Duration(percentile(successfulLatencies, 90)) * time.Nanosecond
|
|
stats.P95Latency = time.Duration(percentile(successfulLatencies, 95)) * time.Nanosecond
|
|
stats.P99Latency = time.Duration(percentile(successfulLatencies, 99)) * time.Nanosecond
|
|
}
|
|
|
|
sort.Float64s(successfulTTFTs)
|
|
if len(successfulTTFTs) > 0 {
|
|
stats.P90TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 90)) * time.Nanosecond
|
|
stats.P95TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 95)) * time.Nanosecond
|
|
stats.P99TimeToFirstToken = time.Duration(percentile(successfulTTFTs, 99)) * time.Nanosecond
|
|
}
|
|
|
|
} else {
|
|
// Handle cases where there are no successful requests
|
|
stats.MinLatency = 0
|
|
stats.MaxLatency = 0
|
|
stats.MinTimeToFirstToken = 0
|
|
stats.MaxTimeToFirstToken = 0
|
|
}
|
|
|
|
// Assign raw data for histograms
|
|
stats.LatencyData = make([]time.Duration, len(sc.latencies))
|
|
copy(stats.LatencyData, sc.latencies)
|
|
stats.TTFTData = make([]time.Duration, len(sc.ttfts))
|
|
copy(stats.TTFTData, sc.ttfts)
|
|
|
|
return stats
|
|
}
|
|
|
|
// percentile calculates the p-th percentile of a sorted float64 slice.
|
|
// Basic implementation, assumes data is already sorted.
|
|
func percentile(data []float64, p float64) float64 {
|
|
if len(data) == 0 {
|
|
return 0
|
|
}
|
|
if p <= 0 {
|
|
return data[0]
|
|
}
|
|
if p >= 100 {
|
|
return data[len(data)-1]
|
|
}
|
|
index := (p / 100.0) * float64(len(data)-1)
|
|
lower := int(index)
|
|
upper := lower + 1
|
|
if upper >= len(data) {
|
|
return data[lower] // Should only happen if index is exactly len(data)-1
|
|
}
|
|
weight := index - float64(lower)
|
|
return data[lower]*(1-weight) + data[upper]*weight
|
|
}
|