112 lines
4.5 KiB
Go

package main
import (
"flag"
"log"
"time"
"context"
"llm-api-benchmark-tool/pkg/client"
"llm-api-benchmark-tool/pkg/config"
"llm-api-benchmark-tool/pkg/concurrency"
"llm-api-benchmark-tool/pkg/stats"
"llm-api-benchmark-tool/pkg/tokenizer"
"llm-api-benchmark-tool/pkg/report"
)
func main() {
// Define command-line flags
configPath := flag.String("config", "config/config.yaml", "Path to the configuration YAML file")
flag.Parse()
// Load configuration
cfg, cfgErr := config.LoadConfig(*configPath)
if cfgErr != nil {
log.Fatalf("Failed to load config: %v", cfgErr)
}
// Check Tokenizer Initialization Status (happens in init() func)
tknErr := tokenizer.CheckInitStatus()
if tknErr != nil {
log.Fatalf("Tokenizer initialization failed: %v", tknErr)
}
log.Println("Tokenizer initialization checked successfully.")
// Initialize HTTP Client
// Determine the effective request timeout
requestTimeout := cfg.RequestTimeout
if requestTimeout <= 0 {
requestTimeout = cfg.Timeout // Fallback to global timeout
}
// Always create FastHTTPClient; it handles both streaming (net/http) and non-streaming (fasthttp)
httpClient := client.NewFastHTTPClient(requestTimeout)
log.Println("Configuration loaded successfully.")
log.Printf("API Endpoint: %s", cfg.APIEndpoint)
// The actual client used (fasthttp for Do, net/http for Stream) is handled internally by FastHTTPClient
log.Printf("Using HTTP client implementation: FastHTTPClient (non-stream timeout: %s)", requestTimeout)
log.Println("HTTP Client initialized.")
// Initialize Stats Collector
statsCollector := stats.NewStatsCollector()
log.Println("Stats Collector initialized.")
// Initialize Concurrency Manager
benchmarkParams := concurrency.BenchmarkParams{
HTTPClient: httpClient,
TargetURL: cfg.APIEndpoint,
APIKey: cfg.APIKey,
TotalDuration: cfg.Timeout,
Concurrency: cfg.Concurrency,
RequestInterval: time.Duration(cfg.RateLimit) * time.Millisecond, // Convert ms to time.Duration
IsStreaming: cfg.Streaming,
PromptTokens: cfg.PromptTokens,
Model: cfg.Model,
// TODO: Add other params like API Key, Model from cfg if needed
}
concurrencyManager := concurrency.NewManager(benchmarkParams, statsCollector)
log.Println("Concurrency Manager initialized.")
log.Printf("Benchmark settings: Duration=%s, Concurrency=%d, RateLimit=%s, Streaming=%v, PromptTokens=%d, Model=%s",
benchmarkParams.TotalDuration, benchmarkParams.Concurrency, benchmarkParams.RequestInterval, benchmarkParams.IsStreaming, benchmarkParams.PromptTokens, benchmarkParams.Model)
// --- Run Benchmark ---
log.Println("Starting benchmark...")
ctx := context.Background()
if err := concurrencyManager.RunBenchmark(ctx); err != nil {
log.Fatalf("Benchmark run failed: %v", err)
}
log.Println("Benchmark finished.")
// --- Get Final Stats ---
finalStats := statsCollector.CalculateStats(benchmarkParams.TotalDuration)
log.Println("--- Benchmark Results ---")
log.Printf("Total Duration: %.2fs", finalStats.TotalDuration.Seconds())
log.Printf("Total Requests: %d", finalStats.TotalRequests)
log.Printf("Successful Requests: %d", finalStats.SuccessfulRequests)
log.Printf("Failed Requests: %d", finalStats.FailedRequests)
if finalStats.TotalRequests > 0 {
log.Printf("Success Rate: %.2f%%", float64(finalStats.SuccessfulRequests)/float64(finalStats.TotalRequests)*100)
}
log.Printf("Average QPS: %.2f", finalStats.AvgQPS)
if finalStats.SuccessfulRequests > 0 {
log.Printf("Average Latency (Successful): %s", finalStats.AvgLatency)
log.Printf("Latency (Min/Max): %s / %s", finalStats.MinLatency, finalStats.MaxLatency)
log.Printf("Latency (P90/P95/P99): %s / %s / %s", finalStats.P90Latency, finalStats.P95Latency, finalStats.P99Latency)
log.Printf("Average TTFT (Successful): %s", finalStats.AvgTimeToFirstToken)
log.Printf("TTFT (Min/Max): %s / %s", finalStats.MinTimeToFirstToken, finalStats.MaxTimeToFirstToken)
log.Printf("TTFT (P90/P95/P99): %s / %s / %s", finalStats.P90TimeToFirstToken, finalStats.P95TimeToFirstToken, finalStats.P99TimeToFirstToken)
log.Printf("Average Tokens/Second (Successful): %.2f", finalStats.AvgTokensPerSecond)
}
log.Println("-------------------------")
// --- Generate Report ---
if cfg.ReportFile != "" {
log.Printf("Generating HTML report to: %s", cfg.ReportFile)
if err := report.GenerateHTMLReport(finalStats, cfg.ReportFile); err != nil {
log.Printf("Error generating report: %v", err)
} else {
log.Println("Report generated successfully.")
}
}
}