171 lines
6.0 KiB
Go
171 lines
6.0 KiB
Go
package test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"llm-api-benchmark-tool/pkg/client"
|
|
"github.com/stretchr/testify/assert"
|
|
"io"
|
|
"fmt"
|
|
"time"
|
|
"sync"
|
|
)
|
|
|
|
// TestHTTPClientDo tests the non-streaming Do method of the HTTPClient interface.
|
|
func TestHTTPClientDo(t *testing.T) {
|
|
// 1. Mock server setup
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Basic validation of the incoming request
|
|
assert.Equal(t, "POST", r.Method)
|
|
assert.Equal(t, "/test-endpoint", r.URL.Path)
|
|
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
|
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
|
|
|
if r.ContentLength <= 0 {
|
|
t.Errorf("Server received zero or negative Content-Length: %d", r.ContentLength)
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
_, _ = w.Write([]byte(`{"error":"missing or invalid content length"}`))
|
|
return
|
|
}
|
|
|
|
bodyBytes := make([]byte, r.ContentLength)
|
|
_, err := io.ReadFull(r.Body, bodyBytes)
|
|
assert.NoError(t, err, "Server failed to read expected number of body bytes")
|
|
assert.Equal(t, `{"message":"ping"}`, string(bodyBytes))
|
|
|
|
// Send mock response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("X-Test-Header", "test-value")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"reply":"pong"}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
// 2. Instantiate client implementation
|
|
httpClient := client.NewFastHTTPClient() // Using fasthttp client
|
|
|
|
// 3. Create mock Request
|
|
req := &client.Request{
|
|
Method: "POST",
|
|
URL: server.URL + "/test-endpoint",
|
|
Body: []byte(`{"message":"ping"}`),
|
|
Headers: map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Authorization": "Bearer test-key",
|
|
},
|
|
}
|
|
|
|
// 4. Call Do method
|
|
resp, err := httpClient.Do(req)
|
|
|
|
// 5. Assert Response correctness
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, resp)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, `{"reply":"pong"}`, string(resp.Body))
|
|
assert.Equal(t, "application/json", resp.Headers["Content-Type"])
|
|
assert.Equal(t, "test-value", resp.Headers["X-Test-Header"])
|
|
|
|
// TODO: Add test cases for error handling (e.g., server error, network error)
|
|
}
|
|
|
|
// TestHTTPClientStream tests the streaming Stream method of the HTTPClient interface.
|
|
// It should verify SSE parsing, callback invocation, TTFT measurement, and handling of [DONE] message.
|
|
func TestHTTPClientStream(t *testing.T) {
|
|
// 1. Mock SSE server setup
|
|
mockSSEData := []string{
|
|
`data: {"id":"1","choices":[{"delta":{"role":"assistant"}}]}`, // Initial metadata
|
|
`data: {"id":"1","choices":[{"delta":{"content":"Hello"}}]}`, // First content
|
|
`data: {"id":"1","choices":[{"delta":{"reasoning_content":" Thinking..."}}]}`, // Reasoning content
|
|
`data: {"id":"1","choices":[{"delta":{"content":" world"}}]}`, // More content
|
|
`data: [DONE]`,
|
|
}
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "POST", r.Method)
|
|
assert.Equal(t, "/test-stream", r.URL.Path)
|
|
assert.Equal(t, "Bearer test-key-stream", r.Header.Get("Authorization"))
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
t.Fatal("Server does not support flushing")
|
|
}
|
|
|
|
for _, line := range mockSSEData {
|
|
fmt.Fprintf(w, "%s\n\n", line)
|
|
flusher.Flush() // Ensure data is sent immediately
|
|
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
// 2. Instantiate client implementation
|
|
httpClient := client.NewFastHTTPClient()
|
|
|
|
// 3. Create mock Request
|
|
req := &client.Request{
|
|
Method: "POST",
|
|
URL: server.URL + "/test-stream",
|
|
Body: []byte(`{"prompt":"test prompt","stream":true}`),
|
|
Headers: map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Authorization": "Bearer test-key-stream",
|
|
},
|
|
}
|
|
|
|
// 4. Define callback function
|
|
var receivedChunks []client.SSEChunk
|
|
var firstChunkTime int64
|
|
var mu sync.Mutex // Protect access to shared variables
|
|
|
|
callback := func(chunk client.SSEChunk) error {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
if firstChunkTime == 0 && len(chunk.Data) > 0 && !chunk.IsDone {
|
|
firstChunkTime = chunk.Timestamp // Record time of first data chunk
|
|
}
|
|
receivedChunks = append(receivedChunks, chunk)
|
|
// fmt.Printf("Callback received: IsDone=%v, Data=%s\n", chunk.IsDone, string(chunk.Data))
|
|
return nil // No error from callback in this test
|
|
}
|
|
|
|
// 5. Call Stream method (EXPECTED TO FAIL INITIALLY)
|
|
startTime := time.Now().UnixNano()
|
|
err := httpClient.Stream(req, callback)
|
|
|
|
// 6. Assert results (WILL FAIL INITIALLY)
|
|
// assert.NoError(t, err) // Initially, this will fail
|
|
if err != nil {
|
|
assert.ErrorContains(t, err, "FastHTTPClient.Stream not implemented yet")
|
|
// Since it's not implemented, we don't assert chunk content or TTFT yet
|
|
} else {
|
|
// --- Assertions for when Stream IS implemented ---
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
assert.NotEmpty(t, receivedChunks, "Callback should have been called")
|
|
assert.Len(t, receivedChunks, len(mockSSEData), "Should receive the same number of chunks as sent")
|
|
|
|
// Verify last chunk is [DONE]
|
|
lastChunk := receivedChunks[len(receivedChunks)-1]
|
|
assert.True(t, lastChunk.IsDone, "Last chunk should be marked as IsDone")
|
|
assert.Equal(t, "[DONE]", string(lastChunk.Data), "Last chunk data should be [DONE]")
|
|
|
|
// Verify TTFT calculation (placeholder)
|
|
assert.Greater(t, firstChunkTime, startTime, "First chunk time should be after start time")
|
|
ttftMs := float64(firstChunkTime-startTime) / 1e6
|
|
t.Logf("Measured TTFT: %.2f ms", ttftMs)
|
|
assert.Less(t, ttftMs, 100.0, "TTFT should be reasonably small in mock test") // Adjust threshold as needed
|
|
|
|
// TODO: Add more detailed assertions on specific chunk content if needed
|
|
// TODO: Add assertions for handling 'content' vs 'reasoning_content'
|
|
// TODO: Add test cases for error handling (server error, invalid SSE, callback error)
|
|
}
|
|
|
|
}
|
|
|
|
// TODO: Add helper functions for mock server/client setup if needed
|