309 lines
10 KiB
Go

package concurrency
import (
"context"
"encoding/json"
"log"
"sync"
"time"
"llm-api-benchmark-tool/pkg/client"
"llm-api-benchmark-tool/pkg/stats"
"llm-api-benchmark-tool/pkg/tokenizer"
)
// BenchmarkParams holds the parameters for the benchmark run.
type BenchmarkParams struct {
HTTPClient client.HTTPClient // The HTTP client to use for requests
TargetURL string // The target API endpoint
APIKey string // API Key for authorization
Model string // Model name to use
PromptTokens int // Number of tokens for the generated prompt
TotalDuration time.Duration // How long the benchmark should run
Concurrency int // Number of concurrent workers/requests allowed simultaneously
RequestInterval time.Duration // Minimum delay between starting *any* two requests globally
IsStreaming bool // Indicates if the request expects a streaming response
Payload map[string]interface{} // Optional: Custom payload fields
}
// Manager coordinates the benchmark execution based on the parameters.
type Manager struct {
params BenchmarkParams
statsCollector *stats.StatsCollector
}
// NewManager creates a new concurrency manager.
func NewManager(params BenchmarkParams, statsCollector *stats.StatsCollector) *Manager {
return &Manager{
params: params,
statsCollector: statsCollector,
}
}
// RunBenchmark executes the benchmark according to the configured parameters.
// It uses the provided context for cancellation.
func (m *Manager) RunBenchmark(ctx context.Context) error {
log.Printf("RunBenchmark called. RateLimit: %d", m.params.RequestInterval)
if m.params.RequestInterval > 0 {
return m.runBenchmarkWithTicker(ctx)
}
return m.runBenchmarkWithoutTicker(ctx)
}
// runBenchmarkWithTicker executes the benchmark with a fixed rate limit.
func (m *Manager) runBenchmarkWithTicker(ctx context.Context) error {
log.Printf("Starting runBenchmarkWithTicker. Duration: %s, RateLimit: %d, Concurrency: %d",
m.params.TotalDuration, m.params.RequestInterval, m.params.Concurrency)
runCtx, cancel := context.WithTimeout(ctx, m.params.TotalDuration)
defer cancel()
var wg sync.WaitGroup
concurrencyLimiter := make(chan struct{}, m.params.Concurrency)
ticker := time.NewTicker(m.params.RequestInterval)
defer ticker.Stop()
log.Printf("Ticker interval: %s", m.params.RequestInterval)
wg.Add(1)
go func() {
defer wg.Done()
requestCounter := 0
for {
select {
case <-runCtx.Done():
return
case <-ticker.C:
select {
case concurrencyLimiter <- struct{}{}:
wg.Add(1)
go func(workerID int) {
defer wg.Done()
defer func() { <-concurrencyLimiter }()
if runCtx.Err() != nil {
return
}
promptText, err := tokenizer.GeneratePrompt(m.params.PromptTokens)
if err != nil {
log.Printf("Worker: Failed to generate prompt: %v\n", err)
return
}
req := m.createRequest(promptText)
startTime := time.Now()
var result stats.RequestResult
log.Printf("Worker %d making request (Streaming: %v)...", workerID, m.params.IsStreaming)
if m.params.IsStreaming {
var firstChunkTime time.Time
var totalTokens int
err = m.params.HTTPClient.Stream(req, func(chunk client.SSEChunk) error {
if firstChunkTime.IsZero() {
firstChunkTime = time.Now()
}
chunkTokens, err := tokenizer.CountTokensInText(string(chunk.Data))
if err != nil {
log.Printf("Worker %d: Error counting stream chunk tokens: %v", workerID, err)
} else {
totalTokens += chunkTokens
}
return nil
})
latency := time.Since(startTime)
result = stats.RequestResult{
IsSuccess: err == nil,
Latency: latency,
TimeToFirstToken: firstChunkTime.Sub(startTime),
TotalTokens: totalTokens,
}
log.Printf("Worker %d finished streaming request. Success: %v", workerID, result.IsSuccess)
m.statsCollector.RecordResult(result)
} else {
resp, err := m.params.HTTPClient.Do(req)
latency := time.Since(startTime)
var statusCode int
var totalTokens int
if resp != nil {
statusCode = resp.StatusCode
bodyBytes := resp.Body
bodyString := string(bodyBytes)
totalTokens, err = tokenizer.CountTokensInText(bodyString)
if err != nil {
log.Printf("Worker %d: Error counting non-stream response tokens: %v", workerID, err)
}
}
result = stats.RequestResult{
IsSuccess: err == nil && statusCode >= 200 && statusCode < 300,
Latency: latency,
TimeToFirstToken: latency,
TotalTokens: totalTokens,
}
log.Printf("Worker %d finished non-streaming request. Success: %v, Status: %d", workerID, result.IsSuccess, statusCode)
m.statsCollector.RecordResult(result)
}
}(requestCounter)
requestCounter++
default:
log.Printf("Skipping ticker event, concurrency limiter full (Size: %d)", len(concurrencyLimiter))
}
}
}
}()
wg.Wait()
if runCtx.Err() == context.DeadlineExceeded {
return nil
}
log.Printf("Ticker stopped. Waiting for %d workers to finish outstanding requests...", m.params.Concurrency)
wg.Wait()
return runCtx.Err()
}
// runBenchmarkWithoutTicker handles the benchmark execution when no rate limiting (RequestInterval <= 0) is needed.
// It runs requests as fast as possible, respecting the concurrency limit.
func (m *Manager) runBenchmarkWithoutTicker(ctx context.Context) error {
log.Printf("Starting runBenchmarkWithoutTicker. Duration: %s, Concurrency: %d",
m.params.TotalDuration, m.params.Concurrency)
runCtx, cancel := context.WithTimeout(ctx, m.params.TotalDuration)
defer cancel()
var wg sync.WaitGroup
concurrencyLimiter := make(chan struct{}, m.params.Concurrency)
for i := 0; i < m.params.Concurrency; i++ {
concurrencyLimiter <- struct{}{}
}
log.Printf("Starting %d workers", m.params.Concurrency)
for i := 0; i < m.params.Concurrency; i++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
log.Printf("Worker %d started", workerID)
for {
select {
case <-runCtx.Done():
log.Printf("Worker %d exiting (context done before acquiring slot).", workerID)
return
case <-concurrencyLimiter:
log.Printf("Worker %d acquired concurrency slot.", workerID)
defer func() {
log.Printf("Worker %d releasing concurrency slot.", workerID)
concurrencyLimiter <- struct{}{}
}()
if runCtx.Err() != nil {
log.Printf("Worker %d exiting (context done after acquiring slot).", workerID)
return
}
promptText, err := tokenizer.GeneratePrompt(m.params.PromptTokens)
if err != nil {
log.Printf("Worker %d failed to generate prompt: %v. Skipping request.", workerID, err)
continue
}
req := m.createRequest(promptText)
startTime := time.Now()
var result stats.RequestResult
var totalTokens int
firstChunkTime := time.Time{}
log.Printf("Worker %d making request (Streaming: %v)...", workerID, m.params.IsStreaming)
if m.params.IsStreaming {
err = m.params.HTTPClient.Stream(req, func(chunk client.SSEChunk) error {
if firstChunkTime.IsZero() {
firstChunkTime = time.Now()
}
chunkTokens, err := tokenizer.CountTokensInText(string(chunk.Data))
if err != nil {
log.Printf("Worker %d: Error counting stream chunk tokens: %v", workerID, err)
} else {
totalTokens += chunkTokens
}
return nil
})
latency := time.Since(startTime)
result = stats.RequestResult{
IsSuccess: err == nil,
Latency: latency,
TimeToFirstToken: firstChunkTime.Sub(startTime),
TotalTokens: totalTokens,
}
log.Printf("Worker %d finished streaming request. Success: %v", workerID, result.IsSuccess)
m.statsCollector.RecordResult(result)
} else {
resp, doErr := m.params.HTTPClient.Do(req)
endTime := time.Now()
latency := endTime.Sub(startTime)
var statusCode int
var totalTokens int
if doErr != nil {
log.Printf("Worker %d failed to send request (non-stream): %v", workerID, doErr)
result = stats.RequestResult{
IsSuccess: false,
Latency: latency,
TimeToFirstToken: latency,
TotalTokens: 0,
}
} else {
statusCode = resp.StatusCode
bodyBytes := resp.Body
bodyString := string(bodyBytes)
totalTokens, err = tokenizer.CountTokensInText(bodyString)
if err != nil {
log.Printf("Worker %d: Error counting non-stream response tokens: %v", workerID, err)
}
result = stats.RequestResult{
IsSuccess: resp != nil && resp.StatusCode >= 200 && resp.StatusCode < 300,
Latency: latency,
TimeToFirstToken: latency,
TotalTokens: totalTokens,
}
log.Printf("Worker %d finished non-streaming request. Success: %v, Status: %d", workerID, result.IsSuccess, statusCode)
}
m.statsCollector.RecordResult(result)
}
}
}
}(i)
}
<-runCtx.Done()
log.Printf("Benchmark duration (%s) finished or context cancelled.", m.params.TotalDuration)
wg.Wait()
log.Println("runBenchmarkWithoutTicker finished.")
return nil
}
// createRequest creates a new client.Request object for the benchmark.
// It generates a prompt using the tokenizer and constructs the request body.
func (m *Manager) createRequest(prompt string) *client.Request {
payload := map[string]interface{}{
"model": m.params.Model,
"messages": []map[string]string{{"role": "user", "content": prompt}},
"stream": m.params.IsStreaming,
}
body, err := json.Marshal(payload)
if err != nil {
log.Printf("Error marshaling request body: %v", err)
return nil
}
return &client.Request{
URL: m.params.TargetURL,
Method: "POST",
Headers: map[string]string{"Content-Type": "application/json", "Authorization": "Bearer " + m.params.APIKey},
Body: body,
IsStream: m.params.IsStreaming,
}
}