171 lines
6.0 KiB
Go

package test
import (
"net/http"
"net/http/httptest"
"testing"
"llm-api-benchmark-tool/pkg/client"
"github.com/stretchr/testify/assert"
"io"
"fmt"
"time"
"sync"
)
// TestHTTPClientDo tests the non-streaming Do method of the HTTPClient interface.
func TestHTTPClientDo(t *testing.T) {
// 1. Mock server setup
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Basic validation of the incoming request
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/test-endpoint", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
if r.ContentLength <= 0 {
t.Errorf("Server received zero or negative Content-Length: %d", r.ContentLength)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(`{"error":"missing or invalid content length"}`))
return
}
bodyBytes := make([]byte, r.ContentLength)
_, err := io.ReadFull(r.Body, bodyBytes)
assert.NoError(t, err, "Server failed to read expected number of body bytes")
assert.Equal(t, `{"message":"ping"}`, string(bodyBytes))
// Send mock response
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Test-Header", "test-value")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"reply":"pong"}`))
}))
defer server.Close()
// 2. Instantiate client implementation
httpClient := client.NewFastHTTPClient() // Using fasthttp client
// 3. Create mock Request
req := &client.Request{
Method: "POST",
URL: server.URL + "/test-endpoint",
Body: []byte(`{"message":"ping"}`),
Headers: map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer test-key",
},
}
// 4. Call Do method
resp, err := httpClient.Do(req)
// 5. Assert Response correctness
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, `{"reply":"pong"}`, string(resp.Body))
assert.Equal(t, "application/json", resp.Headers["Content-Type"])
assert.Equal(t, "test-value", resp.Headers["X-Test-Header"])
// TODO: Add test cases for error handling (e.g., server error, network error)
}
// TestHTTPClientStream tests the streaming Stream method of the HTTPClient interface.
// It should verify SSE parsing, callback invocation, TTFT measurement, and handling of [DONE] message.
func TestHTTPClientStream(t *testing.T) {
// 1. Mock SSE server setup
mockSSEData := []string{
`data: {"id":"1","choices":[{"delta":{"role":"assistant"}}]}`, // Initial metadata
`data: {"id":"1","choices":[{"delta":{"content":"Hello"}}]}`, // First content
`data: {"id":"1","choices":[{"delta":{"reasoning_content":" Thinking..."}}]}`, // Reasoning content
`data: {"id":"1","choices":[{"delta":{"content":" world"}}]}`, // More content
`data: [DONE]`,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/test-stream", r.URL.Path)
assert.Equal(t, "Bearer test-key-stream", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("Server does not support flushing")
}
for _, line := range mockSSEData {
fmt.Fprintf(w, "%s\n\n", line)
flusher.Flush() // Ensure data is sent immediately
time.Sleep(10 * time.Millisecond) // Small delay between chunks
}
}))
defer server.Close()
// 2. Instantiate client implementation
httpClient := client.NewFastHTTPClient()
// 3. Create mock Request
req := &client.Request{
Method: "POST",
URL: server.URL + "/test-stream",
Body: []byte(`{"prompt":"test prompt","stream":true}`),
Headers: map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer test-key-stream",
},
}
// 4. Define callback function
var receivedChunks []client.SSEChunk
var firstChunkTime int64
var mu sync.Mutex // Protect access to shared variables
callback := func(chunk client.SSEChunk) error {
mu.Lock()
defer mu.Unlock()
if firstChunkTime == 0 && len(chunk.Data) > 0 && !chunk.IsDone {
firstChunkTime = chunk.Timestamp // Record time of first data chunk
}
receivedChunks = append(receivedChunks, chunk)
// fmt.Printf("Callback received: IsDone=%v, Data=%s\n", chunk.IsDone, string(chunk.Data))
return nil // No error from callback in this test
}
// 5. Call Stream method (EXPECTED TO FAIL INITIALLY)
startTime := time.Now().UnixNano()
err := httpClient.Stream(req, callback)
// 6. Assert results (WILL FAIL INITIALLY)
// assert.NoError(t, err) // Initially, this will fail
if err != nil {
assert.ErrorContains(t, err, "FastHTTPClient.Stream not implemented yet")
// Since it's not implemented, we don't assert chunk content or TTFT yet
} else {
// --- Assertions for when Stream IS implemented ---
mu.Lock()
defer mu.Unlock()
assert.NotEmpty(t, receivedChunks, "Callback should have been called")
assert.Len(t, receivedChunks, len(mockSSEData), "Should receive the same number of chunks as sent")
// Verify last chunk is [DONE]
lastChunk := receivedChunks[len(receivedChunks)-1]
assert.True(t, lastChunk.IsDone, "Last chunk should be marked as IsDone")
assert.Equal(t, "[DONE]", string(lastChunk.Data), "Last chunk data should be [DONE]")
// Verify TTFT calculation (placeholder)
assert.Greater(t, firstChunkTime, startTime, "First chunk time should be after start time")
ttftMs := float64(firstChunkTime-startTime) / 1e6
t.Logf("Measured TTFT: %.2f ms", ttftMs)
assert.Less(t, ttftMs, 100.0, "TTFT should be reasonably small in mock test") // Adjust threshold as needed
// TODO: Add more detailed assertions on specific chunk content if needed
// TODO: Add assertions for handling 'content' vs 'reasoning_content'
// TODO: Add test cases for error handling (server error, invalid SSE, callback error)
}
}
// TODO: Add helper functions for mock server/client setup if needed