112 lines
4.5 KiB
Go
112 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"log"
|
|
"time"
|
|
"context"
|
|
|
|
"llm-api-benchmark-tool/pkg/client"
|
|
"llm-api-benchmark-tool/pkg/config"
|
|
"llm-api-benchmark-tool/pkg/concurrency"
|
|
"llm-api-benchmark-tool/pkg/stats"
|
|
"llm-api-benchmark-tool/pkg/tokenizer"
|
|
"llm-api-benchmark-tool/pkg/report"
|
|
)
|
|
|
|
func main() {
|
|
// Define command-line flags
|
|
configPath := flag.String("config", "config/config.yaml", "Path to the configuration YAML file")
|
|
flag.Parse()
|
|
|
|
// Load configuration
|
|
cfg, cfgErr := config.LoadConfig(*configPath)
|
|
if cfgErr != nil {
|
|
log.Fatalf("Failed to load config: %v", cfgErr)
|
|
}
|
|
|
|
// Check Tokenizer Initialization Status (happens in init() func)
|
|
tknErr := tokenizer.CheckInitStatus()
|
|
if tknErr != nil {
|
|
log.Fatalf("Tokenizer initialization failed: %v", tknErr)
|
|
}
|
|
log.Println("Tokenizer initialization checked successfully.")
|
|
|
|
// Initialize HTTP Client
|
|
// Determine the effective request timeout
|
|
requestTimeout := cfg.RequestTimeout
|
|
if requestTimeout <= 0 {
|
|
requestTimeout = cfg.Timeout // Fallback to global timeout
|
|
}
|
|
// Always create FastHTTPClient; it handles both streaming (net/http) and non-streaming (fasthttp)
|
|
httpClient := client.NewFastHTTPClient(requestTimeout)
|
|
|
|
log.Println("Configuration loaded successfully.")
|
|
log.Printf("API Endpoint: %s", cfg.APIEndpoint)
|
|
// The actual client used (fasthttp for Do, net/http for Stream) is handled internally by FastHTTPClient
|
|
log.Printf("Using HTTP client implementation: FastHTTPClient (non-stream timeout: %s)", requestTimeout)
|
|
log.Println("HTTP Client initialized.")
|
|
|
|
// Initialize Stats Collector
|
|
statsCollector := stats.NewStatsCollector()
|
|
log.Println("Stats Collector initialized.")
|
|
|
|
// Initialize Concurrency Manager
|
|
benchmarkParams := concurrency.BenchmarkParams{
|
|
HTTPClient: httpClient,
|
|
TargetURL: cfg.APIEndpoint,
|
|
APIKey: cfg.APIKey,
|
|
TotalDuration: cfg.Timeout,
|
|
Concurrency: cfg.Concurrency,
|
|
RequestInterval: time.Duration(cfg.RateLimit) * time.Millisecond, // Convert ms to time.Duration
|
|
IsStreaming: cfg.Streaming,
|
|
PromptTokens: cfg.PromptTokens,
|
|
Model: cfg.Model,
|
|
// TODO: Add other params like API Key, Model from cfg if needed
|
|
}
|
|
concurrencyManager := concurrency.NewManager(benchmarkParams, statsCollector)
|
|
log.Println("Concurrency Manager initialized.")
|
|
log.Printf("Benchmark settings: Duration=%s, Concurrency=%d, RateLimit=%s, Streaming=%v, PromptTokens=%d, Model=%s",
|
|
benchmarkParams.TotalDuration, benchmarkParams.Concurrency, benchmarkParams.RequestInterval, benchmarkParams.IsStreaming, benchmarkParams.PromptTokens, benchmarkParams.Model)
|
|
|
|
// --- Run Benchmark ---
|
|
log.Println("Starting benchmark...")
|
|
ctx := context.Background()
|
|
if err := concurrencyManager.RunBenchmark(ctx); err != nil {
|
|
log.Fatalf("Benchmark run failed: %v", err)
|
|
}
|
|
log.Println("Benchmark finished.")
|
|
|
|
// --- Get Final Stats ---
|
|
finalStats := statsCollector.CalculateStats(benchmarkParams.TotalDuration)
|
|
log.Println("--- Benchmark Results ---")
|
|
log.Printf("Total Duration: %.2fs", finalStats.TotalDuration.Seconds())
|
|
log.Printf("Total Requests: %d", finalStats.TotalRequests)
|
|
log.Printf("Successful Requests: %d", finalStats.SuccessfulRequests)
|
|
log.Printf("Failed Requests: %d", finalStats.FailedRequests)
|
|
if finalStats.TotalRequests > 0 {
|
|
log.Printf("Success Rate: %.2f%%", float64(finalStats.SuccessfulRequests)/float64(finalStats.TotalRequests)*100)
|
|
}
|
|
log.Printf("Average QPS: %.2f", finalStats.AvgQPS)
|
|
if finalStats.SuccessfulRequests > 0 {
|
|
log.Printf("Average Latency (Successful): %s", finalStats.AvgLatency)
|
|
log.Printf("Latency (Min/Max): %s / %s", finalStats.MinLatency, finalStats.MaxLatency)
|
|
log.Printf("Latency (P90/P95/P99): %s / %s / %s", finalStats.P90Latency, finalStats.P95Latency, finalStats.P99Latency)
|
|
log.Printf("Average TTFT (Successful): %s", finalStats.AvgTimeToFirstToken)
|
|
log.Printf("TTFT (Min/Max): %s / %s", finalStats.MinTimeToFirstToken, finalStats.MaxTimeToFirstToken)
|
|
log.Printf("TTFT (P90/P95/P99): %s / %s / %s", finalStats.P90TimeToFirstToken, finalStats.P95TimeToFirstToken, finalStats.P99TimeToFirstToken)
|
|
log.Printf("Average Tokens/Second (Successful): %.2f", finalStats.AvgTokensPerSecond)
|
|
}
|
|
log.Println("-------------------------")
|
|
|
|
// --- Generate Report ---
|
|
if cfg.ReportFile != "" {
|
|
log.Printf("Generating HTML report to: %s", cfg.ReportFile)
|
|
if err := report.GenerateHTMLReport(finalStats, cfg.ReportFile); err != nil {
|
|
log.Printf("Error generating report: %v", err)
|
|
} else {
|
|
log.Println("Report generated successfully.")
|
|
}
|
|
}
|
|
}
|