197 lines
5.4 KiB
Go
197 lines
5.4 KiB
Go
package client
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// FastHTTPClient implements the HTTPClient interface using fasthttp for non-streaming
|
|
// and net/http for streaming, with configurable request timeout for non-streaming.
|
|
type FastHTTPClient struct {
|
|
reqPool sync.Pool
|
|
respPool sync.Pool
|
|
stdClient *http.Client // Standard client for streaming
|
|
requestTimeout time.Duration
|
|
}
|
|
|
|
// NewFastHTTPClient creates a new FastHTTPClient.
|
|
func NewFastHTTPClient(requestTimeout time.Duration) *FastHTTPClient {
|
|
return &FastHTTPClient{
|
|
reqPool: sync.Pool{
|
|
New: func() interface{} {
|
|
return fasthttp.AcquireRequest()
|
|
},
|
|
},
|
|
respPool: sync.Pool{
|
|
New: func() interface{} {
|
|
return fasthttp.AcquireResponse()
|
|
},
|
|
},
|
|
stdClient: &http.Client{
|
|
// Configure standard client (e.g., transport, timeout)
|
|
Timeout: 30 * time.Second, // Example timeout
|
|
},
|
|
requestTimeout: requestTimeout,
|
|
}
|
|
}
|
|
|
|
// Do sends a non-streaming request using fasthttp.
|
|
func (c *FastHTTPClient) Do(req *Request) (*Response, error) {
|
|
// 1. Acquire fasthttp request from pool
|
|
freq := c.reqPool.Get().(*fasthttp.Request)
|
|
defer c.reqPool.Put(freq) // Ensure request is put back even on error
|
|
freq.Reset()
|
|
|
|
// 2. Populate fasthttp request from input *Request
|
|
freq.SetRequestURI(req.URL)
|
|
freq.Header.SetMethod(req.Method)
|
|
if req.Body != nil {
|
|
freq.SetBodyRaw(req.Body) // Use SetBodyRaw to avoid copying
|
|
}
|
|
for k, v := range req.Headers {
|
|
freq.Header.Set(k, v)
|
|
}
|
|
|
|
// 3. Acquire fasthttp response from pool
|
|
fresp := c.respPool.Get().(*fasthttp.Response)
|
|
defer c.respPool.Put(fresp) // Ensure response is put back
|
|
fresp.Reset()
|
|
|
|
// 4. Perform the request
|
|
// Ensure request/response are released even if DoTimeout panics (though unlikely)
|
|
// Place deferred release calls *before* the operation that might fail
|
|
defer fasthttp.ReleaseRequest(freq)
|
|
defer fasthttp.ReleaseResponse(fresp)
|
|
|
|
// Perform the request using DoTimeout
|
|
err := fasthttp.DoTimeout(freq, fresp, c.requestTimeout)
|
|
if err != nil {
|
|
// Don't release again here, defer handles it
|
|
return nil, fmt.Errorf("fasthttp client timeout or connection error for %s %s: %w", req.Method, req.URL, err)
|
|
}
|
|
|
|
// 5. Populate output *Response from fasthttp response
|
|
resp := &Response{
|
|
StatusCode: fresp.StatusCode(),
|
|
Body: append([]byte(nil), fresp.Body()...), // Important: Copy body as fresp will be reused
|
|
Headers: make(map[string]string),
|
|
}
|
|
fresp.Header.VisitAll(func(key, value []byte) {
|
|
resp.Headers[string(key)] = string(value)
|
|
})
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
// Stream handles streaming SSE requests using the standard net/http client.
|
|
func (c *FastHTTPClient) Stream(req *Request, callback func(chunk SSEChunk) error) error {
|
|
// 1. Create standard http request
|
|
var bodyReader io.Reader
|
|
if req.Body != nil {
|
|
bodyReader = bytes.NewReader(req.Body)
|
|
}
|
|
// Use context for potential cancellation/timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Example: 60s timeout for the whole stream
|
|
defer cancel()
|
|
|
|
stdReq, err := http.NewRequestWithContext(ctx, req.Method, req.URL, bodyReader)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create standard http request for %s: %w", req.URL, err)
|
|
}
|
|
// Copy headers
|
|
for k, v := range req.Headers {
|
|
stdReq.Header.Set(k, v)
|
|
}
|
|
// Ensure Content-Type is set if needed (and not already set)
|
|
if bodyReader != nil && stdReq.Header.Get("Content-Type") == "" {
|
|
stdReq.Header.Set("Content-Type", "application/json")
|
|
}
|
|
// Ensure Accept header is set for SSE
|
|
if stdReq.Header.Get("Accept") == "" {
|
|
stdReq.Header.Set("Accept", "text/event-stream")
|
|
}
|
|
// Ensure Connection header for streaming
|
|
if stdReq.Header.Get("Connection") == "" {
|
|
stdReq.Header.Set("Connection", "keep-alive")
|
|
}
|
|
|
|
// 2. Perform the request
|
|
resp, err := c.stdClient.Do(stdReq)
|
|
if err != nil {
|
|
return fmt.Errorf("standard http client Do failed for %s: %w", req.URL, err)
|
|
}
|
|
defer resp.Body.Close() // Ensure body is closed
|
|
|
|
// 3. Check status code
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, readErr := io.ReadAll(resp.Body)
|
|
if readErr != nil {
|
|
return fmt.Errorf("http stream error status %d for %s, failed to read body: %w", resp.StatusCode, req.URL, readErr)
|
|
}
|
|
return fmt.Errorf("http stream error status %d for %s: %s", resp.StatusCode, req.URL, string(bodyBytes))
|
|
}
|
|
|
|
// 4. Process the stream
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
ssePrefix := []byte("data: ")
|
|
var processingErr error
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
if !bytes.HasPrefix(line, ssePrefix) {
|
|
continue
|
|
}
|
|
|
|
data := bytes.TrimPrefix(line, ssePrefix)
|
|
ts := time.Now().UnixNano()
|
|
var isDone bool
|
|
|
|
if bytes.Equal(data, []byte("[DONE]")) {
|
|
isDone = true
|
|
}
|
|
|
|
dataCopy := make([]byte, len(data))
|
|
copy(dataCopy, data)
|
|
|
|
chunk := SSEChunk{
|
|
Data: dataCopy,
|
|
Timestamp: ts,
|
|
IsDone: isDone,
|
|
}
|
|
|
|
if err := callback(chunk); err != nil {
|
|
if processingErr == nil {
|
|
processingErr = fmt.Errorf("sse stream user callback error: %w", err)
|
|
}
|
|
// Stop processing loop on user callback error
|
|
break
|
|
}
|
|
|
|
if isDone {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Check for scanner errors
|
|
if err := scanner.Err(); err != nil {
|
|
if processingErr == nil {
|
|
processingErr = fmt.Errorf("error reading sse stream for %s: %w", req.URL, err)
|
|
}
|
|
}
|
|
|
|
// Return processing error if any occurred during scan or callback
|
|
if processingErr != nil {
|
|
return processingErr
|
|
}
|
|
|
|
return nil
|
|
}
|