211 lines
8.2 KiB
Go
211 lines
8.2 KiB
Go
package test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"html"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"llm-api-benchmark-tool/pkg/client"
|
|
"llm-api-benchmark-tool/pkg/concurrency"
|
|
"llm-api-benchmark-tool/pkg/config"
|
|
"llm-api-benchmark-tool/pkg/report"
|
|
"llm-api-benchmark-tool/pkg/stats"
|
|
)
|
|
|
|
// Helper to create durations
|
|
func dms(ms int) time.Duration {
|
|
return time.Duration(ms) * time.Millisecond
|
|
}
|
|
|
|
// TestBenchmarkIntegration performs an end-to-end test of the benchmark tool.
|
|
func TestBenchmarkIntegration(t *testing.T) {
|
|
// --- 1. Set up Mock HTTP Server ---
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
|
// Simulate some processing time
|
|
time.Sleep(dms(50 + int(time.Now().UnixMilli()%50))) // 50-100ms latency
|
|
|
|
// Check if it's a streaming request based on header or request body field
|
|
// For simplicity, let's assume non-streaming for now
|
|
isStreaming := false // TODO: Detect streaming requests properly
|
|
|
|
if isStreaming {
|
|
// Simulate SSE streaming
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Send TTFT event quickly
|
|
time.Sleep(dms(10 + int(time.Now().UnixMilli()%10))) // 10-20ms TTFT
|
|
fmt.Fprintf(w, "id: msg1\nevent: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \"Hello\"}}\n\n")
|
|
flusher.Flush()
|
|
|
|
// Send subsequent events
|
|
time.Sleep(dms(30))
|
|
fmt.Fprintf(w, "id: msg2\nevent: content_block_delta\ndata: {\"type\": \"content_block_delta\", \"index\": 0, \"delta\": {\"type\": \"text_delta\", \"text\": \" world!\"}}\n\n")
|
|
flusher.Flush()
|
|
|
|
// Send stop event
|
|
time.Sleep(dms(10))
|
|
fmt.Fprintf(w, "id: msg3\nevent: message_stop\ndata: {\"type\": \"message_stop\"}\n\n")
|
|
flusher.Flush()
|
|
} else {
|
|
// Simulate non-streaming response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
// A simple, valid-looking JSON response
|
|
responseBody := `{
|
|
"id": "chatcmpl-mock123",
|
|
"object": "chat.completion",
|
|
"created": 1677652288,
|
|
"model": "gpt-mock-3.5",
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "This is a mock response."
|
|
},
|
|
"finish_reason": "stop"
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": 9,
|
|
"completion_tokens": 12,
|
|
"total_tokens": 21
|
|
}
|
|
}`
|
|
fmt.Fprintln(w, responseBody)
|
|
}
|
|
})
|
|
|
|
server := httptest.NewServer(mux)
|
|
defer server.Close()
|
|
|
|
// --- 2. Set up Configuration ---
|
|
testDuration := 100 * time.Millisecond // Short duration for testing
|
|
cfg := config.Config{
|
|
APIEndpoint: server.URL + "/v1/chat/completions", // Use mock server URL
|
|
Model: "gpt-mock-3.5",
|
|
Client: "fasthttp", // Use fasthttp for testing
|
|
Timeout: 5 * time.Second,
|
|
Concurrency: 2, // Low concurrency for test
|
|
Duration: testDuration,
|
|
Requests: 10, // Limit requests for test (Note: Duration takes precedence)
|
|
RateLimit: 0, // No rate limit for test
|
|
Streaming: false, // Test non-streaming first
|
|
PromptTokens: 10,
|
|
MaxTokens: 20,
|
|
OutputReport: filepath.Join(t.TempDir(), "integration_report.html"),
|
|
Headers: make(map[string]string),
|
|
Payload: make(map[string]interface{}),
|
|
}
|
|
|
|
// --- 3. Initialize Components ---
|
|
httpClient := client.NewFastHTTPClient()
|
|
require.NotNil(t, httpClient, "Failed to create HTTP client")
|
|
|
|
statsCollector := stats.NewStatsCollector()
|
|
|
|
params := concurrency.BenchmarkParams{
|
|
TargetURL: cfg.APIEndpoint, // Use correct field name
|
|
HTTPClient: httpClient,
|
|
TotalDuration: cfg.Duration,
|
|
Concurrency: cfg.Concurrency,
|
|
RequestInterval: 0, // No artificial delay between requests
|
|
IsStreaming: cfg.Streaming,
|
|
PromptTokens: cfg.PromptTokens,
|
|
}
|
|
|
|
manager := concurrency.NewManager(params, statsCollector)
|
|
require.NotNil(t, manager, "Manager should not be nil")
|
|
|
|
// --- 4. Run Benchmark ---
|
|
t.Logf("Running benchmark against mock server: %s", cfg.APIEndpoint)
|
|
startTime := time.Now()
|
|
err := manager.RunBenchmark(context.Background()) // Pass context
|
|
actualDuration := time.Since(startTime)
|
|
require.NoError(t, err, "Benchmark run failed")
|
|
|
|
// --- 5. Get Final Stats ---
|
|
finalStats := statsCollector.CalculateStats(actualDuration)
|
|
t.Logf("Benchmark finished. Final stats: %+v", finalStats)
|
|
|
|
// Explicit check before assertion
|
|
if finalStats.TotalRequests > 0 {
|
|
t.Logf("Check PASSED: finalStats.TotalRequests (%d) > 0", finalStats.TotalRequests)
|
|
} else {
|
|
t.Errorf("Check FAILED: finalStats.TotalRequests (%d) is not > 0", finalStats.TotalRequests)
|
|
}
|
|
|
|
// Replace assert.Greater with standard check due to potential testify issue
|
|
if !(finalStats.TotalRequests > 0) { // Explicitly check the condition
|
|
t.Errorf("Test FAILED: Expected TotalRequests (%d) to be greater than 0", finalStats.TotalRequests)
|
|
}
|
|
|
|
// Replace assert.LessOrEqual with standard check due to potential testify issue
|
|
if !(finalStats.FailedRequests <= 1) { // Explicitly check the condition
|
|
t.Errorf("Test FAILED: Expected FailedRequests (%d) to be less than or equal to 1", finalStats.FailedRequests)
|
|
}
|
|
|
|
// --- 6. Generate Report ---
|
|
err = report.GenerateHTMLReport(finalStats, cfg.OutputReport) // Pass report path string
|
|
require.NoError(t, err, "Failed to generate HTML report")
|
|
t.Logf("Report generated at: %s", cfg.OutputReport)
|
|
|
|
// --- 7. Verify Report Content ---
|
|
reportBytes, err := os.ReadFile(cfg.OutputReport)
|
|
require.NoError(t, err, "Failed to read generated report file")
|
|
reportContent := string(reportBytes)
|
|
|
|
// Check for basic summary stats in the report (adjust format as needed)
|
|
assert.Contains(t, reportContent, "Total Requests:", "Report missing actual summary data (e.g., 'Total Requests:')")
|
|
assert.Contains(t, reportContent, fmt.Sprintf("Total Requests:</td><td>%d", finalStats.TotalRequests), "Report missing or incorrect total requests")
|
|
assert.Contains(t, reportContent, fmt.Sprintf("Successful Requests:</td><td>%d", finalStats.SuccessfulRequests), "Report missing or incorrect successful requests")
|
|
assert.Contains(t, reportContent, "Avg Latency:", "Report missing avg latency")
|
|
|
|
// Check for chart divs
|
|
assert.Contains(t, reportContent, "<div id=\"latencyHistogram\"", "Report missing latency histogram div")
|
|
assert.Contains(t, reportContent, "<div id=\"ttftHistogram\"", "Report missing TTFT histogram div")
|
|
|
|
// Check for chart initialization JS (presence, not exact content)
|
|
assert.Contains(t, reportContent, "echarts.init(document.getElementById('latencyHistogram')", "Report missing latency chart JS init")
|
|
assert.Contains(t, reportContent, "echarts.init(document.getElementById('ttftHistogram')", "Report missing TTFT chart JS init")
|
|
|
|
// Check that histogram data is not just the default empty/"No Data" case
|
|
// We need to find the data part of the JS options
|
|
latencyOptionStart := strings.Index(reportContent, "let option_latencyHistogram = {")
|
|
require.True(t, latencyOptionStart > 0, "Latency chart options not found in JS")
|
|
|
|
latencyOptionEnd := strings.Index(reportContent[latencyOptionStart:], "};")
|
|
require.True(t, latencyOptionEnd > 0, "Latency chart options end not found")
|
|
|
|
latencyOptionJS := reportContent[latencyOptionStart : latencyOptionStart+latencyOptionEnd]
|
|
|
|
// Check that the default "No Data" category isn't present if we have data
|
|
if len(finalStats.LatencyData) > 0 {
|
|
assert.NotContains(t, latencyOptionJS, `"name":"Latency","type":"bar","data":[{"value":0}]`, "Latency chart shows default empty data when actual data exists")
|
|
assert.NotContains(t, latencyOptionJS, `xaxis":[{"data":["No Data"]}]`, "Latency chart shows 'No Data' category when actual data exists") // Check category
|
|
}
|
|
|
|
// Check HTML escaping for summary values (find an example value)
|
|
avgLatencyStr := html.EscapeString(finalStats.AvgLatency.String())
|
|
assert.Contains(t, reportContent, avgLatencyStr, "AvgLatency string not found or not escaped in HTML")
|
|
|
|
t.Log("Integration test completed successfully.")
|
|
}
|
|
|
|
// TODO: Add another integration test case for streaming requests
|