llm-api-benchmark-tool/test/integration_test.go

211 lines
8.2 KiB
Go

package test
import (
"context"
"fmt"
"html"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"llm-api-benchmark-tool/pkg/client"
"llm-api-benchmark-tool/pkg/concurrency"
"llm-api-benchmark-tool/pkg/config"
"llm-api-benchmark-tool/pkg/report"
"llm-api-benchmark-tool/pkg/stats"
)
// Helper to create durations
func dms(ms int) time.Duration {
return time.Duration(ms) * time.Millisecond
}
// TestBenchmarkIntegration performs an end-to-end test of the benchmark tool.
func TestBenchmarkIntegration(t *testing.T) {
// --- 1. Set up Mock HTTP Server ---
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
// Simulate some processing time
time.Sleep(dms(50 + int(time.Now().UnixMilli()%50))) // 50-100ms latency
// Check if it's a streaming request based on header or request body field
// For simplicity, let's assume non-streaming for now
isStreaming := false // TODO: Detect streaming requests properly
if isStreaming {
// Simulate SSE streaming
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}
// Send TTFT event quickly
time.Sleep(dms(10 + int(time.Now().UnixMilli()%10))) // 10-20ms TTFT
fmt.Fprintf(w, "id: msg1\nevent: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Hello\"}}\n\n")
flusher.Flush()
// Send subsequent events
time.Sleep(dms(30))
fmt.Fprintf(w, "id: msg2\nevent: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \" world!\"}}\n\n")
flusher.Flush()
// Send stop event
time.Sleep(dms(10))
fmt.Fprintf(w, "id: msg3\nevent: message_stop\ndata: {\"type\": \"message_stop\"}\n\n")
flusher.Flush()
} else {
// Simulate non-streaming response
w.Header().Set("Content-Type", "application/json")
// A simple, valid-looking JSON response
responseBody := `{
"id": "chatcmpl-mock123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-mock-3.5",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a mock response."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}`
fmt.Fprintln(w, responseBody)
}
})
server := httptest.NewServer(mux)
defer server.Close()
// --- 2. Set up Configuration ---
testDuration := 100 * time.Millisecond // Short duration for testing
cfg := config.Config{
APIEndpoint: server.URL + "/v1/chat/completions", // Use mock server URL
Model: "gpt-mock-3.5",
Client: "fasthttp", // Use fasthttp for testing
Timeout: 5 * time.Second,
Concurrency: 2, // Low concurrency for test
Duration: testDuration,
Requests: 10, // Limit requests for test (Note: Duration takes precedence)
RateLimit: 0, // No rate limit for test
Streaming: false, // Test non-streaming first
PromptTokens: 10,
MaxTokens: 20,
OutputReport: filepath.Join(t.TempDir(), "integration_report.html"),
Headers: make(map[string]string),
Payload: make(map[string]interface{}),
}
// --- 3. Initialize Components ---
httpClient := client.NewFastHTTPClient()
require.NotNil(t, httpClient, "Failed to create HTTP client")
statsCollector := stats.NewStatsCollector()
params := concurrency.BenchmarkParams{
TargetURL: cfg.APIEndpoint, // Use correct field name
HTTPClient: httpClient,
TotalDuration: cfg.Duration,
Concurrency: cfg.Concurrency,
RequestInterval: 0, // No artificial delay between requests
IsStreaming: cfg.Streaming,
PromptTokens: cfg.PromptTokens,
}
manager := concurrency.NewManager(params, statsCollector)
require.NotNil(t, manager, "Manager should not be nil")
// --- 4. Run Benchmark ---
t.Logf("Running benchmark against mock server: %s", cfg.APIEndpoint)
startTime := time.Now()
err := manager.RunBenchmark(context.Background()) // Pass context
actualDuration := time.Since(startTime)
require.NoError(t, err, "Benchmark run failed")
// --- 5. Get Final Stats ---
finalStats := statsCollector.CalculateStats(actualDuration)
t.Logf("Benchmark finished. Final stats: %+v", finalStats)
// Explicit check before assertion
if finalStats.TotalRequests > 0 {
t.Logf("Check PASSED: finalStats.TotalRequests (%d) > 0", finalStats.TotalRequests)
} else {
t.Errorf("Check FAILED: finalStats.TotalRequests (%d) is not > 0", finalStats.TotalRequests)
}
// Replace assert.Greater with standard check due to potential testify issue
if !(finalStats.TotalRequests > 0) { // Explicitly check the condition
t.Errorf("Test FAILED: Expected TotalRequests (%d) to be greater than 0", finalStats.TotalRequests)
}
// Replace assert.LessOrEqual with standard check due to potential testify issue
if !(finalStats.FailedRequests <= 1) { // Explicitly check the condition
t.Errorf("Test FAILED: Expected FailedRequests (%d) to be less than or equal to 1", finalStats.FailedRequests)
}
// --- 6. Generate Report ---
err = report.GenerateHTMLReport(finalStats, cfg.OutputReport) // Pass report path string
require.NoError(t, err, "Failed to generate HTML report")
t.Logf("Report generated at: %s", cfg.OutputReport)
// --- 7. Verify Report Content ---
reportBytes, err := os.ReadFile(cfg.OutputReport)
require.NoError(t, err, "Failed to read generated report file")
reportContent := string(reportBytes)
// Check for basic summary stats in the report (adjust format as needed)
assert.Contains(t, reportContent, "Total Requests:", "Report missing actual summary data (e.g., 'Total Requests:')")
assert.Contains(t, reportContent, fmt.Sprintf("Total Requests:</td><td>%d", finalStats.TotalRequests), "Report missing or incorrect total requests")
assert.Contains(t, reportContent, fmt.Sprintf("Successful Requests:</td><td>%d", finalStats.SuccessfulRequests), "Report missing or incorrect successful requests")
assert.Contains(t, reportContent, "Avg Latency:", "Report missing avg latency")
// Check for chart divs
assert.Contains(t, reportContent, "<div id=\"latencyHistogram\"", "Report missing latency histogram div")
assert.Contains(t, reportContent, "<div id=\"ttftHistogram\"", "Report missing TTFT histogram div")
// Check for chart initialization JS (presence, not exact content)
assert.Contains(t, reportContent, "echarts.init(document.getElementById('latencyHistogram')", "Report missing latency chart JS init")
assert.Contains(t, reportContent, "echarts.init(document.getElementById('ttftHistogram')", "Report missing TTFT chart JS init")
// Check that histogram data is not just the default empty/"No Data" case
// We need to find the data part of the JS options
latencyOptionStart := strings.Index(reportContent, "let option_latencyHistogram = {")
require.True(t, latencyOptionStart > 0, "Latency chart options not found in JS")
latencyOptionEnd := strings.Index(reportContent[latencyOptionStart:], "};")
require.True(t, latencyOptionEnd > 0, "Latency chart options end not found")
latencyOptionJS := reportContent[latencyOptionStart : latencyOptionStart+latencyOptionEnd]
// Check that the default "No Data" category isn't present if we have data
if len(finalStats.LatencyData) > 0 {
assert.NotContains(t, latencyOptionJS, `"name":"Latency","type":"bar","data":[{"value":0}]`, "Latency chart shows default empty data when actual data exists")
assert.NotContains(t, latencyOptionJS, `xaxis":[{"data":["No Data"]}]`, "Latency chart shows 'No Data' category when actual data exists") // Check category
}
// Check HTML escaping for summary values (find an example value)
avgLatencyStr := html.EscapeString(finalStats.AvgLatency.String())
assert.Contains(t, reportContent, avgLatencyStr, "AvgLatency string not found or not escaped in HTML")
t.Log("Integration test completed successfully.")
}
// TODO: Add another integration test case for streaming requests