197 lines
5.4 KiB
Go

package client
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/valyala/fasthttp"
)
// FastHTTPClient implements the HTTPClient interface using fasthttp for non-streaming
// and net/http for streaming, with configurable request timeout for non-streaming.
type FastHTTPClient struct {
reqPool sync.Pool
respPool sync.Pool
stdClient *http.Client // Standard client for streaming
requestTimeout time.Duration
}
// NewFastHTTPClient creates a new FastHTTPClient.
func NewFastHTTPClient(requestTimeout time.Duration) *FastHTTPClient {
return &FastHTTPClient{
reqPool: sync.Pool{
New: func() interface{} {
return fasthttp.AcquireRequest()
},
},
respPool: sync.Pool{
New: func() interface{} {
return fasthttp.AcquireResponse()
},
},
stdClient: &http.Client{
// Configure standard client (e.g., transport, timeout)
Timeout: 30 * time.Second, // Example timeout
},
requestTimeout: requestTimeout,
}
}
// Do sends a non-streaming request using fasthttp.
func (c *FastHTTPClient) Do(req *Request) (*Response, error) {
// 1. Acquire fasthttp request from pool
freq := c.reqPool.Get().(*fasthttp.Request)
defer c.reqPool.Put(freq) // Ensure request is put back even on error
freq.Reset()
// 2. Populate fasthttp request from input *Request
freq.SetRequestURI(req.URL)
freq.Header.SetMethod(req.Method)
if req.Body != nil {
freq.SetBodyRaw(req.Body) // Use SetBodyRaw to avoid copying
}
for k, v := range req.Headers {
freq.Header.Set(k, v)
}
// 3. Acquire fasthttp response from pool
fresp := c.respPool.Get().(*fasthttp.Response)
defer c.respPool.Put(fresp) // Ensure response is put back
fresp.Reset()
// 4. Perform the request
// Ensure request/response are released even if DoTimeout panics (though unlikely)
// Place deferred release calls *before* the operation that might fail
defer fasthttp.ReleaseRequest(freq)
defer fasthttp.ReleaseResponse(fresp)
// Perform the request using DoTimeout
err := fasthttp.DoTimeout(freq, fresp, c.requestTimeout)
if err != nil {
// Don't release again here, defer handles it
return nil, fmt.Errorf("fasthttp client timeout or connection error for %s %s: %w", req.Method, req.URL, err)
}
// 5. Populate output *Response from fasthttp response
resp := &Response{
StatusCode: fresp.StatusCode(),
Body: append([]byte(nil), fresp.Body()...), // Important: Copy body as fresp will be reused
Headers: make(map[string]string),
}
fresp.Header.VisitAll(func(key, value []byte) {
resp.Headers[string(key)] = string(value)
})
return resp, nil
}
// Stream handles streaming SSE requests using the standard net/http client.
func (c *FastHTTPClient) Stream(req *Request, callback func(chunk SSEChunk) error) error {
// 1. Create standard http request
var bodyReader io.Reader
if req.Body != nil {
bodyReader = bytes.NewReader(req.Body)
}
// Use context for potential cancellation/timeout
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Example: 60s timeout for the whole stream
defer cancel()
stdReq, err := http.NewRequestWithContext(ctx, req.Method, req.URL, bodyReader)
if err != nil {
return fmt.Errorf("failed to create standard http request for %s: %w", req.URL, err)
}
// Copy headers
for k, v := range req.Headers {
stdReq.Header.Set(k, v)
}
// Ensure Content-Type is set if needed (and not already set)
if bodyReader != nil && stdReq.Header.Get("Content-Type") == "" {
stdReq.Header.Set("Content-Type", "application/json")
}
// Ensure Accept header is set for SSE
if stdReq.Header.Get("Accept") == "" {
stdReq.Header.Set("Accept", "text/event-stream")
}
// Ensure Connection header for streaming
if stdReq.Header.Get("Connection") == "" {
stdReq.Header.Set("Connection", "keep-alive")
}
// 2. Perform the request
resp, err := c.stdClient.Do(stdReq)
if err != nil {
return fmt.Errorf("standard http client Do failed for %s: %w", req.URL, err)
}
defer resp.Body.Close() // Ensure body is closed
// 3. Check status code
if resp.StatusCode != http.StatusOK {
bodyBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return fmt.Errorf("http stream error status %d for %s, failed to read body: %w", resp.StatusCode, req.URL, readErr)
}
return fmt.Errorf("http stream error status %d for %s: %s", resp.StatusCode, req.URL, string(bodyBytes))
}
// 4. Process the stream
scanner := bufio.NewScanner(resp.Body)
ssePrefix := []byte("data: ")
var processingErr error
for scanner.Scan() {
line := scanner.Bytes()
if !bytes.HasPrefix(line, ssePrefix) {
continue
}
data := bytes.TrimPrefix(line, ssePrefix)
ts := time.Now().UnixNano()
var isDone bool
if bytes.Equal(data, []byte("[DONE]")) {
isDone = true
}
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
chunk := SSEChunk{
Data: dataCopy,
Timestamp: ts,
IsDone: isDone,
}
if err := callback(chunk); err != nil {
if processingErr == nil {
processingErr = fmt.Errorf("sse stream user callback error: %w", err)
}
// Stop processing loop on user callback error
break
}
if isDone {
break
}
}
// Check for scanner errors
if err := scanner.Err(); err != nil {
if processingErr == nil {
processingErr = fmt.Errorf("error reading sse stream for %s: %w", req.URL, err)
}
}
// Return processing error if any occurred during scan or callback
if processingErr != nil {
return processingErr
}
return nil
}