package test import ( "context" "fmt" "html" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "llm-api-benchmark-tool/pkg/client" "llm-api-benchmark-tool/pkg/concurrency" "llm-api-benchmark-tool/pkg/config" "llm-api-benchmark-tool/pkg/report" "llm-api-benchmark-tool/pkg/stats" ) // Helper to create durations func dms(ms int) time.Duration { return time.Duration(ms) * time.Millisecond } // TestBenchmarkIntegration performs an end-to-end test of the benchmark tool. func TestBenchmarkIntegration(t *testing.T) { // --- 1. Set up Mock HTTP Server --- mux := http.NewServeMux() mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { // Simulate some processing time time.Sleep(dms(50 + int(time.Now().UnixMilli()%50))) // 50-100ms latency // Check if it's a streaming request based on header or request body field // For simplicity, let's assume non-streaming for now isStreaming := false // TODO: Detect streaming requests properly if isStreaming { // Simulate SSE streaming w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) return } // Send TTFT event quickly time.Sleep(dms(10 + int(time.Now().UnixMilli()%10))) // 10-20ms TTFT fmt.Fprintf(w, "id: msg1\nevent: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Hello\"}}\n\n") flusher.Flush() // Send subsequent events time.Sleep(dms(30)) fmt.Fprintf(w, "id: msg2\nevent: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \" world!\"}}\n\n") flusher.Flush() // Send stop event time.Sleep(dms(10)) fmt.Fprintf(w, "id: msg3\nevent: message_stop\ndata: {\"type\": \"message_stop\"}\n\n") flusher.Flush() } else { // Simulate non-streaming response w.Header().Set("Content-Type", "application/json") // A simple, valid-looking JSON response responseBody := `{ "id": "chatcmpl-mock123", "object": "chat.completion", "created": 1677652288, "model": "gpt-mock-3.5", "choices": [{ "index": 0, "message": { "role": "assistant", "content": "This is a mock response." }, "finish_reason": "stop" }], "usage": { "prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21 } }` fmt.Fprintln(w, responseBody) } }) server := httptest.NewServer(mux) defer server.Close() // --- 2. Set up Configuration --- testDuration := 100 * time.Millisecond // Short duration for testing cfg := config.Config{ APIEndpoint: server.URL + "/v1/chat/completions", // Use mock server URL Model: "gpt-mock-3.5", Client: "fasthttp", // Use fasthttp for testing Timeout: 5 * time.Second, Concurrency: 2, // Low concurrency for test Duration: testDuration, Requests: 10, // Limit requests for test (Note: Duration takes precedence) RateLimit: 0, // No rate limit for test Streaming: false, // Test non-streaming first PromptTokens: 10, MaxTokens: 20, OutputReport: filepath.Join(t.TempDir(), "integration_report.html"), Headers: make(map[string]string), Payload: make(map[string]interface{}), } // --- 3. Initialize Components --- httpClient := client.NewFastHTTPClient() require.NotNil(t, httpClient, "Failed to create HTTP client") statsCollector := stats.NewStatsCollector() params := concurrency.BenchmarkParams{ TargetURL: cfg.APIEndpoint, // Use correct field name HTTPClient: httpClient, TotalDuration: cfg.Duration, Concurrency: cfg.Concurrency, RequestInterval: 0, // No artificial delay between requests IsStreaming: cfg.Streaming, PromptTokens: cfg.PromptTokens, } manager := concurrency.NewManager(params, statsCollector) require.NotNil(t, manager, "Manager should not be nil") // --- 4. Run Benchmark --- t.Logf("Running benchmark against mock server: %s", cfg.APIEndpoint) startTime := time.Now() err := manager.RunBenchmark(context.Background()) // Pass context actualDuration := time.Since(startTime) require.NoError(t, err, "Benchmark run failed") // --- 5. Get Final Stats --- finalStats := statsCollector.CalculateStats(actualDuration) t.Logf("Benchmark finished. Final stats: %+v", finalStats) // Explicit check before assertion if finalStats.TotalRequests > 0 { t.Logf("Check PASSED: finalStats.TotalRequests (%d) > 0", finalStats.TotalRequests) } else { t.Errorf("Check FAILED: finalStats.TotalRequests (%d) is not > 0", finalStats.TotalRequests) } // Replace assert.Greater with standard check due to potential testify issue if !(finalStats.TotalRequests > 0) { // Explicitly check the condition t.Errorf("Test FAILED: Expected TotalRequests (%d) to be greater than 0", finalStats.TotalRequests) } // Replace assert.LessOrEqual with standard check due to potential testify issue if !(finalStats.FailedRequests <= 1) { // Explicitly check the condition t.Errorf("Test FAILED: Expected FailedRequests (%d) to be less than or equal to 1", finalStats.FailedRequests) } // --- 6. Generate Report --- err = report.GenerateHTMLReport(finalStats, cfg.OutputReport) // Pass report path string require.NoError(t, err, "Failed to generate HTML report") t.Logf("Report generated at: %s", cfg.OutputReport) // --- 7. Verify Report Content --- reportBytes, err := os.ReadFile(cfg.OutputReport) require.NoError(t, err, "Failed to read generated report file") reportContent := string(reportBytes) // Check for basic summary stats in the report (adjust format as needed) assert.Contains(t, reportContent, "Total Requests:", "Report missing actual summary data (e.g., 'Total Requests:')") assert.Contains(t, reportContent, fmt.Sprintf("Total Requests:%d", finalStats.TotalRequests), "Report missing or incorrect total requests") assert.Contains(t, reportContent, fmt.Sprintf("Successful Requests:%d", finalStats.SuccessfulRequests), "Report missing or incorrect successful requests") assert.Contains(t, reportContent, "Avg Latency:", "Report missing avg latency") // Check for chart divs assert.Contains(t, reportContent, "
0, "Latency chart options not found in JS") latencyOptionEnd := strings.Index(reportContent[latencyOptionStart:], "};") require.True(t, latencyOptionEnd > 0, "Latency chart options end not found") latencyOptionJS := reportContent[latencyOptionStart : latencyOptionStart+latencyOptionEnd] // Check that the default "No Data" category isn't present if we have data if len(finalStats.LatencyData) > 0 { assert.NotContains(t, latencyOptionJS, `"name":"Latency","type":"bar","data":[{"value":0}]`, "Latency chart shows default empty data when actual data exists") assert.NotContains(t, latencyOptionJS, `xaxis":[{"data":["No Data"]}]`, "Latency chart shows 'No Data' category when actual data exists") // Check category } // Check HTML escaping for summary values (find an example value) avgLatencyStr := html.EscapeString(finalStats.AvgLatency.String()) assert.Contains(t, reportContent, avgLatencyStr, "AvgLatency string not found or not escaped in HTML") t.Log("Integration test completed successfully.") } // TODO: Add another integration test case for streaming requests