2025-04-21 18:19:09 +08:00

674 lines
18 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package benchmark
import (
"bytes"
"context"
"encoding/json"
"fmt"
"math"
"math/rand"
"strings"
"sync"
"time"
"github.com/montanaflynn/stats"
"github.com/valyala/fasthttp"
"github.com/acmestudio/llm-api-benchmark-tool/config"
"github.com/acmestudio/llm-api-benchmark-tool/logger"
"github.com/acmestudio/llm-api-benchmark-tool/prompt"
)
// Benchmark 表示基准测试执行器
type Benchmark struct {
config *config.Config
generator *prompt.Generator
client *fasthttp.Client
rand *rand.Rand
}
// Response 表示API响应
type Response struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
// StreamResponse 表示流式API响应
type StreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
}
// StreamChoice 表示流式响应中的选择
type StreamChoice struct {
Delta StreamDelta `json:"delta"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
// StreamDelta 表示流式响应中的增量内容
type StreamDelta struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Choice 表示API响应中的选择
type Choice struct {
Message Message `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
// Message 表示聊天消息
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Usage 表示API响应中的使用情况
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Request 表示API请求
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Stream bool `json:"stream"`
}
// Result 表示单个API请求的结果
type Result struct {
Timestamp time.Time // 请求时间戳
Concurrency int // 并发级别
PromptType string // 提示词类型 (short/long)
PromptTokenCount int // 提示词token数量
PromptTokens int // 提示词token数量API返回
CompletionTokens int // 完成token数量
TotalTokens int // 总token数量
ResponseTime time.Duration // 响应时间
Response string // 响应内容
Error error // 错误信息
}
// BenchmarkResults 表示基准测试结果
type BenchmarkResults struct {
Results []Result // 所有请求结果
StartTime time.Time // 开始时间
EndTime time.Time // 结束时间
Config *config.Config // 配置信息
ConcurrencyData map[int][]Result // 按并发步骤分组的结果
TotalRequests int // 总请求数
SuccessRequests int // 成功请求数
FailedRequests int // 失败请求数
AvgResponseTime time.Duration // 平均响应时间
P50ResponseTime time.Duration // 响应时间50百分位
P90ResponseTime time.Duration // 响应时间90百分位
P99ResponseTime time.Duration // 响应时间99百分位
QPS float64 // 每秒请求数
AvgPromptTokens int // 平均提示词token数
AvgCompletionTokens int // 平均完成token数
AvgTotalTokens int // 平均总token数
AvgTokenRate float64 // 平均token生成速率
}
// NewBenchmark 创建新的基准测试执行器
func NewBenchmark(cfg *config.Config) (*Benchmark, error) {
// 初始化提示词生成器
logger.Debug("初始化提示词生成器...")
gen, err := prompt.NewGenerator(cfg)
if err != nil {
return nil, fmt.Errorf("初始化提示词生成器失败: %w", err)
}
logger.Debug("提示词生成器初始化成功")
// 初始化HTTP客户端
logger.Debug("初始化HTTP客户端...")
client := &fasthttp.Client{
MaxConnsPerHost: 10000,
ReadTimeout: time.Duration(cfg.Timeout) * time.Second,
WriteTimeout: time.Duration(cfg.Timeout) * time.Second,
}
logger.Debug("HTTP客户端初始化成功超时设置: %d秒", cfg.Timeout)
// 初始化随机数生成器
source := rand.NewSource(time.Now().UnixNano())
return &Benchmark{
config: cfg,
generator: gen,
client: client,
rand: rand.New(source),
}, nil
}
// Run 执行基准测试
func (b *Benchmark) Run() (*BenchmarkResults, error) {
// 创建结果对象
results := &BenchmarkResults{
StartTime: time.Now(),
Config: b.config,
ConcurrencyData: make(map[int][]Result),
}
// 执行每个并发步骤
for _, concurrency := range b.config.Concurrency.Steps {
logger.Info("开始执行并发步骤: %d 并发用户", concurrency)
logger.Info("预计持续时间: %d秒", b.config.Concurrency.DurationPerStep)
// 执行单个并发步骤
stepResults, err := b.runConcurrencyStep(concurrency)
if err != nil {
return nil, fmt.Errorf("执行并发步骤 %d 失败: %w", concurrency, err)
}
// 统计成功率
successCount := 0
for _, result := range stepResults {
if result.Error == nil {
successCount++
}
// 添加到总结果
results.Results = append(results.Results, result)
}
// 记录步骤结果
results.ConcurrencyData[concurrency] = stepResults
// 输出步骤统计
successRate := 0.0
if len(stepResults) > 0 {
successRate = float64(successCount) / float64(len(stepResults)) * 100
}
logger.Info("完成并发步骤: %d 并发用户,共 %d 个请求,成功率: %.2f%%",
concurrency, len(stepResults), successRate)
}
results.EndTime = time.Now()
logger.Info("所有并发步骤测试完成")
return results, nil
}
// runConcurrencyStep 执行单个并发步骤
func (b *Benchmark) runConcurrencyStep(concurrencyStep int) ([]Result, error) {
var results []Result
var resultsMutex sync.Mutex
// 创建上下文,用于控制测试时间
ctx, cancel := context.WithTimeout(
context.Background(),
time.Duration(b.config.Concurrency.DurationPerStep)*time.Second,
)
defer cancel()
// 创建工作通道
workChan := make(chan struct{}, concurrencyStep)
// 统计变量
var statsMutex sync.Mutex
totalRequests := 0
successRequests := 0
failedRequests := 0
totalResponseTime := time.Duration(0)
// 启动统计协程
statsTicker := time.NewTicker(5 * time.Second)
defer statsTicker.Stop()
go func() {
startTime := time.Now()
for {
select {
case <-ctx.Done():
return
case <-statsTicker.C:
statsMutex.Lock()
currentTotal := totalRequests
currentSuccess := successRequests
currentFailed := failedRequests
var avgResponseTime time.Duration
if successRequests > 0 {
avgResponseTime = totalResponseTime / time.Duration(successRequests)
}
statsMutex.Unlock()
elapsedTime := time.Since(startTime).Seconds()
qps := float64(currentTotal) / elapsedTime
logger.Info("实时统计 - 请求: %d (成功: %d, 失败: %d), QPS: %.2f, 平均响应时间: %v",
currentTotal, currentSuccess, currentFailed, qps, avgResponseTime)
}
}
}()
// 启动工作协程
var wg sync.WaitGroup
logger.Debug("启动 %d 个工作协程...", concurrencyStep)
for i := 0; i < concurrencyStep; i++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
logger.Debug("工作协程 #%d 已启动", workerID)
for {
select {
case <-ctx.Done():
// 测试时间结束
logger.Debug("工作协程 #%d 收到结束信号", workerID)
return
case workChan <- struct{}{}:
// 执行请求前先更新计数器
statsMutex.Lock()
totalRequests++
requestID := totalRequests
statsMutex.Unlock()
logger.Debug("工作协程 #%d 开始执行请求 #%d", workerID, requestID)
// 执行请求
result := b.executeRequest(concurrencyStep)
// 更新统计
statsMutex.Lock()
if result.Error == nil {
successRequests++
totalResponseTime += result.ResponseTime
logger.Debug("请求 #%d 成功,响应时间: %v", requestID, result.ResponseTime)
} else {
failedRequests++
logger.Debug("请求 #%d 失败: %v", requestID, result.Error)
}
statsMutex.Unlock()
// 添加结果
resultsMutex.Lock()
results = append(results, result)
resultsMutex.Unlock()
// 根据泊松分布等待
waitTime := b.getPoissonWaitTime()
logger.Debug("请求 #%d 完成,等待 %v 后发送下一个请求", requestID, waitTime)
time.Sleep(waitTime)
<-workChan
}
}
}(i)
}
// 等待所有工作协程完成
logger.Debug("等待所有工作协程完成...")
wg.Wait()
logger.Debug("所有工作协程已完成")
return results, nil
}
// executeRequest 执行单个请求
func (b *Benchmark) executeRequest(concurrencyStep int) Result {
result := Result{
Timestamp: time.Now(),
Concurrency: concurrencyStep,
}
// 选择提示词类型
promptType := b.generator.SelectPromptType()
result.PromptType = promptType
// 生成提示词
prompt, err := b.generator.GeneratePrompt(promptType)
if err != nil {
result.Error = fmt.Errorf("生成提示词失败: %w", err)
return result
}
// 创建请求
req := Request{
Model: b.config.API.Model,
Messages: []Message{
{
Role: "user",
Content: prompt,
},
},
MaxTokens: 2048,
Temperature: 0.7,
Stream: true, // 启用流式响应
}
// 序列化请求
reqBody, err := json.Marshal(req)
if err != nil {
result.Error = fmt.Errorf("序列化请求失败: %w", err)
return result
}
// 创建HTTP请求
httpReq := fasthttp.AcquireRequest()
httpResp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(httpReq)
defer fasthttp.ReleaseResponse(httpResp)
httpReq.SetRequestURI(b.config.API.Endpoint)
httpReq.Header.SetMethod("POST")
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+b.config.API.APIKey)
httpReq.Header.Set("Accept", "text/event-stream") // 设置接受流式响应
httpReq.SetBody(reqBody)
// 记录开始时间
startTime := time.Now()
// 发送请求
logger.Debug("发送请求到 %s提示词类型: %s提示词长度: %d字符",
b.config.API.Endpoint, promptType, len(prompt))
err = b.client.Do(httpReq, httpResp)
// 计算响应时间
responseTime := time.Since(startTime)
result.ResponseTime = responseTime
// 处理错误
if err != nil {
result.Error = fmt.Errorf("发送请求失败: %w", err)
logger.Error("请求失败: %v", err)
return result
}
// 检查状态码
statusCode := httpResp.StatusCode()
if statusCode != 200 {
result.Error = fmt.Errorf("API返回错误状态码: %d, 响应: %s", statusCode, httpResp.Body())
logger.Error("API返回错误状态码: %d, 响应: %s", statusCode, httpResp.Body())
return result
}
// 处理流式响应
if req.Stream {
return b.handleStreamResponse(httpResp, result, startTime)
}
// 处理非流式响应
return b.handleNormalResponse(httpResp, result)
}
// handleStreamResponse 处理流式响应
func (b *Benchmark) handleStreamResponse(httpResp *fasthttp.Response, result Result, startTime time.Time) Result {
body := httpResp.Body()
responseContent := ""
tokenCount := 0
// 打印响应头和原始响应体的前100个字符用于调试
logger.Debug("响应头: %s", httpResp.Header.String())
if len(body) > 0 {
previewLen := 100
if len(body) < previewLen {
previewLen = len(body)
}
logger.Debug("响应体预览: %s", string(body[:previewLen]))
} else {
logger.Warn("收到空响应体")
result.Error = fmt.Errorf("收到空响应体")
return result
}
// 解析流式响应
// 流式响应格式为: data: {...}\n\ndata: {...}\n\ndata: [DONE]\n\n
lines := bytes.Split(body, []byte("\n\n"))
if len(lines) == 0 {
logger.Warn("无法解析流式响应")
result.Error = fmt.Errorf("无法解析流式响应")
return result
}
logger.Debug("解析到 %d 个数据块", len(lines))
// 超时控制
timeout := time.After(30 * time.Second)
// 使用通道处理解析结果
doneChan := make(chan struct{})
go func() {
for _, line := range lines {
if len(line) == 0 {
continue
}
// 检查是否为数据行
if !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
// 移除 "data: " 前缀
data := line[6:]
// 检查是否为结束标记
if bytes.Equal(data, []byte("[DONE]")) {
logger.Debug("收到流式响应结束标记")
break
}
// 解析JSON
var chunk map[string]interface{}
if err := json.Unmarshal(data, &chunk); err != nil {
logger.Warn("解析JSON数据块失败: %v, 数据: %s", err, string(data))
continue
}
// 提取内容
choices, ok := chunk["choices"].([]interface{})
if !ok || len(choices) == 0 {
logger.Warn("无法从数据块中提取choices")
continue
}
choice, ok := choices[0].(map[string]interface{})
if !ok {
logger.Warn("无法从choices中提取第一个元素")
continue
}
delta, ok := choice["delta"].(map[string]interface{})
if !ok {
logger.Warn("无法从choice中提取delta")
continue
}
content, ok := delta["content"].(string)
if !ok {
// 某些delta可能没有content这是正常的
continue
}
// 累加内容
responseContent += content
tokenCount += len(strings.Split(content, " "))
}
close(doneChan)
}()
// 等待处理完成或超时
select {
case <-doneChan:
// 处理完成
case <-timeout:
logger.Warn("解析流式响应超时")
result.Error = fmt.Errorf("解析流式响应超时")
return result
}
// 计算响应时间
result.ResponseTime = time.Since(startTime)
// 设置结果
result.Response = responseContent
result.PromptTokenCount = len(result.PromptType) / 2 // 简单估计
result.PromptTokens = len(result.PromptType) / 2 // 简单估计
result.CompletionTokens = tokenCount
result.TotalTokens = result.PromptTokens + result.CompletionTokens
logger.Debug("流式响应处理完成,响应时间: %v, 提示词Token: %d, 完成Token: %d, 总Token: %d",
result.ResponseTime, result.PromptTokens, result.CompletionTokens, result.TotalTokens)
return result
}
// handleNormalResponse 处理普通响应
func (b *Benchmark) handleNormalResponse(httpResp *fasthttp.Response, result Result) Result {
// 解析响应
var response Response
if err := json.Unmarshal(httpResp.Body(), &response); err != nil {
result.Error = fmt.Errorf("解析响应失败: %w", err)
logger.Error("解析响应失败: %v", err)
return result
}
// 记录Token数
result.PromptTokens = response.Usage.PromptTokens
result.CompletionTokens = response.Usage.CompletionTokens
result.TotalTokens = response.Usage.TotalTokens
// 记录详细日志
logger.Debug("请求完成 - 类型: %s, 提示词Token: %d, 响应Token: %d, 响应时间: %v",
result.PromptType, result.PromptTokens, result.CompletionTokens, result.ResponseTime)
return result
}
// getPoissonWaitTime 根据泊松分布获取等待时间
func (b *Benchmark) getPoissonWaitTime() time.Duration {
// 生成泊松分布随机数
lambda := b.config.PoissonLambda
L := math.Exp(-lambda)
k := 0
p := 1.0
for p > L {
k++
p *= b.rand.Float64()
}
// 转换为等待时间(毫秒)
waitTime := time.Duration(k * 100) * time.Millisecond
// 限制最大等待时间为1秒
if waitTime > time.Second {
waitTime = time.Second
}
return waitTime
}
// convertToFloat64 将整数切片转换为浮点数切片
func convertToFloat64(ints []int) []float64 {
floats := make([]float64, len(ints))
for i, v := range ints {
floats[i] = float64(v)
}
return floats
}
// calculateStats 计算统计数据
func (b *Benchmark) calculateStats(results []Result) *BenchmarkResults {
var stats BenchmarkResults
// 过滤出成功的请求
var successResults []Result
for _, r := range results {
if r.Error == nil {
successResults = append(successResults, r)
}
}
// 如果没有成功的请求,返回空结果
if len(successResults) == 0 {
return &stats
}
// 计算基本统计数据
stats.TotalRequests = len(results)
stats.SuccessRequests = len(successResults)
stats.FailedRequests = stats.TotalRequests - stats.SuccessRequests
// 计算响应时间统计
var responseTimes []float64
for _, r := range successResults {
responseTimes = append(responseTimes, float64(r.ResponseTime.Milliseconds()))
}
// 计算平均响应时间
avg, _ := stats.Mean(responseTimes)
stats.AvgResponseTime = time.Duration(avg) * time.Millisecond
// 计算响应时间百分位数
p50, _ := stats.Percentile(responseTimes, 50)
p90, _ := stats.Percentile(responseTimes, 90)
p99, _ := stats.Percentile(responseTimes, 99)
stats.P50ResponseTime = time.Duration(p50) * time.Millisecond
stats.P90ResponseTime = time.Duration(p90) * time.Millisecond
stats.P99ResponseTime = time.Duration(p99) * time.Millisecond
// 计算QPS
totalDuration := successResults[len(successResults)-1].Timestamp.Sub(successResults[0].Timestamp).Seconds()
if totalDuration > 0 {
stats.QPS = float64(stats.SuccessRequests) / totalDuration
}
// 计算Token统计
var promptTokens, completionTokens, totalTokens []int
for _, r := range successResults {
promptTokens = append(promptTokens, r.PromptTokens)
completionTokens = append(completionTokens, r.CompletionTokens)
totalTokens = append(totalTokens, r.TotalTokens)
}
// 计算平均Token数
if len(promptTokens) > 0 {
promptAvg, _ := stats.Mean(convertToFloat64(promptTokens))
completionAvg, _ := stats.Mean(convertToFloat64(completionTokens))
totalAvg, _ := stats.Mean(convertToFloat64(totalTokens))
stats.AvgPromptTokens = int(promptAvg)
stats.AvgCompletionTokens = int(completionAvg)
stats.AvgTotalTokens = int(totalAvg)
}
// 计算Token生成速率 (tokens/s)
var tokenRates []float64
for _, r := range successResults {
if r.ResponseTime.Seconds() > 0 {
tokenRate := float64(r.CompletionTokens) / r.ResponseTime.Seconds()
tokenRates = append(tokenRates, tokenRate)
}
}
if len(tokenRates) > 0 {
tokenRateAvg, _ := stats.Mean(tokenRates)
stats.AvgTokenRate = tokenRateAvg
}
return &stats
}
func (b *BenchmarkResults) Mean(data []float64) (float64, error) {
return stats.Mean(data)
}
func (b *BenchmarkResults) Percentile(data []float64, p float64) (float64, error) {
return stats.Percentile(data, p)
}