309 lines
10 KiB
Go
309 lines
10 KiB
Go
package concurrency
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
"llm-api-benchmark-tool/pkg/client"
|
|
"llm-api-benchmark-tool/pkg/stats"
|
|
"llm-api-benchmark-tool/pkg/tokenizer"
|
|
)
|
|
|
|
// BenchmarkParams holds the parameters for the benchmark run.
|
|
type BenchmarkParams struct {
|
|
HTTPClient client.HTTPClient // The HTTP client to use for requests
|
|
TargetURL string // The target API endpoint
|
|
APIKey string // API Key for authorization
|
|
Model string // Model name to use
|
|
PromptTokens int // Number of tokens for the generated prompt
|
|
TotalDuration time.Duration // How long the benchmark should run
|
|
Concurrency int // Number of concurrent workers/requests allowed simultaneously
|
|
RequestInterval time.Duration // Minimum delay between starting *any* two requests globally
|
|
IsStreaming bool // Indicates if the request expects a streaming response
|
|
Payload map[string]interface{} // Optional: Custom payload fields
|
|
}
|
|
|
|
// Manager coordinates the benchmark execution based on the parameters.
|
|
type Manager struct {
|
|
params BenchmarkParams
|
|
statsCollector *stats.StatsCollector
|
|
}
|
|
|
|
// NewManager creates a new concurrency manager.
|
|
func NewManager(params BenchmarkParams, statsCollector *stats.StatsCollector) *Manager {
|
|
return &Manager{
|
|
params: params,
|
|
statsCollector: statsCollector,
|
|
}
|
|
}
|
|
|
|
// RunBenchmark executes the benchmark according to the configured parameters.
|
|
// It uses the provided context for cancellation.
|
|
func (m *Manager) RunBenchmark(ctx context.Context) error {
|
|
log.Printf("RunBenchmark called. RateLimit: %d", m.params.RequestInterval)
|
|
if m.params.RequestInterval > 0 {
|
|
return m.runBenchmarkWithTicker(ctx)
|
|
}
|
|
return m.runBenchmarkWithoutTicker(ctx)
|
|
}
|
|
|
|
// runBenchmarkWithTicker executes the benchmark with a fixed rate limit.
|
|
func (m *Manager) runBenchmarkWithTicker(ctx context.Context) error {
|
|
log.Printf("Starting runBenchmarkWithTicker. Duration: %s, RateLimit: %d, Concurrency: %d",
|
|
m.params.TotalDuration, m.params.RequestInterval, m.params.Concurrency)
|
|
|
|
runCtx, cancel := context.WithTimeout(ctx, m.params.TotalDuration)
|
|
defer cancel()
|
|
|
|
var wg sync.WaitGroup
|
|
concurrencyLimiter := make(chan struct{}, m.params.Concurrency)
|
|
|
|
ticker := time.NewTicker(m.params.RequestInterval)
|
|
defer ticker.Stop()
|
|
|
|
log.Printf("Ticker interval: %s", m.params.RequestInterval)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
requestCounter := 0
|
|
for {
|
|
select {
|
|
case <-runCtx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
select {
|
|
case concurrencyLimiter <- struct{}{}:
|
|
wg.Add(1)
|
|
go func(workerID int) {
|
|
defer wg.Done()
|
|
defer func() { <-concurrencyLimiter }()
|
|
|
|
if runCtx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
promptText, err := tokenizer.GeneratePrompt(m.params.PromptTokens)
|
|
if err != nil {
|
|
log.Printf("Worker: Failed to generate prompt: %v\n", err)
|
|
return
|
|
}
|
|
|
|
req := m.createRequest(promptText)
|
|
startTime := time.Now()
|
|
var result stats.RequestResult
|
|
|
|
log.Printf("Worker %d making request (Streaming: %v)...", workerID, m.params.IsStreaming)
|
|
if m.params.IsStreaming {
|
|
var firstChunkTime time.Time
|
|
var totalTokens int
|
|
err = m.params.HTTPClient.Stream(req, func(chunk client.SSEChunk) error {
|
|
if firstChunkTime.IsZero() {
|
|
firstChunkTime = time.Now()
|
|
}
|
|
chunkTokens, err := tokenizer.CountTokensInText(string(chunk.Data))
|
|
if err != nil {
|
|
log.Printf("Worker %d: Error counting stream chunk tokens: %v", workerID, err)
|
|
} else {
|
|
totalTokens += chunkTokens
|
|
}
|
|
return nil
|
|
})
|
|
latency := time.Since(startTime)
|
|
result = stats.RequestResult{
|
|
IsSuccess: err == nil,
|
|
Latency: latency,
|
|
TimeToFirstToken: firstChunkTime.Sub(startTime),
|
|
TotalTokens: totalTokens,
|
|
}
|
|
log.Printf("Worker %d finished streaming request. Success: %v", workerID, result.IsSuccess)
|
|
m.statsCollector.RecordResult(result)
|
|
} else {
|
|
resp, err := m.params.HTTPClient.Do(req)
|
|
latency := time.Since(startTime)
|
|
var statusCode int
|
|
var totalTokens int
|
|
if resp != nil {
|
|
statusCode = resp.StatusCode
|
|
bodyBytes := resp.Body
|
|
bodyString := string(bodyBytes)
|
|
totalTokens, err = tokenizer.CountTokensInText(bodyString)
|
|
if err != nil {
|
|
log.Printf("Worker %d: Error counting non-stream response tokens: %v", workerID, err)
|
|
}
|
|
}
|
|
result = stats.RequestResult{
|
|
IsSuccess: err == nil && statusCode >= 200 && statusCode < 300,
|
|
Latency: latency,
|
|
TimeToFirstToken: latency,
|
|
TotalTokens: totalTokens,
|
|
}
|
|
log.Printf("Worker %d finished non-streaming request. Success: %v, Status: %d", workerID, result.IsSuccess, statusCode)
|
|
m.statsCollector.RecordResult(result)
|
|
}
|
|
|
|
}(requestCounter)
|
|
requestCounter++
|
|
default:
|
|
log.Printf("Skipping ticker event, concurrency limiter full (Size: %d)", len(concurrencyLimiter))
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
|
|
if runCtx.Err() == context.DeadlineExceeded {
|
|
return nil
|
|
}
|
|
log.Printf("Ticker stopped. Waiting for %d workers to finish outstanding requests...", m.params.Concurrency)
|
|
wg.Wait()
|
|
return runCtx.Err()
|
|
}
|
|
|
|
// runBenchmarkWithoutTicker handles the benchmark execution when no rate limiting (RequestInterval <= 0) is needed.
|
|
// It runs requests as fast as possible, respecting the concurrency limit.
|
|
func (m *Manager) runBenchmarkWithoutTicker(ctx context.Context) error {
|
|
log.Printf("Starting runBenchmarkWithoutTicker. Duration: %s, Concurrency: %d",
|
|
m.params.TotalDuration, m.params.Concurrency)
|
|
|
|
runCtx, cancel := context.WithTimeout(ctx, m.params.TotalDuration)
|
|
defer cancel()
|
|
|
|
var wg sync.WaitGroup
|
|
concurrencyLimiter := make(chan struct{}, m.params.Concurrency)
|
|
|
|
for i := 0; i < m.params.Concurrency; i++ {
|
|
concurrencyLimiter <- struct{}{}
|
|
}
|
|
|
|
log.Printf("Starting %d workers", m.params.Concurrency)
|
|
for i := 0; i < m.params.Concurrency; i++ {
|
|
wg.Add(1)
|
|
go func(workerID int) {
|
|
defer wg.Done()
|
|
log.Printf("Worker %d started", workerID)
|
|
for {
|
|
select {
|
|
case <-runCtx.Done():
|
|
log.Printf("Worker %d exiting (context done before acquiring slot).", workerID)
|
|
return
|
|
case <-concurrencyLimiter:
|
|
log.Printf("Worker %d acquired concurrency slot.", workerID)
|
|
defer func() {
|
|
log.Printf("Worker %d releasing concurrency slot.", workerID)
|
|
concurrencyLimiter <- struct{}{}
|
|
}()
|
|
|
|
if runCtx.Err() != nil {
|
|
log.Printf("Worker %d exiting (context done after acquiring slot).", workerID)
|
|
return
|
|
}
|
|
|
|
promptText, err := tokenizer.GeneratePrompt(m.params.PromptTokens)
|
|
if err != nil {
|
|
log.Printf("Worker %d failed to generate prompt: %v. Skipping request.", workerID, err)
|
|
continue
|
|
}
|
|
|
|
req := m.createRequest(promptText)
|
|
startTime := time.Now()
|
|
var result stats.RequestResult
|
|
var totalTokens int
|
|
firstChunkTime := time.Time{}
|
|
|
|
log.Printf("Worker %d making request (Streaming: %v)...", workerID, m.params.IsStreaming)
|
|
if m.params.IsStreaming {
|
|
err = m.params.HTTPClient.Stream(req, func(chunk client.SSEChunk) error {
|
|
if firstChunkTime.IsZero() {
|
|
firstChunkTime = time.Now()
|
|
}
|
|
chunkTokens, err := tokenizer.CountTokensInText(string(chunk.Data))
|
|
if err != nil {
|
|
log.Printf("Worker %d: Error counting stream chunk tokens: %v", workerID, err)
|
|
} else {
|
|
totalTokens += chunkTokens
|
|
}
|
|
return nil
|
|
})
|
|
latency := time.Since(startTime)
|
|
result = stats.RequestResult{
|
|
IsSuccess: err == nil,
|
|
Latency: latency,
|
|
TimeToFirstToken: firstChunkTime.Sub(startTime),
|
|
TotalTokens: totalTokens,
|
|
}
|
|
log.Printf("Worker %d finished streaming request. Success: %v", workerID, result.IsSuccess)
|
|
m.statsCollector.RecordResult(result)
|
|
} else {
|
|
resp, doErr := m.params.HTTPClient.Do(req)
|
|
endTime := time.Now()
|
|
latency := endTime.Sub(startTime)
|
|
var statusCode int
|
|
var totalTokens int
|
|
if doErr != nil {
|
|
log.Printf("Worker %d failed to send request (non-stream): %v", workerID, doErr)
|
|
result = stats.RequestResult{
|
|
IsSuccess: false,
|
|
Latency: latency,
|
|
TimeToFirstToken: latency,
|
|
TotalTokens: 0,
|
|
}
|
|
} else {
|
|
statusCode = resp.StatusCode
|
|
bodyBytes := resp.Body
|
|
bodyString := string(bodyBytes)
|
|
totalTokens, err = tokenizer.CountTokensInText(bodyString)
|
|
if err != nil {
|
|
log.Printf("Worker %d: Error counting non-stream response tokens: %v", workerID, err)
|
|
}
|
|
result = stats.RequestResult{
|
|
IsSuccess: resp != nil && resp.StatusCode >= 200 && resp.StatusCode < 300,
|
|
Latency: latency,
|
|
TimeToFirstToken: latency,
|
|
TotalTokens: totalTokens,
|
|
}
|
|
log.Printf("Worker %d finished non-streaming request. Success: %v, Status: %d", workerID, result.IsSuccess, statusCode)
|
|
}
|
|
m.statsCollector.RecordResult(result)
|
|
}
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
<-runCtx.Done()
|
|
log.Printf("Benchmark duration (%s) finished or context cancelled.", m.params.TotalDuration)
|
|
|
|
wg.Wait()
|
|
log.Println("runBenchmarkWithoutTicker finished.")
|
|
return nil
|
|
}
|
|
|
|
// createRequest creates a new client.Request object for the benchmark.
|
|
// It generates a prompt using the tokenizer and constructs the request body.
|
|
func (m *Manager) createRequest(prompt string) *client.Request {
|
|
payload := map[string]interface{}{
|
|
"model": m.params.Model,
|
|
"messages": []map[string]string{{"role": "user", "content": prompt}},
|
|
"stream": m.params.IsStreaming,
|
|
}
|
|
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
log.Printf("Error marshaling request body: %v", err)
|
|
return nil
|
|
}
|
|
|
|
return &client.Request{
|
|
URL: m.params.TargetURL,
|
|
Method: "POST",
|
|
Headers: map[string]string{"Content-Type": "application/json", "Authorization": "Bearer " + m.params.APIKey},
|
|
Body: body,
|
|
IsStream: m.params.IsStreaming,
|
|
}
|
|
}
|