400 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package benchmark
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strings"
"sync"
"time"
"github.com/acmestudio/llm-api-benchmark-tool/config"
"github.com/acmestudio/llm-api-benchmark-tool/logger"
"github.com/acmestudio/llm-api-benchmark-tool/prompt"
)
// Message 表示聊天消息
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Request 表示API请求
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Stream bool `json:"stream"`
}
// Choice 表示API响应中的选择
type Choice struct {
Message Message `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
// reasoning_content: 思维链内容
ReasoningContent string `json:"reasoning_content"`
}
// Usage 表示API响应中的使用情况
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Response 表示API响应
type Response struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
// Result 表示单个API请求的结果
type Result struct {
Timestamp time.Time // 请求时间戳
Concurrency int // 并发级别
PromptType string // 提示词类型 (short/long)
PromptTokenCount int // 提示词token数量
PromptTokens int // 提示词token数量API返回
CompletionTokens int // 完成token数量
TotalTokens int // 总token数量
ResponseTime time.Duration // 响应时间
FirstTokenTime time.Duration // 首个令牌生成时间TTFT
Response string // 响应内容
// 推理模型的思维链内容
ReasoningContent string
Error error // 错误信息
}
// BenchmarkResults 表示基准测试结果
type BenchmarkResults struct {
Results []Result // 所有请求结果
StartTime time.Time // 开始时间
EndTime time.Time // 结束时间
Config *config.Config // 配置信息
ConcurrencyData map[int][]Result // 按并发步骤分组的结果
TotalRequests int // 总请求数
SuccessRequests int // 成功请求数
FailedRequests int // 失败请求数
AvgResponseTime time.Duration // 平均响应时间
P50ResponseTime time.Duration // 响应时间50百分位
P90ResponseTime time.Duration // 响应时间90百分位
P99ResponseTime time.Duration // 响应时间99百分位
QPS float64 // 每秒请求数
AvgPromptTokens int // 平均提示词token数
AvgCompletionTokens int // 平均完成token数
AvgTotalTokens int // 平均总token数
AvgTokenRate float64 // 平均token生成速率
}
// Benchmark 表示基准测试执行器
type Benchmark struct {
config *config.Config
generator *prompt.Generator
httpClient *http.Client
rand *rand.Rand
}
// NewBenchmark 创建新的基准测试执行器
func NewBenchmark(cfg *config.Config) (*Benchmark, error) {
gen, _ := prompt.NewGenerator(cfg)
return &Benchmark{
config: cfg,
generator: gen,
httpClient: &http.Client{Timeout: time.Duration(cfg.Timeout) * time.Second},
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}, nil
}
// Run 执行基准测试
func (b *Benchmark) Run() (*BenchmarkResults, error) {
results := &BenchmarkResults{
StartTime: time.Now(),
Config: b.config,
ConcurrencyData: make(map[int][]Result),
}
for _, concurrency := range b.config.Concurrency.Steps {
logger.Info("开始执行并发步骤: %d 并发用户", concurrency)
stepResults, err := b.runConcurrencyStep(concurrency)
if err != nil {
return nil, fmt.Errorf("执行并发步骤 %d 失败: %w", concurrency, err)
}
for _, result := range stepResults {
results.Results = append(results.Results, result)
}
results.ConcurrencyData[concurrency] = stepResults
}
results.EndTime = time.Now()
return results, nil
}
// executeRequest 执行单个请求
func (b *Benchmark) executeRequest(concurrencyStep int) Result {
result := Result{
Timestamp: time.Now(),
Concurrency: concurrencyStep,
}
// 选择提示词类型
promptType := b.generator.SelectPromptType()
result.PromptType = promptType
// 生成并统计提示词
prompt, err := b.generator.GeneratePrompt(promptType)
if err != nil {
result.Error = fmt.Errorf("生成提示词失败: %w", err)
return result
}
promptTokens := b.generator.CountTokens(prompt)
result.PromptTokenCount = promptTokens
result.PromptTokens = promptTokens
// 构造请求对象
reqObj := Request{
Model: b.config.API.Model,
Messages: []Message{{Role: "user", Content: prompt}},
MaxTokens: b.config.API.MaxTokens,
Temperature: b.config.API.Temperature,
Stream: true,
}
reqBody, err := json.Marshal(reqObj)
if err != nil {
result.Error = fmt.Errorf("序列化请求失败: %w", err)
return result
}
requestID := fmt.Sprintf("request-%s", result.Timestamp.Format("20060102-150405.000"))
logger.LogToFile(fmt.Sprintf("===== 请求 ID: %s =====\n并发: %d, 类型: %s, 端点: %s\n请求体:\n%s\n", requestID, concurrencyStep, promptType, b.config.API.Endpoint, string(reqBody)))
// 构造 HTTP 请求并绑定超时上下文
httpReq, err := http.NewRequest("POST", b.config.API.Endpoint, bytes.NewReader(reqBody))
if err != nil {
result.Error = fmt.Errorf("创建HTTP请求失败: %w", err)
return result
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+b.config.API.APIKey)
httpReq.Header.Set("Accept", "text/event-stream")
ctxReq, cancel := context.WithTimeout(httpReq.Context(), time.Duration(b.config.Timeout)*time.Second)
defer cancel()
httpReq = httpReq.WithContext(ctxReq)
startTime := time.Now()
resp, err := b.httpClient.Do(httpReq)
if err != nil {
result.Error = fmt.Errorf("发送请求失败: %w", err)
logger.Error("请求错误: %v", err)
return result
}
defer resp.Body.Close()
// 检查状态码
if resp.StatusCode != http.StatusOK {
result.Error = fmt.Errorf("HTTP状态码错误: %d", resp.StatusCode)
logger.Warn("HTTP状态码非200: %d", resp.StatusCode)
return result
}
// 记录响应时间
responseTime := time.Since(startTime)
result.ResponseTime = responseTime
// 非流式记录前4K
if !reqObj.Stream {
head, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
logger.LogToFile(fmt.Sprintf("===== 响应头部 ID: %s =====\n状态码: %d\n前4K响应体:\n%s\n", requestID, resp.StatusCode, string(head)))
resp.Body = io.NopCloser(io.MultiReader(bytes.NewReader(head), resp.Body))
}
if reqObj.Stream {
return b.handleStreamResponse(resp, result, startTime)
}
return b.handleNormalResponse(resp, result)
}
// runConcurrencyStep 执行单个并发步骤
func (b *Benchmark) runConcurrencyStep(concurrencyStep int) ([]Result, error) {
var results []Result
var resultsMutex sync.Mutex
// 创建上下文,用于控制测试时间
duration := b.config.Concurrency.DurationPerStep
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(duration)*time.Second)
defer cancel()
// 启动 worker goroutine
var wg sync.WaitGroup
for i := 0; i < concurrencyStep; i++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
res := b.executeRequest(concurrencyStep)
resultsMutex.Lock()
results = append(results, res)
resultsMutex.Unlock()
// 按泊松分布间隔
interval := b.rand.ExpFloat64() / b.config.PoissonLambda
time.Sleep(time.Duration(interval * float64(time.Second)))
}
}
}(i)
}
wg.Wait()
return results, nil
}
// handleStreamResponse 处理流式响应
func (b *Benchmark) handleStreamResponse(resp *http.Response, result Result, startTime time.Time) Result {
var builder strings.Builder
// 聚合思维链
var reasoningBuilder strings.Builder
tokenCount := 0
firstTokenReceived := false
var firstTokenTime time.Duration
reader := bufio.NewReaderSize(resp.Body, 4096)
responseHeaderTime := time.Now()
logger.Debug("接收到响应头,网络请求耗时: %v", responseHeaderTime.Sub(startTime))
chunkIdx := 0
for {
line, err := reader.ReadString('\n')
if err != nil {
if err != io.EOF {
logger.Warn("读取流式响应出错: %v", err)
}
break
}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
data = strings.TrimSuffix(data, "\n")
if data == "[DONE]" {
logger.Debug("收到流式响应结束标记")
break
}
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
logger.Warn("JSON解析失败: %v", err)
continue
}
choices, ok := chunk["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
choiceIface := choices[0]
choiceMap, ok := choiceIface.(map[string]interface{})
if !ok {
continue
}
deltaIface, ok := choiceMap["delta"]
deltaMap, ok := deltaIface.(map[string]interface{})
if !ok {
continue
}
contentVal, ok := deltaMap["content"]
content, ok := contentVal.(string)
if !ok || content == "" {
continue
}
now := time.Now()
builder.WriteString(content)
tokenCount += len(strings.Split(content, " "))
chunkIdx++
logger.LogToFile(fmt.Sprintf("[Chunk %d] %s 内容: %s", chunkIdx, now.Format(time.RFC3339Nano), content))
// 初次令牌检测:推理模型优先考虑思维链内容
rcVal, rcOk := deltaMap["reasoning_content"]
rcStr, _ := rcVal.(string)
if !firstTokenReceived {
if rcOk && rcStr != "" {
firstTokenTime = now.Sub(responseHeaderTime)
firstTokenReceived = true
logger.Debug("收到首个思维链令牌TTFT: %v", firstTokenTime)
} else if content != "" {
firstTokenTime = now.Sub(responseHeaderTime)
firstTokenReceived = true
logger.Debug("收到首个内容令牌TTFT: %v", firstTokenTime)
}
}
// 聚合思维链字段
if rcOk && rcStr != "" {
reasoningBuilder.WriteString(rcStr)
}
}
responseContent := builder.String()
result.Response = responseContent
// CompletionTokens 和总Token
result.CompletionTokens = tokenCount
result.TotalTokens = result.PromptTokens + result.CompletionTokens
// 设置思维链内容
result.ReasoningContent = reasoningBuilder.String()
// 设置响应时间和首个令牌时间 (TTFT)
result.ResponseTime = time.Since(startTime)
if firstTokenReceived {
result.FirstTokenTime = firstTokenTime
} else {
// 未检测到首个令牌,使用响应时间作为 TTFT
result.FirstTokenTime = result.ResponseTime
}
return result
}
// handleNormalResponse 处理普通响应
func (b *Benchmark) handleNormalResponse(resp *http.Response, result Result) Result {
var response Response
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
result.Error = fmt.Errorf("解析响应失败: %w", err)
logger.Error("解析响应失败: %v", err)
return result
}
// 先获取响应内容
if len(response.Choices) > 0 {
choice := response.Choices[0]
result.Response = choice.Message.Content
// 推理模型思维链
result.ReasoningContent = choice.ReasoningContent
}
// 根据 Usage 判断是否兼容推理模型
if response.Usage.TotalTokens > 0 {
result.PromptTokens = response.Usage.PromptTokens
result.CompletionTokens = response.Usage.CompletionTokens
result.TotalTokens = response.Usage.TotalTokens
} else {
// 推理模型未返回 Usage使用本地统计
result.PromptTokens = result.PromptTokenCount
result.CompletionTokens = b.generator.CountTokens(result.Response)
result.TotalTokens = result.PromptTokens + result.CompletionTokens
}
result.FirstTokenTime = result.ResponseTime
logger.Debug("非流式响应使用响应时间作为TTFT: %v", result.FirstTokenTime)
processedLog := fmt.Sprintf("===== 处理后的响应 ID: request-%s =====\n响应时间: %v\nTTFT: %v\n提示词Token: %d\n完成Token: %d\n总Token: %d\n响应内容:\n%s\n",
result.Timestamp.Format("20060102-150405.000"), result.ResponseTime, result.FirstTokenTime,
result.PromptTokens, result.CompletionTokens, result.TotalTokens, result.Response)
logger.LogToFile(processedLog)
logger.Debug("请求完成 - 类型: %s, 提示词Token: %d, 响应Token: %d, 响应时间: %v",
result.PromptType, result.PromptTokens, result.CompletionTokens, result.ResponseTime)
return result
}
// ... (其他代码保持不变)