128 lines
5.3 KiB
Go
128 lines
5.3 KiB
Go
package test
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"llm-api-benchmark-tool/pkg/stats"
|
|
"github.com/stretchr/testify/assert"
|
|
// "gonum.org/v1/gonum/stat" // Might need later for percentile verification
|
|
)
|
|
|
|
// TestStatsCollector_BasicRecording tests basic recording and aggregation.
|
|
func TestStatsCollector_BasicRecording(t *testing.T) {
|
|
collector := stats.NewStatsCollector()
|
|
|
|
// Record some sample results
|
|
collector.RecordResult(stats.RequestResult{
|
|
IsSuccess: true,
|
|
Latency: 100 * time.Millisecond,
|
|
TimeToFirstToken: 50 * time.Millisecond, // Example TTFT
|
|
TotalTokens: 150, // Example token count
|
|
StreamLatency: nil, // Not a streaming result for this example
|
|
})
|
|
collector.RecordResult(stats.RequestResult{
|
|
IsSuccess: true,
|
|
Latency: 120 * time.Millisecond,
|
|
TimeToFirstToken: 60 * time.Millisecond,
|
|
TotalTokens: 160,
|
|
StreamLatency: nil,
|
|
})
|
|
collector.RecordResult(stats.RequestResult{
|
|
IsSuccess: false, // Failed request
|
|
Latency: 50 * time.Millisecond, // Failed quickly
|
|
TimeToFirstToken: 0, // No TTFT for failed requests
|
|
TotalTokens: 0, // No tokens for failed requests
|
|
StreamLatency: nil,
|
|
})
|
|
|
|
// Calculate final stats
|
|
finalStats := collector.CalculateStats(2 * time.Second) // Example total duration
|
|
|
|
assert.Equal(t, int64(3), finalStats.TotalRequests, "Total requests should be 3")
|
|
assert.Equal(t, int64(2), finalStats.SuccessfulRequests, "Successful requests should be 2")
|
|
assert.Equal(t, int64(1), finalStats.FailedRequests, "Failed requests should be 1")
|
|
|
|
// Basic average latency check (only successful)
|
|
// Expected avg = (100 + 120) / 2 = 110 ms
|
|
// Compare float64 seconds
|
|
assert.InDelta(t, 110*time.Millisecond.Seconds(), finalStats.AvgLatency.Seconds(), 0.01, "Average latency mismatch")
|
|
|
|
// Basic average TTFT check (only successful)
|
|
// Expected avg = (50 + 60) / 2 = 55 ms
|
|
// Compare float64 seconds
|
|
assert.InDelta(t, 55*time.Millisecond.Seconds(), finalStats.AvgTimeToFirstToken.Seconds(), 0.01, "Average TTFT mismatch")
|
|
|
|
// Check Min/Max Latency (successful only)
|
|
assert.Equal(t, 100*time.Millisecond, finalStats.MinLatency, "Min latency mismatch")
|
|
assert.Equal(t, 120*time.Millisecond, finalStats.MaxLatency, "Max latency mismatch")
|
|
|
|
// Check Min/Max TTFT (successful only)
|
|
assert.Equal(t, 50*time.Millisecond, finalStats.MinTimeToFirstToken, "Min TTFT mismatch")
|
|
assert.Equal(t, 60*time.Millisecond, finalStats.MaxTimeToFirstToken, "Max TTFT mismatch")
|
|
|
|
// Basic QPS check
|
|
// Expected QPS = 3 requests / 2 seconds = 1.5
|
|
assert.InDelta(t, 1.5, finalStats.AvgQPS, 0.01, "Average QPS mismatch")
|
|
|
|
// Basic Token Rate check (only successful)
|
|
// Total tokens = 150 + 160 = 310
|
|
// Total successful time = 100ms + 120ms = 220ms = 0.22s
|
|
// Expected rate = 310 / 0.22 = ~1409 tokens/sec
|
|
assert.InDelta(t, 310/0.22, finalStats.AvgTokensPerSecond, 1, "Average tokens/sec mismatch") // Larger delta due to potential floating point issues
|
|
|
|
// Check Percentiles (calculated manually for [100, 120] ms latency)
|
|
// P90 = 100*(1-0.9) + 120*0.9 = 10 + 108 = 118
|
|
// P95 = 100*(1-0.95) + 120*0.95 = 5 + 114 = 119
|
|
// P99 = 100*(1-0.99) + 120*0.99 = 1 + 118.8 = 119.8
|
|
assert.Equal(t, 118*time.Millisecond, finalStats.P90Latency, "P90 Latency mismatch")
|
|
assert.Equal(t, 119*time.Millisecond, finalStats.P95Latency, "P95 Latency mismatch")
|
|
assert.InDelta(t, 119.8*float64(time.Millisecond), float64(finalStats.P99Latency.Nanoseconds()), 0.1*float64(time.Millisecond), "P99 Latency mismatch")
|
|
|
|
// Check Percentiles (calculated manually for [50, 60] ms TTFT)
|
|
// P90 = 50*0.1 + 60*0.9 = 5 + 54 = 59
|
|
// P95 = 50*0.05 + 60*0.95 = 2.5 + 57 = 59.5
|
|
// P99 = 50*0.01 + 60*0.99 = 0.5 + 59.4 = 59.9
|
|
assert.Equal(t, 59*time.Millisecond, finalStats.P90TimeToFirstToken, "P90 TTFT mismatch")
|
|
assert.InDelta(t, 59.5*float64(time.Millisecond), float64(finalStats.P95TimeToFirstToken.Nanoseconds()), 0.1*float64(time.Millisecond), "P95 TTFT mismatch")
|
|
assert.InDelta(t, 59.9*float64(time.Millisecond), float64(finalStats.P99TimeToFirstToken.Nanoseconds()), 0.1*float64(time.Millisecond), "P99 TTFT mismatch")
|
|
}
|
|
|
|
// TestStatsCollector_ConcurrentRecording tests thread-safety.
|
|
func TestStatsCollector_ConcurrentRecording(t *testing.T) {
|
|
collector := stats.NewStatsCollector()
|
|
numGoroutines := 50
|
|
numRequestsPerGoroutine := 100
|
|
totalRequests := numGoroutines * numRequestsPerGoroutine
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(numGoroutines)
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
go func(routineID int) {
|
|
defer wg.Done()
|
|
for j := 0; j < numRequestsPerGoroutine; j++ {
|
|
latency := time.Duration(100+routineID+j) * time.Millisecond
|
|
ttft := time.Duration(50+routineID+j) * time.Millisecond
|
|
collector.RecordResult(stats.RequestResult{
|
|
IsSuccess: true,
|
|
Latency: latency,
|
|
TimeToFirstToken: ttft,
|
|
TotalTokens: 100,
|
|
StreamLatency: nil,
|
|
})
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
finalStats := collector.CalculateStats(10 * time.Second) // Arbitrary duration
|
|
|
|
assert.Equal(t, int64(totalRequests), finalStats.TotalRequests, "Total requests should match after concurrent writes")
|
|
assert.Equal(t, int64(totalRequests), finalStats.SuccessfulRequests, "All requests should be successful")
|
|
assert.Equal(t, int64(0), finalStats.FailedRequests, "No requests should fail")
|
|
// Add more assertions for concurrent results if necessary
|
|
}
|