package benchmark import ( "bytes" "context" "encoding/json" "fmt" "math" "math/rand" "strings" "sync" "time" "github.com/montanaflynn/stats" "github.com/valyala/fasthttp" "github.com/acmestudio/llm-api-benchmark-tool/config" "github.com/acmestudio/llm-api-benchmark-tool/logger" "github.com/acmestudio/llm-api-benchmark-tool/prompt" ) // Benchmark 表示基准测试执行器 type Benchmark struct { config *config.Config generator *prompt.Generator client *fasthttp.Client rand *rand.Rand } // Response 表示API响应 type Response struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []Choice `json:"choices"` Usage Usage `json:"usage"` } // StreamResponse 表示流式API响应 type StreamResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []StreamChoice `json:"choices"` } // StreamChoice 表示流式响应中的选择 type StreamChoice struct { Delta StreamDelta `json:"delta"` Index int `json:"index"` FinishReason string `json:"finish_reason"` } // StreamDelta 表示流式响应中的增量内容 type StreamDelta struct { Role string `json:"role"` Content string `json:"content"` } // Choice 表示API响应中的选择 type Choice struct { Message Message `json:"message"` Index int `json:"index"` FinishReason string `json:"finish_reason"` } // Message 表示聊天消息 type Message struct { Role string `json:"role"` Content string `json:"content"` } // Usage 表示API响应中的使用情况 type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } // Request 表示API请求 type Request struct { Model string `json:"model"` Messages []Message `json:"messages"` MaxTokens int `json:"max_tokens"` Temperature float64 `json:"temperature"` Stream bool `json:"stream"` } // Result 表示单个API请求的结果 type Result struct { Timestamp time.Time // 请求时间戳 Concurrency int // 并发级别 PromptType string // 提示词类型 (short/long) PromptTokenCount int // 提示词token数量 PromptTokens int // 提示词token数量(API返回) CompletionTokens int // 完成token数量 TotalTokens int // 总token数量 ResponseTime time.Duration // 响应时间 Response string // 响应内容 Error error // 错误信息 } // BenchmarkResults 表示基准测试结果 type BenchmarkResults struct { Results []Result // 所有请求结果 StartTime time.Time // 开始时间 EndTime time.Time // 结束时间 Config *config.Config // 配置信息 ConcurrencyData map[int][]Result // 按并发步骤分组的结果 TotalRequests int // 总请求数 SuccessRequests int // 成功请求数 FailedRequests int // 失败请求数 AvgResponseTime time.Duration // 平均响应时间 P50ResponseTime time.Duration // 响应时间50百分位 P90ResponseTime time.Duration // 响应时间90百分位 P99ResponseTime time.Duration // 响应时间99百分位 QPS float64 // 每秒请求数 AvgPromptTokens int // 平均提示词token数 AvgCompletionTokens int // 平均完成token数 AvgTotalTokens int // 平均总token数 AvgTokenRate float64 // 平均token生成速率 } // NewBenchmark 创建新的基准测试执行器 func NewBenchmark(cfg *config.Config) (*Benchmark, error) { // 初始化提示词生成器 logger.Debug("初始化提示词生成器...") gen, err := prompt.NewGenerator(cfg) if err != nil { return nil, fmt.Errorf("初始化提示词生成器失败: %w", err) } logger.Debug("提示词生成器初始化成功") // 初始化HTTP客户端 logger.Debug("初始化HTTP客户端...") client := &fasthttp.Client{ MaxConnsPerHost: 10000, ReadTimeout: time.Duration(cfg.Timeout) * time.Second, WriteTimeout: time.Duration(cfg.Timeout) * time.Second, } logger.Debug("HTTP客户端初始化成功,超时设置: %d秒", cfg.Timeout) // 初始化随机数生成器 source := rand.NewSource(time.Now().UnixNano()) return &Benchmark{ config: cfg, generator: gen, client: client, rand: rand.New(source), }, nil } // Run 执行基准测试 func (b *Benchmark) Run() (*BenchmarkResults, error) { // 创建结果对象 results := &BenchmarkResults{ StartTime: time.Now(), Config: b.config, ConcurrencyData: make(map[int][]Result), } // 执行每个并发步骤 for _, concurrency := range b.config.Concurrency.Steps { logger.Info("开始执行并发步骤: %d 并发用户", concurrency) logger.Info("预计持续时间: %d秒", b.config.Concurrency.DurationPerStep) // 执行单个并发步骤 stepResults, err := b.runConcurrencyStep(concurrency) if err != nil { return nil, fmt.Errorf("执行并发步骤 %d 失败: %w", concurrency, err) } // 统计成功率 successCount := 0 for _, result := range stepResults { if result.Error == nil { successCount++ } // 添加到总结果 results.Results = append(results.Results, result) } // 记录步骤结果 results.ConcurrencyData[concurrency] = stepResults // 输出步骤统计 successRate := 0.0 if len(stepResults) > 0 { successRate = float64(successCount) / float64(len(stepResults)) * 100 } logger.Info("完成并发步骤: %d 并发用户,共 %d 个请求,成功率: %.2f%%", concurrency, len(stepResults), successRate) } results.EndTime = time.Now() logger.Info("所有并发步骤测试完成") return results, nil } // runConcurrencyStep 执行单个并发步骤 func (b *Benchmark) runConcurrencyStep(concurrencyStep int) ([]Result, error) { var results []Result var resultsMutex sync.Mutex // 创建上下文,用于控制测试时间 ctx, cancel := context.WithTimeout( context.Background(), time.Duration(b.config.Concurrency.DurationPerStep)*time.Second, ) defer cancel() // 创建工作通道 workChan := make(chan struct{}, concurrencyStep) // 统计变量 var statsMutex sync.Mutex totalRequests := 0 successRequests := 0 failedRequests := 0 totalResponseTime := time.Duration(0) // 启动统计协程 statsTicker := time.NewTicker(5 * time.Second) defer statsTicker.Stop() go func() { startTime := time.Now() for { select { case <-ctx.Done(): return case <-statsTicker.C: statsMutex.Lock() currentTotal := totalRequests currentSuccess := successRequests currentFailed := failedRequests var avgResponseTime time.Duration if successRequests > 0 { avgResponseTime = totalResponseTime / time.Duration(successRequests) } statsMutex.Unlock() elapsedTime := time.Since(startTime).Seconds() qps := float64(currentTotal) / elapsedTime logger.Info("实时统计 - 请求: %d (成功: %d, 失败: %d), QPS: %.2f, 平均响应时间: %v", currentTotal, currentSuccess, currentFailed, qps, avgResponseTime) } } }() // 启动工作协程 var wg sync.WaitGroup logger.Debug("启动 %d 个工作协程...", concurrencyStep) for i := 0; i < concurrencyStep; i++ { wg.Add(1) go func(workerID int) { defer wg.Done() logger.Debug("工作协程 #%d 已启动", workerID) for { select { case <-ctx.Done(): // 测试时间结束 logger.Debug("工作协程 #%d 收到结束信号", workerID) return case workChan <- struct{}{}: // 执行请求前先更新计数器 statsMutex.Lock() totalRequests++ requestID := totalRequests statsMutex.Unlock() logger.Debug("工作协程 #%d 开始执行请求 #%d", workerID, requestID) // 执行请求 result := b.executeRequest(concurrencyStep) // 更新统计 statsMutex.Lock() if result.Error == nil { successRequests++ totalResponseTime += result.ResponseTime logger.Debug("请求 #%d 成功,响应时间: %v", requestID, result.ResponseTime) } else { failedRequests++ logger.Debug("请求 #%d 失败: %v", requestID, result.Error) } statsMutex.Unlock() // 添加结果 resultsMutex.Lock() results = append(results, result) resultsMutex.Unlock() // 根据泊松分布等待 waitTime := b.getPoissonWaitTime() logger.Debug("请求 #%d 完成,等待 %v 后发送下一个请求", requestID, waitTime) time.Sleep(waitTime) <-workChan } } }(i) } // 等待所有工作协程完成 logger.Debug("等待所有工作协程完成...") wg.Wait() logger.Debug("所有工作协程已完成") return results, nil } // executeRequest 执行单个请求 func (b *Benchmark) executeRequest(concurrencyStep int) Result { result := Result{ Timestamp: time.Now(), Concurrency: concurrencyStep, } // 选择提示词类型 promptType := b.generator.SelectPromptType() result.PromptType = promptType // 生成提示词 prompt, err := b.generator.GeneratePrompt(promptType) if err != nil { result.Error = fmt.Errorf("生成提示词失败: %w", err) return result } // 创建请求 req := Request{ Model: b.config.API.Model, Messages: []Message{ { Role: "user", Content: prompt, }, }, MaxTokens: 2048, Temperature: 0.7, Stream: true, // 启用流式响应 } // 序列化请求 reqBody, err := json.Marshal(req) if err != nil { result.Error = fmt.Errorf("序列化请求失败: %w", err) return result } // 创建HTTP请求 httpReq := fasthttp.AcquireRequest() httpResp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(httpReq) defer fasthttp.ReleaseResponse(httpResp) httpReq.SetRequestURI(b.config.API.Endpoint) httpReq.Header.SetMethod("POST") httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+b.config.API.APIKey) httpReq.Header.Set("Accept", "text/event-stream") // 设置接受流式响应 httpReq.SetBody(reqBody) // 记录开始时间 startTime := time.Now() // 发送请求 logger.Debug("发送请求到 %s,提示词类型: %s,提示词长度: %d字符", b.config.API.Endpoint, promptType, len(prompt)) err = b.client.Do(httpReq, httpResp) // 计算响应时间 responseTime := time.Since(startTime) result.ResponseTime = responseTime // 处理错误 if err != nil { result.Error = fmt.Errorf("发送请求失败: %w", err) logger.Error("请求失败: %v", err) return result } // 检查状态码 statusCode := httpResp.StatusCode() if statusCode != 200 { result.Error = fmt.Errorf("API返回错误状态码: %d, 响应: %s", statusCode, httpResp.Body()) logger.Error("API返回错误状态码: %d, 响应: %s", statusCode, httpResp.Body()) return result } // 处理流式响应 if req.Stream { return b.handleStreamResponse(httpResp, result, startTime) } // 处理非流式响应 return b.handleNormalResponse(httpResp, result) } // handleStreamResponse 处理流式响应 func (b *Benchmark) handleStreamResponse(httpResp *fasthttp.Response, result Result, startTime time.Time) Result { body := httpResp.Body() responseContent := "" tokenCount := 0 // 打印响应头和原始响应体的前100个字符,用于调试 logger.Debug("响应头: %s", httpResp.Header.String()) if len(body) > 0 { previewLen := 100 if len(body) < previewLen { previewLen = len(body) } logger.Debug("响应体预览: %s", string(body[:previewLen])) } else { logger.Warn("收到空响应体") result.Error = fmt.Errorf("收到空响应体") return result } // 解析流式响应 // 流式响应格式为: data: {...}\n\ndata: {...}\n\ndata: [DONE]\n\n lines := bytes.Split(body, []byte("\n\n")) if len(lines) == 0 { logger.Warn("无法解析流式响应") result.Error = fmt.Errorf("无法解析流式响应") return result } logger.Debug("解析到 %d 个数据块", len(lines)) // 超时控制 timeout := time.After(30 * time.Second) // 使用通道处理解析结果 doneChan := make(chan struct{}) go func() { for _, line := range lines { if len(line) == 0 { continue } // 检查是否为数据行 if !bytes.HasPrefix(line, []byte("data: ")) { continue } // 移除 "data: " 前缀 data := line[6:] // 检查是否为结束标记 if bytes.Equal(data, []byte("[DONE]")) { logger.Debug("收到流式响应结束标记") break } // 解析JSON var chunk map[string]interface{} if err := json.Unmarshal(data, &chunk); err != nil { logger.Warn("解析JSON数据块失败: %v, 数据: %s", err, string(data)) continue } // 提取内容 choices, ok := chunk["choices"].([]interface{}) if !ok || len(choices) == 0 { logger.Warn("无法从数据块中提取choices") continue } choice, ok := choices[0].(map[string]interface{}) if !ok { logger.Warn("无法从choices中提取第一个元素") continue } delta, ok := choice["delta"].(map[string]interface{}) if !ok { logger.Warn("无法从choice中提取delta") continue } content, ok := delta["content"].(string) if !ok { // 某些delta可能没有content,这是正常的 continue } // 累加内容 responseContent += content tokenCount += len(strings.Split(content, " ")) } close(doneChan) }() // 等待处理完成或超时 select { case <-doneChan: // 处理完成 case <-timeout: logger.Warn("解析流式响应超时") result.Error = fmt.Errorf("解析流式响应超时") return result } // 计算响应时间 result.ResponseTime = time.Since(startTime) // 设置结果 result.Response = responseContent result.PromptTokenCount = len(result.PromptType) / 2 // 简单估计 result.PromptTokens = len(result.PromptType) / 2 // 简单估计 result.CompletionTokens = tokenCount result.TotalTokens = result.PromptTokens + result.CompletionTokens logger.Debug("流式响应处理完成,响应时间: %v, 提示词Token: %d, 完成Token: %d, 总Token: %d", result.ResponseTime, result.PromptTokens, result.CompletionTokens, result.TotalTokens) return result } // handleNormalResponse 处理普通响应 func (b *Benchmark) handleNormalResponse(httpResp *fasthttp.Response, result Result) Result { // 解析响应 var response Response if err := json.Unmarshal(httpResp.Body(), &response); err != nil { result.Error = fmt.Errorf("解析响应失败: %w", err) logger.Error("解析响应失败: %v", err) return result } // 记录Token数 result.PromptTokens = response.Usage.PromptTokens result.CompletionTokens = response.Usage.CompletionTokens result.TotalTokens = response.Usage.TotalTokens // 记录详细日志 logger.Debug("请求完成 - 类型: %s, 提示词Token: %d, 响应Token: %d, 响应时间: %v", result.PromptType, result.PromptTokens, result.CompletionTokens, result.ResponseTime) return result } // getPoissonWaitTime 根据泊松分布获取等待时间 func (b *Benchmark) getPoissonWaitTime() time.Duration { // 生成泊松分布随机数 lambda := b.config.PoissonLambda L := math.Exp(-lambda) k := 0 p := 1.0 for p > L { k++ p *= b.rand.Float64() } // 转换为等待时间(毫秒) waitTime := time.Duration(k * 100) * time.Millisecond // 限制最大等待时间为1秒 if waitTime > time.Second { waitTime = time.Second } return waitTime } // convertToFloat64 将整数切片转换为浮点数切片 func convertToFloat64(ints []int) []float64 { floats := make([]float64, len(ints)) for i, v := range ints { floats[i] = float64(v) } return floats } // calculateStats 计算统计数据 func (b *Benchmark) calculateStats(results []Result) *BenchmarkResults { var stats BenchmarkResults // 过滤出成功的请求 var successResults []Result for _, r := range results { if r.Error == nil { successResults = append(successResults, r) } } // 如果没有成功的请求,返回空结果 if len(successResults) == 0 { return &stats } // 计算基本统计数据 stats.TotalRequests = len(results) stats.SuccessRequests = len(successResults) stats.FailedRequests = stats.TotalRequests - stats.SuccessRequests // 计算响应时间统计 var responseTimes []float64 for _, r := range successResults { responseTimes = append(responseTimes, float64(r.ResponseTime.Milliseconds())) } // 计算平均响应时间 avg, _ := stats.Mean(responseTimes) stats.AvgResponseTime = time.Duration(avg) * time.Millisecond // 计算响应时间百分位数 p50, _ := stats.Percentile(responseTimes, 50) p90, _ := stats.Percentile(responseTimes, 90) p99, _ := stats.Percentile(responseTimes, 99) stats.P50ResponseTime = time.Duration(p50) * time.Millisecond stats.P90ResponseTime = time.Duration(p90) * time.Millisecond stats.P99ResponseTime = time.Duration(p99) * time.Millisecond // 计算QPS totalDuration := successResults[len(successResults)-1].Timestamp.Sub(successResults[0].Timestamp).Seconds() if totalDuration > 0 { stats.QPS = float64(stats.SuccessRequests) / totalDuration } // 计算Token统计 var promptTokens, completionTokens, totalTokens []int for _, r := range successResults { promptTokens = append(promptTokens, r.PromptTokens) completionTokens = append(completionTokens, r.CompletionTokens) totalTokens = append(totalTokens, r.TotalTokens) } // 计算平均Token数 if len(promptTokens) > 0 { promptAvg, _ := stats.Mean(convertToFloat64(promptTokens)) completionAvg, _ := stats.Mean(convertToFloat64(completionTokens)) totalAvg, _ := stats.Mean(convertToFloat64(totalTokens)) stats.AvgPromptTokens = int(promptAvg) stats.AvgCompletionTokens = int(completionAvg) stats.AvgTotalTokens = int(totalAvg) } // 计算Token生成速率 (tokens/s) var tokenRates []float64 for _, r := range successResults { if r.ResponseTime.Seconds() > 0 { tokenRate := float64(r.CompletionTokens) / r.ResponseTime.Seconds() tokenRates = append(tokenRates, tokenRate) } } if len(tokenRates) > 0 { tokenRateAvg, _ := stats.Mean(tokenRates) stats.AvgTokenRate = tokenRateAvg } return &stats } func (b *BenchmarkResults) Mean(data []float64) (float64, error) { return stats.Mean(data) } func (b *BenchmarkResults) Percentile(data []float64, p float64) (float64, error) { return stats.Percentile(data, p) }