package test import ( "sync" "testing" "time" "llm-api-benchmark-tool/pkg/stats" "github.com/stretchr/testify/assert" // "gonum.org/v1/gonum/stat" // Might need later for percentile verification ) // TestStatsCollector_BasicRecording tests basic recording and aggregation. func TestStatsCollector_BasicRecording(t *testing.T) { collector := stats.NewStatsCollector() // Record some sample results collector.RecordResult(stats.RequestResult{ IsSuccess: true, Latency: 100 * time.Millisecond, TimeToFirstToken: 50 * time.Millisecond, // Example TTFT TotalTokens: 150, // Example token count StreamLatency: nil, // Not a streaming result for this example }) collector.RecordResult(stats.RequestResult{ IsSuccess: true, Latency: 120 * time.Millisecond, TimeToFirstToken: 60 * time.Millisecond, TotalTokens: 160, StreamLatency: nil, }) collector.RecordResult(stats.RequestResult{ IsSuccess: false, // Failed request Latency: 50 * time.Millisecond, // Failed quickly TimeToFirstToken: 0, // No TTFT for failed requests TotalTokens: 0, // No tokens for failed requests StreamLatency: nil, }) // Calculate final stats finalStats := collector.CalculateStats(2 * time.Second) // Example total duration assert.Equal(t, int64(3), finalStats.TotalRequests, "Total requests should be 3") assert.Equal(t, int64(2), finalStats.SuccessfulRequests, "Successful requests should be 2") assert.Equal(t, int64(1), finalStats.FailedRequests, "Failed requests should be 1") // Basic average latency check (only successful) // Expected avg = (100 + 120) / 2 = 110 ms // Compare float64 seconds assert.InDelta(t, 110*time.Millisecond.Seconds(), finalStats.AvgLatency.Seconds(), 0.01, "Average latency mismatch") // Basic average TTFT check (only successful) // Expected avg = (50 + 60) / 2 = 55 ms // Compare float64 seconds assert.InDelta(t, 55*time.Millisecond.Seconds(), finalStats.AvgTimeToFirstToken.Seconds(), 0.01, "Average TTFT mismatch") // Check Min/Max Latency (successful only) assert.Equal(t, 100*time.Millisecond, finalStats.MinLatency, "Min latency mismatch") assert.Equal(t, 120*time.Millisecond, finalStats.MaxLatency, "Max latency mismatch") // Check Min/Max TTFT (successful only) assert.Equal(t, 50*time.Millisecond, finalStats.MinTimeToFirstToken, "Min TTFT mismatch") assert.Equal(t, 60*time.Millisecond, finalStats.MaxTimeToFirstToken, "Max TTFT mismatch") // Basic QPS check // Expected QPS = 3 requests / 2 seconds = 1.5 assert.InDelta(t, 1.5, finalStats.AvgQPS, 0.01, "Average QPS mismatch") // Basic Token Rate check (only successful) // Total tokens = 150 + 160 = 310 // Total successful time = 100ms + 120ms = 220ms = 0.22s // Expected rate = 310 / 0.22 = ~1409 tokens/sec assert.InDelta(t, 310/0.22, finalStats.AvgTokensPerSecond, 1, "Average tokens/sec mismatch") // Larger delta due to potential floating point issues // Check Percentiles (calculated manually for [100, 120] ms latency) // P90 = 100*(1-0.9) + 120*0.9 = 10 + 108 = 118 // P95 = 100*(1-0.95) + 120*0.95 = 5 + 114 = 119 // P99 = 100*(1-0.99) + 120*0.99 = 1 + 118.8 = 119.8 assert.Equal(t, 118*time.Millisecond, finalStats.P90Latency, "P90 Latency mismatch") assert.Equal(t, 119*time.Millisecond, finalStats.P95Latency, "P95 Latency mismatch") assert.InDelta(t, 119.8*float64(time.Millisecond), float64(finalStats.P99Latency.Nanoseconds()), 0.1*float64(time.Millisecond), "P99 Latency mismatch") // Check Percentiles (calculated manually for [50, 60] ms TTFT) // P90 = 50*0.1 + 60*0.9 = 5 + 54 = 59 // P95 = 50*0.05 + 60*0.95 = 2.5 + 57 = 59.5 // P99 = 50*0.01 + 60*0.99 = 0.5 + 59.4 = 59.9 assert.Equal(t, 59*time.Millisecond, finalStats.P90TimeToFirstToken, "P90 TTFT mismatch") assert.InDelta(t, 59.5*float64(time.Millisecond), float64(finalStats.P95TimeToFirstToken.Nanoseconds()), 0.1*float64(time.Millisecond), "P95 TTFT mismatch") assert.InDelta(t, 59.9*float64(time.Millisecond), float64(finalStats.P99TimeToFirstToken.Nanoseconds()), 0.1*float64(time.Millisecond), "P99 TTFT mismatch") } // TestStatsCollector_ConcurrentRecording tests thread-safety. func TestStatsCollector_ConcurrentRecording(t *testing.T) { collector := stats.NewStatsCollector() numGoroutines := 50 numRequestsPerGoroutine := 100 totalRequests := numGoroutines * numRequestsPerGoroutine var wg sync.WaitGroup wg.Add(numGoroutines) for i := 0; i < numGoroutines; i++ { go func(routineID int) { defer wg.Done() for j := 0; j < numRequestsPerGoroutine; j++ { latency := time.Duration(100+routineID+j) * time.Millisecond ttft := time.Duration(50+routineID+j) * time.Millisecond collector.RecordResult(stats.RequestResult{ IsSuccess: true, Latency: latency, TimeToFirstToken: ttft, TotalTokens: 100, StreamLatency: nil, }) } }(i) } wg.Wait() finalStats := collector.CalculateStats(10 * time.Second) // Arbitrary duration assert.Equal(t, int64(totalRequests), finalStats.TotalRequests, "Total requests should match after concurrent writes") assert.Equal(t, int64(totalRequests), finalStats.SuccessfulRequests, "All requests should be successful") assert.Equal(t, int64(0), finalStats.FailedRequests, "No requests should fail") // Add more assertions for concurrent results if necessary }