package test import ( "context" "fmt" // Added for createTestRequest "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "llm-api-benchmark-tool/pkg/client" // Needs client types "llm-api-benchmark-tool/pkg/concurrency" // Target package (to be created) "llm-api-benchmark-tool/pkg/stats" // Added for stats collector ) // --- Mock HTTP Client --- type MockHTTPClient struct { DoFunc func(req *client.Request) (*client.Response, error) StreamFunc func(req *client.Request, callback func(chunk client.SSEChunk) error) error ReqCount atomic.Int64 // Track calls } func (m *MockHTTPClient) Do(req *client.Request) (*client.Response, error) { m.ReqCount.Add(1) if m.DoFunc != nil { return m.DoFunc(req) } // Default mock response for Do return &client.Response{ StatusCode: 200, Body: []byte(`{"mock": "response"}`), }, nil } func (m *MockHTTPClient) Stream(req *client.Request, callback func(chunk client.SSEChunk) error) error { m.ReqCount.Add(1) if m.StreamFunc != nil { return m.StreamFunc(req, callback) } // Default mock stream behavior go func() { // Simulate receiving chunks time.Sleep(5 * time.Millisecond) // Simulate TTFT // Ensure callback is not nil before calling if callback != nil { callback(client.SSEChunk{Data: []byte(`{"mock": "chunk1"}`), Timestamp: time.Now().UnixNano()}) time.Sleep(10 * time.Millisecond) callback(client.SSEChunk{Data: []byte(`[DONE]`), Timestamp: time.Now().UnixNano(), IsDone: true}) } }() return nil } // --- Test Concurrency Manager --- func TestConcurrencyManager_RunBenchmark(t *testing.T) { mockClient := &MockHTTPClient{} params := concurrency.BenchmarkParams{ HTTPClient: mockClient, TargetURL: "http://mock.test/api", TotalDuration: 100 * time.Millisecond, // Short duration for testing Concurrency: 10, RequestInterval: 5 * time.Millisecond, // Interval between starting new requests IsStreaming: false, // Test non-streaming first PromptTokens: 50, // Example prompt size // Add other necessary params like API Key, Model later } // Create the manager (package and struct to be defined) statsCollector := stats.NewStatsCollector() // Create mock stats collector manager := concurrency.NewManager(params, statsCollector) // Assuming a constructor // Run the benchmark // TODO: Define the result struct if RunBenchmark returns results err := manager.RunBenchmark(context.Background()) // Use background context for now require.NoError(t, err, "RunBenchmark should complete without error") // Assertions // 1. Check if roughly the correct number of requests were made. // Expected rate = Concurrency / RequestInterval (adjust if interval is per worker) // Expected requests ~ TotalDuration / RequestInterval (simplified, needs refinement) // Let's just check if *some* requests were made for now. madeRequests := mockClient.ReqCount.Load() assert.Greater(t, madeRequests, int64(0), "Should have made at least one request") t.Logf("Total requests made by mock client: %d", madeRequests) // 2. Test with Streaming mockClientStreaming := &MockHTTPClient{} paramsStreaming := params paramsStreaming.HTTPClient = mockClientStreaming paramsStreaming.IsStreaming = true paramsStreaming.TotalDuration = 150 * time.Millisecond // Slightly longer for streaming statsCollectorStreaming := stats.NewStatsCollector() // Create mock stats collector managerStreaming := concurrency.NewManager(paramsStreaming, statsCollectorStreaming) errStreaming := managerStreaming.RunBenchmark(context.Background()) require.NoError(t, errStreaming, "RunBenchmark (streaming) should complete without error") madeRequestsStreaming := mockClientStreaming.ReqCount.Load() assert.Greater(t, madeRequestsStreaming, int64(0), "Should have made at least one streaming request") t.Logf("Total streaming requests made by mock client: %d", madeRequestsStreaming) // TODO: Add test case for context cancellation/timeout // TODO: Add test case for handling client errors // TODO: Add more precise assertions on request count and timing if possible } // Helper function to create a basic request (might move to a shared test util) func createTestRequest(url string, isStreaming bool, promptTokens int) *client.Request { // In a real scenario, we'd use the tokenizer here mockPrompt := fmt.Sprintf("mock prompt %d tokens", promptTokens) body := []byte(fmt.Sprintf(`{"prompt": "%s", "stream": %v}`, mockPrompt, isStreaming)) return &client.Request{ URL: url, Method: "POST", Headers: map[string]string{"Content-Type": "application/json"}, Body: body, } }