llm-api-benchmark-tool/test/concurrency_test.go

125 lines
4.6 KiB
Go

package test
import (
"context"
"fmt" // Added for createTestRequest
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"llm-api-benchmark-tool/pkg/client" // Needs client types
"llm-api-benchmark-tool/pkg/concurrency" // Target package (to be created)
"llm-api-benchmark-tool/pkg/stats" // Added for stats collector
)
// --- Mock HTTP Client ---
type MockHTTPClient struct {
DoFunc func(req *client.Request) (*client.Response, error)
StreamFunc func(req *client.Request, callback func(chunk client.SSEChunk) error) error
ReqCount atomic.Int64 // Track calls
}
func (m *MockHTTPClient) Do(req *client.Request) (*client.Response, error) {
m.ReqCount.Add(1)
if m.DoFunc != nil {
return m.DoFunc(req)
}
// Default mock response for Do
return &client.Response{
StatusCode: 200,
Body: []byte(`{"mock": "response"}`),
}, nil
}
func (m *MockHTTPClient) Stream(req *client.Request, callback func(chunk client.SSEChunk) error) error {
m.ReqCount.Add(1)
if m.StreamFunc != nil {
return m.StreamFunc(req, callback)
}
// Default mock stream behavior
go func() { // Simulate receiving chunks
time.Sleep(5 * time.Millisecond) // Simulate TTFT
// Ensure callback is not nil before calling
if callback != nil {
callback(client.SSEChunk{Data: []byte(`{"mock": "chunk1"}`), Timestamp: time.Now().UnixNano()})
time.Sleep(10 * time.Millisecond)
callback(client.SSEChunk{Data: []byte(`[DONE]`), Timestamp: time.Now().UnixNano(), IsDone: true})
}
}()
return nil
}
// --- Test Concurrency Manager ---
func TestConcurrencyManager_RunBenchmark(t *testing.T) {
mockClient := &MockHTTPClient{}
params := concurrency.BenchmarkParams{
HTTPClient: mockClient,
TargetURL: "http://mock.test/api",
TotalDuration: 100 * time.Millisecond, // Short duration for testing
Concurrency: 10,
RequestInterval: 5 * time.Millisecond, // Interval between starting new requests
IsStreaming: false, // Test non-streaming first
PromptTokens: 50, // Example prompt size
// Add other necessary params like API Key, Model later
}
// Create the manager (package and struct to be defined)
statsCollector := stats.NewStatsCollector() // Create mock stats collector
manager := concurrency.NewManager(params, statsCollector) // Assuming a constructor
// Run the benchmark
// TODO: Define the result struct if RunBenchmark returns results
err := manager.RunBenchmark(context.Background()) // Use background context for now
require.NoError(t, err, "RunBenchmark should complete without error")
// Assertions
// 1. Check if roughly the correct number of requests were made.
// Expected rate = Concurrency / RequestInterval (adjust if interval is per worker)
// Expected requests ~ TotalDuration / RequestInterval (simplified, needs refinement)
// Let's just check if *some* requests were made for now.
madeRequests := mockClient.ReqCount.Load()
assert.Greater(t, madeRequests, int64(0), "Should have made at least one request")
t.Logf("Total requests made by mock client: %d", madeRequests)
// 2. Test with Streaming
mockClientStreaming := &MockHTTPClient{}
paramsStreaming := params
paramsStreaming.HTTPClient = mockClientStreaming
paramsStreaming.IsStreaming = true
paramsStreaming.TotalDuration = 150 * time.Millisecond // Slightly longer for streaming
statsCollectorStreaming := stats.NewStatsCollector() // Create mock stats collector
managerStreaming := concurrency.NewManager(paramsStreaming, statsCollectorStreaming)
errStreaming := managerStreaming.RunBenchmark(context.Background())
require.NoError(t, errStreaming, "RunBenchmark (streaming) should complete without error")
madeRequestsStreaming := mockClientStreaming.ReqCount.Load()
assert.Greater(t, madeRequestsStreaming, int64(0), "Should have made at least one streaming request")
t.Logf("Total streaming requests made by mock client: %d", madeRequestsStreaming)
// TODO: Add test case for context cancellation/timeout
// TODO: Add test case for handling client errors
// TODO: Add more precise assertions on request count and timing if possible
}
// Helper function to create a basic request (might move to a shared test util)
func createTestRequest(url string, isStreaming bool, promptTokens int) *client.Request {
// In a real scenario, we'd use the tokenizer here
mockPrompt := fmt.Sprintf("mock prompt %d tokens", promptTokens)
body := []byte(fmt.Sprintf(`{"prompt": "%s", "stream": %v}`, mockPrompt, isStreaming))
return &client.Request{
URL: url,
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json"},
Body: body,
}
}