125 lines
4.6 KiB
Go
125 lines
4.6 KiB
Go
package test
|
|
|
|
import (
|
|
"context"
|
|
"fmt" // Added for createTestRequest
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"llm-api-benchmark-tool/pkg/client" // Needs client types
|
|
"llm-api-benchmark-tool/pkg/concurrency" // Target package (to be created)
|
|
"llm-api-benchmark-tool/pkg/stats" // Added for stats collector
|
|
)
|
|
|
|
// --- Mock HTTP Client ---
|
|
|
|
type MockHTTPClient struct {
|
|
DoFunc func(req *client.Request) (*client.Response, error)
|
|
StreamFunc func(req *client.Request, callback func(chunk client.SSEChunk) error) error
|
|
ReqCount atomic.Int64 // Track calls
|
|
}
|
|
|
|
func (m *MockHTTPClient) Do(req *client.Request) (*client.Response, error) {
|
|
m.ReqCount.Add(1)
|
|
if m.DoFunc != nil {
|
|
return m.DoFunc(req)
|
|
}
|
|
// Default mock response for Do
|
|
return &client.Response{
|
|
StatusCode: 200,
|
|
Body: []byte(`{"mock": "response"}`),
|
|
}, nil
|
|
}
|
|
|
|
func (m *MockHTTPClient) Stream(req *client.Request, callback func(chunk client.SSEChunk) error) error {
|
|
m.ReqCount.Add(1)
|
|
if m.StreamFunc != nil {
|
|
return m.StreamFunc(req, callback)
|
|
}
|
|
// Default mock stream behavior
|
|
go func() { // Simulate receiving chunks
|
|
time.Sleep(5 * time.Millisecond) // Simulate TTFT
|
|
// Ensure callback is not nil before calling
|
|
if callback != nil {
|
|
callback(client.SSEChunk{Data: []byte(`{"mock": "chunk1"}`), Timestamp: time.Now().UnixNano()})
|
|
time.Sleep(10 * time.Millisecond)
|
|
callback(client.SSEChunk{Data: []byte(`[DONE]`), Timestamp: time.Now().UnixNano(), IsDone: true})
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
// --- Test Concurrency Manager ---
|
|
|
|
func TestConcurrencyManager_RunBenchmark(t *testing.T) {
|
|
mockClient := &MockHTTPClient{}
|
|
|
|
params := concurrency.BenchmarkParams{
|
|
HTTPClient: mockClient,
|
|
TargetURL: "http://mock.test/api",
|
|
TotalDuration: 100 * time.Millisecond, // Short duration for testing
|
|
Concurrency: 10,
|
|
RequestInterval: 5 * time.Millisecond, // Interval between starting new requests
|
|
IsStreaming: false, // Test non-streaming first
|
|
PromptTokens: 50, // Example prompt size
|
|
// Add other necessary params like API Key, Model later
|
|
}
|
|
|
|
// Create the manager (package and struct to be defined)
|
|
statsCollector := stats.NewStatsCollector() // Create mock stats collector
|
|
manager := concurrency.NewManager(params, statsCollector) // Assuming a constructor
|
|
|
|
// Run the benchmark
|
|
// TODO: Define the result struct if RunBenchmark returns results
|
|
err := manager.RunBenchmark(context.Background()) // Use background context for now
|
|
|
|
require.NoError(t, err, "RunBenchmark should complete without error")
|
|
|
|
// Assertions
|
|
// 1. Check if roughly the correct number of requests were made.
|
|
// Expected rate = Concurrency / RequestInterval (adjust if interval is per worker)
|
|
// Expected requests ~ TotalDuration / RequestInterval (simplified, needs refinement)
|
|
// Let's just check if *some* requests were made for now.
|
|
madeRequests := mockClient.ReqCount.Load()
|
|
assert.Greater(t, madeRequests, int64(0), "Should have made at least one request")
|
|
t.Logf("Total requests made by mock client: %d", madeRequests)
|
|
|
|
// 2. Test with Streaming
|
|
mockClientStreaming := &MockHTTPClient{}
|
|
paramsStreaming := params
|
|
paramsStreaming.HTTPClient = mockClientStreaming
|
|
paramsStreaming.IsStreaming = true
|
|
paramsStreaming.TotalDuration = 150 * time.Millisecond // Slightly longer for streaming
|
|
|
|
statsCollectorStreaming := stats.NewStatsCollector() // Create mock stats collector
|
|
managerStreaming := concurrency.NewManager(paramsStreaming, statsCollectorStreaming)
|
|
errStreaming := managerStreaming.RunBenchmark(context.Background())
|
|
|
|
require.NoError(t, errStreaming, "RunBenchmark (streaming) should complete without error")
|
|
madeRequestsStreaming := mockClientStreaming.ReqCount.Load()
|
|
assert.Greater(t, madeRequestsStreaming, int64(0), "Should have made at least one streaming request")
|
|
t.Logf("Total streaming requests made by mock client: %d", madeRequestsStreaming)
|
|
|
|
|
|
// TODO: Add test case for context cancellation/timeout
|
|
// TODO: Add test case for handling client errors
|
|
// TODO: Add more precise assertions on request count and timing if possible
|
|
}
|
|
|
|
// Helper function to create a basic request (might move to a shared test util)
|
|
func createTestRequest(url string, isStreaming bool, promptTokens int) *client.Request {
|
|
// In a real scenario, we'd use the tokenizer here
|
|
mockPrompt := fmt.Sprintf("mock prompt %d tokens", promptTokens)
|
|
body := []byte(fmt.Sprintf(`{"prompt": "%s", "stream": %v}`, mockPrompt, isStreaming))
|
|
return &client.Request{
|
|
URL: url,
|
|
Method: "POST",
|
|
Headers: map[string]string{"Content-Type": "application/json"},
|
|
Body: body,
|
|
}
|
|
}
|