package concurrency import ( "context" "encoding/json" "log" "sync" "time" "llm-api-benchmark-tool/pkg/client" "llm-api-benchmark-tool/pkg/stats" "llm-api-benchmark-tool/pkg/tokenizer" ) // BenchmarkParams holds the parameters for the benchmark run. type BenchmarkParams struct { HTTPClient client.HTTPClient // The HTTP client to use for requests TargetURL string // The target API endpoint APIKey string // API Key for authorization Model string // Model name to use PromptTokens int // Number of tokens for the generated prompt TotalDuration time.Duration // How long the benchmark should run Concurrency int // Number of concurrent workers/requests allowed simultaneously RequestInterval time.Duration // Minimum delay between starting *any* two requests globally IsStreaming bool // Indicates if the request expects a streaming response Payload map[string]interface{} // Optional: Custom payload fields } // Manager coordinates the benchmark execution based on the parameters. type Manager struct { params BenchmarkParams statsCollector *stats.StatsCollector } // NewManager creates a new concurrency manager. func NewManager(params BenchmarkParams, statsCollector *stats.StatsCollector) *Manager { return &Manager{ params: params, statsCollector: statsCollector, } } // RunBenchmark executes the benchmark according to the configured parameters. // It uses the provided context for cancellation. func (m *Manager) RunBenchmark(ctx context.Context) error { log.Printf("RunBenchmark called. RateLimit: %d", m.params.RequestInterval) if m.params.RequestInterval > 0 { return m.runBenchmarkWithTicker(ctx) } return m.runBenchmarkWithoutTicker(ctx) } // runBenchmarkWithTicker executes the benchmark with a fixed rate limit. func (m *Manager) runBenchmarkWithTicker(ctx context.Context) error { log.Printf("Starting runBenchmarkWithTicker. Duration: %s, RateLimit: %d, Concurrency: %d", m.params.TotalDuration, m.params.RequestInterval, m.params.Concurrency) runCtx, cancel := context.WithTimeout(ctx, m.params.TotalDuration) defer cancel() var wg sync.WaitGroup concurrencyLimiter := make(chan struct{}, m.params.Concurrency) ticker := time.NewTicker(m.params.RequestInterval) defer ticker.Stop() log.Printf("Ticker interval: %s", m.params.RequestInterval) wg.Add(1) go func() { defer wg.Done() requestCounter := 0 for { select { case <-runCtx.Done(): return case <-ticker.C: select { case concurrencyLimiter <- struct{}{}: wg.Add(1) go func(workerID int) { defer wg.Done() defer func() { <-concurrencyLimiter }() if runCtx.Err() != nil { return } promptText, err := tokenizer.GeneratePrompt(m.params.PromptTokens) if err != nil { log.Printf("Worker: Failed to generate prompt: %v\n", err) return } req := m.createRequest(promptText) startTime := time.Now() var result stats.RequestResult log.Printf("Worker %d making request (Streaming: %v)...", workerID, m.params.IsStreaming) if m.params.IsStreaming { var firstChunkTime time.Time var totalTokens int err = m.params.HTTPClient.Stream(req, func(chunk client.SSEChunk) error { if firstChunkTime.IsZero() { firstChunkTime = time.Now() } chunkTokens, err := tokenizer.CountTokensInText(string(chunk.Data)) if err != nil { log.Printf("Worker %d: Error counting stream chunk tokens: %v", workerID, err) } else { totalTokens += chunkTokens } return nil }) latency := time.Since(startTime) result = stats.RequestResult{ IsSuccess: err == nil, Latency: latency, TimeToFirstToken: firstChunkTime.Sub(startTime), TotalTokens: totalTokens, } log.Printf("Worker %d finished streaming request. Success: %v", workerID, result.IsSuccess) m.statsCollector.RecordResult(result) } else { resp, err := m.params.HTTPClient.Do(req) latency := time.Since(startTime) var statusCode int var totalTokens int if resp != nil { statusCode = resp.StatusCode bodyBytes := resp.Body bodyString := string(bodyBytes) totalTokens, err = tokenizer.CountTokensInText(bodyString) if err != nil { log.Printf("Worker %d: Error counting non-stream response tokens: %v", workerID, err) } } result = stats.RequestResult{ IsSuccess: err == nil && statusCode >= 200 && statusCode < 300, Latency: latency, TimeToFirstToken: latency, TotalTokens: totalTokens, } log.Printf("Worker %d finished non-streaming request. Success: %v, Status: %d", workerID, result.IsSuccess, statusCode) m.statsCollector.RecordResult(result) } }(requestCounter) requestCounter++ default: log.Printf("Skipping ticker event, concurrency limiter full (Size: %d)", len(concurrencyLimiter)) } } } }() wg.Wait() if runCtx.Err() == context.DeadlineExceeded { return nil } log.Printf("Ticker stopped. Waiting for %d workers to finish outstanding requests...", m.params.Concurrency) wg.Wait() return runCtx.Err() } // runBenchmarkWithoutTicker handles the benchmark execution when no rate limiting (RequestInterval <= 0) is needed. // It runs requests as fast as possible, respecting the concurrency limit. func (m *Manager) runBenchmarkWithoutTicker(ctx context.Context) error { log.Printf("Starting runBenchmarkWithoutTicker. Duration: %s, Concurrency: %d", m.params.TotalDuration, m.params.Concurrency) runCtx, cancel := context.WithTimeout(ctx, m.params.TotalDuration) defer cancel() var wg sync.WaitGroup concurrencyLimiter := make(chan struct{}, m.params.Concurrency) for i := 0; i < m.params.Concurrency; i++ { concurrencyLimiter <- struct{}{} } log.Printf("Starting %d workers", m.params.Concurrency) for i := 0; i < m.params.Concurrency; i++ { wg.Add(1) go func(workerID int) { defer wg.Done() log.Printf("Worker %d started", workerID) for { select { case <-runCtx.Done(): log.Printf("Worker %d exiting (context done before acquiring slot).", workerID) return case <-concurrencyLimiter: log.Printf("Worker %d acquired concurrency slot.", workerID) defer func() { log.Printf("Worker %d releasing concurrency slot.", workerID) concurrencyLimiter <- struct{}{} }() if runCtx.Err() != nil { log.Printf("Worker %d exiting (context done after acquiring slot).", workerID) return } promptText, err := tokenizer.GeneratePrompt(m.params.PromptTokens) if err != nil { log.Printf("Worker %d failed to generate prompt: %v. Skipping request.", workerID, err) continue } req := m.createRequest(promptText) startTime := time.Now() var result stats.RequestResult var totalTokens int firstChunkTime := time.Time{} log.Printf("Worker %d making request (Streaming: %v)...", workerID, m.params.IsStreaming) if m.params.IsStreaming { err = m.params.HTTPClient.Stream(req, func(chunk client.SSEChunk) error { if firstChunkTime.IsZero() { firstChunkTime = time.Now() } chunkTokens, err := tokenizer.CountTokensInText(string(chunk.Data)) if err != nil { log.Printf("Worker %d: Error counting stream chunk tokens: %v", workerID, err) } else { totalTokens += chunkTokens } return nil }) latency := time.Since(startTime) result = stats.RequestResult{ IsSuccess: err == nil, Latency: latency, TimeToFirstToken: firstChunkTime.Sub(startTime), TotalTokens: totalTokens, } log.Printf("Worker %d finished streaming request. Success: %v", workerID, result.IsSuccess) m.statsCollector.RecordResult(result) } else { resp, doErr := m.params.HTTPClient.Do(req) endTime := time.Now() latency := endTime.Sub(startTime) var statusCode int var totalTokens int if doErr != nil { log.Printf("Worker %d failed to send request (non-stream): %v", workerID, doErr) result = stats.RequestResult{ IsSuccess: false, Latency: latency, TimeToFirstToken: latency, TotalTokens: 0, } } else { statusCode = resp.StatusCode bodyBytes := resp.Body bodyString := string(bodyBytes) totalTokens, err = tokenizer.CountTokensInText(bodyString) if err != nil { log.Printf("Worker %d: Error counting non-stream response tokens: %v", workerID, err) } result = stats.RequestResult{ IsSuccess: resp != nil && resp.StatusCode >= 200 && resp.StatusCode < 300, Latency: latency, TimeToFirstToken: latency, TotalTokens: totalTokens, } log.Printf("Worker %d finished non-streaming request. Success: %v, Status: %d", workerID, result.IsSuccess, statusCode) } m.statsCollector.RecordResult(result) } } } }(i) } <-runCtx.Done() log.Printf("Benchmark duration (%s) finished or context cancelled.", m.params.TotalDuration) wg.Wait() log.Println("runBenchmarkWithoutTicker finished.") return nil } // createRequest creates a new client.Request object for the benchmark. // It generates a prompt using the tokenizer and constructs the request body. func (m *Manager) createRequest(prompt string) *client.Request { payload := map[string]interface{}{ "model": m.params.Model, "messages": []map[string]string{{"role": "user", "content": prompt}}, "stream": m.params.IsStreaming, } body, err := json.Marshal(payload) if err != nil { log.Printf("Error marshaling request body: %v", err) return nil } return &client.Request{ URL: m.params.TargetURL, Method: "POST", Headers: map[string]string{"Content-Type": "application/json", "Authorization": "Bearer " + m.params.APIKey}, Body: body, IsStream: m.params.IsStreaming, } }