400 lines
13 KiB
Go
400 lines
13 KiB
Go
package benchmark
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"math/rand"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/acmestudio/llm-api-benchmark-tool/config"
|
||
"github.com/acmestudio/llm-api-benchmark-tool/logger"
|
||
"github.com/acmestudio/llm-api-benchmark-tool/prompt"
|
||
)
|
||
|
||
// Message 表示聊天消息
|
||
type Message struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
// Request 表示API请求
|
||
type Request struct {
|
||
Model string `json:"model"`
|
||
Messages []Message `json:"messages"`
|
||
MaxTokens int `json:"max_tokens"`
|
||
Temperature float64 `json:"temperature"`
|
||
Stream bool `json:"stream"`
|
||
}
|
||
|
||
// Choice 表示API响应中的选择
|
||
type Choice struct {
|
||
Message Message `json:"message"`
|
||
Index int `json:"index"`
|
||
FinishReason string `json:"finish_reason"`
|
||
// reasoning_content: 思维链内容
|
||
ReasoningContent string `json:"reasoning_content"`
|
||
}
|
||
|
||
// Usage 表示API响应中的使用情况
|
||
type Usage struct {
|
||
PromptTokens int `json:"prompt_tokens"`
|
||
CompletionTokens int `json:"completion_tokens"`
|
||
TotalTokens int `json:"total_tokens"`
|
||
}
|
||
|
||
// Response 表示API响应
|
||
type Response struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
Model string `json:"model"`
|
||
Choices []Choice `json:"choices"`
|
||
Usage Usage `json:"usage"`
|
||
}
|
||
|
||
// Result 表示单个API请求的结果
|
||
type Result struct {
|
||
Timestamp time.Time // 请求时间戳
|
||
Concurrency int // 并发级别
|
||
PromptType string // 提示词类型 (short/long)
|
||
PromptTokenCount int // 提示词token数量
|
||
PromptTokens int // 提示词token数量(API返回)
|
||
CompletionTokens int // 完成token数量
|
||
TotalTokens int // 总token数量
|
||
ResponseTime time.Duration // 响应时间
|
||
FirstTokenTime time.Duration // 首个令牌生成时间(TTFT)
|
||
Response string // 响应内容
|
||
// 推理模型的思维链内容
|
||
ReasoningContent string
|
||
Error error // 错误信息
|
||
}
|
||
|
||
// BenchmarkResults 表示基准测试结果
|
||
type BenchmarkResults struct {
|
||
Results []Result // 所有请求结果
|
||
StartTime time.Time // 开始时间
|
||
EndTime time.Time // 结束时间
|
||
Config *config.Config // 配置信息
|
||
ConcurrencyData map[int][]Result // 按并发步骤分组的结果
|
||
TotalRequests int // 总请求数
|
||
SuccessRequests int // 成功请求数
|
||
FailedRequests int // 失败请求数
|
||
AvgResponseTime time.Duration // 平均响应时间
|
||
P50ResponseTime time.Duration // 响应时间50百分位
|
||
P90ResponseTime time.Duration // 响应时间90百分位
|
||
P99ResponseTime time.Duration // 响应时间99百分位
|
||
QPS float64 // 每秒请求数
|
||
AvgPromptTokens int // 平均提示词token数
|
||
AvgCompletionTokens int // 平均完成token数
|
||
AvgTotalTokens int // 平均总token数
|
||
AvgTokenRate float64 // 平均token生成速率
|
||
}
|
||
|
||
// Benchmark 表示基准测试执行器
|
||
type Benchmark struct {
|
||
config *config.Config
|
||
generator *prompt.Generator
|
||
httpClient *http.Client
|
||
rand *rand.Rand
|
||
}
|
||
|
||
// NewBenchmark 创建新的基准测试执行器
|
||
func NewBenchmark(cfg *config.Config) (*Benchmark, error) {
|
||
gen, _ := prompt.NewGenerator(cfg)
|
||
return &Benchmark{
|
||
config: cfg,
|
||
generator: gen,
|
||
httpClient: &http.Client{Timeout: time.Duration(cfg.Timeout) * time.Second},
|
||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||
}, nil
|
||
}
|
||
|
||
// Run 执行基准测试
|
||
func (b *Benchmark) Run() (*BenchmarkResults, error) {
|
||
results := &BenchmarkResults{
|
||
StartTime: time.Now(),
|
||
Config: b.config,
|
||
ConcurrencyData: make(map[int][]Result),
|
||
}
|
||
for _, concurrency := range b.config.Concurrency.Steps {
|
||
logger.Info("开始执行并发步骤: %d 并发用户", concurrency)
|
||
stepResults, err := b.runConcurrencyStep(concurrency)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("执行并发步骤 %d 失败: %w", concurrency, err)
|
||
}
|
||
for _, result := range stepResults {
|
||
results.Results = append(results.Results, result)
|
||
}
|
||
results.ConcurrencyData[concurrency] = stepResults
|
||
}
|
||
results.EndTime = time.Now()
|
||
return results, nil
|
||
}
|
||
|
||
// executeRequest 执行单个请求
|
||
func (b *Benchmark) executeRequest(concurrencyStep int) Result {
|
||
result := Result{
|
||
Timestamp: time.Now(),
|
||
Concurrency: concurrencyStep,
|
||
}
|
||
|
||
// 选择提示词类型
|
||
promptType := b.generator.SelectPromptType()
|
||
result.PromptType = promptType
|
||
|
||
// 生成并统计提示词
|
||
prompt, err := b.generator.GeneratePrompt(promptType)
|
||
if err != nil {
|
||
result.Error = fmt.Errorf("生成提示词失败: %w", err)
|
||
return result
|
||
}
|
||
promptTokens := b.generator.CountTokens(prompt)
|
||
result.PromptTokenCount = promptTokens
|
||
result.PromptTokens = promptTokens
|
||
|
||
// 构造请求对象
|
||
reqObj := Request{
|
||
Model: b.config.API.Model,
|
||
Messages: []Message{{Role: "user", Content: prompt}},
|
||
MaxTokens: b.config.API.MaxTokens,
|
||
Temperature: b.config.API.Temperature,
|
||
Stream: true,
|
||
}
|
||
|
||
reqBody, err := json.Marshal(reqObj)
|
||
if err != nil {
|
||
result.Error = fmt.Errorf("序列化请求失败: %w", err)
|
||
return result
|
||
}
|
||
|
||
requestID := fmt.Sprintf("request-%s", result.Timestamp.Format("20060102-150405.000"))
|
||
logger.LogToFile(fmt.Sprintf("===== 请求 ID: %s =====\n并发: %d, 类型: %s, 端点: %s\n请求体:\n%s\n", requestID, concurrencyStep, promptType, b.config.API.Endpoint, string(reqBody)))
|
||
|
||
// 构造 HTTP 请求并绑定超时上下文
|
||
httpReq, err := http.NewRequest("POST", b.config.API.Endpoint, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
result.Error = fmt.Errorf("创建HTTP请求失败: %w", err)
|
||
return result
|
||
}
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", "Bearer "+b.config.API.APIKey)
|
||
httpReq.Header.Set("Accept", "text/event-stream")
|
||
|
||
ctxReq, cancel := context.WithTimeout(httpReq.Context(), time.Duration(b.config.Timeout)*time.Second)
|
||
defer cancel()
|
||
httpReq = httpReq.WithContext(ctxReq)
|
||
|
||
startTime := time.Now()
|
||
resp, err := b.httpClient.Do(httpReq)
|
||
if err != nil {
|
||
result.Error = fmt.Errorf("发送请求失败: %w", err)
|
||
logger.Error("请求错误: %v", err)
|
||
return result
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 检查状态码
|
||
if resp.StatusCode != http.StatusOK {
|
||
result.Error = fmt.Errorf("HTTP状态码错误: %d", resp.StatusCode)
|
||
logger.Warn("HTTP状态码非200: %d", resp.StatusCode)
|
||
return result
|
||
}
|
||
|
||
// 记录响应时间
|
||
responseTime := time.Since(startTime)
|
||
result.ResponseTime = responseTime
|
||
|
||
// 非流式记录前4K
|
||
if !reqObj.Stream {
|
||
head, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||
logger.LogToFile(fmt.Sprintf("===== 响应头部 ID: %s =====\n状态码: %d\n前4K响应体:\n%s\n", requestID, resp.StatusCode, string(head)))
|
||
resp.Body = io.NopCloser(io.MultiReader(bytes.NewReader(head), resp.Body))
|
||
}
|
||
|
||
if reqObj.Stream {
|
||
return b.handleStreamResponse(resp, result, startTime)
|
||
}
|
||
return b.handleNormalResponse(resp, result)
|
||
}
|
||
|
||
// runConcurrencyStep 执行单个并发步骤
|
||
func (b *Benchmark) runConcurrencyStep(concurrencyStep int) ([]Result, error) {
|
||
var results []Result
|
||
var resultsMutex sync.Mutex
|
||
|
||
// 创建上下文,用于控制测试时间
|
||
duration := b.config.Concurrency.DurationPerStep
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(duration)*time.Second)
|
||
defer cancel()
|
||
|
||
// 启动 worker goroutine
|
||
var wg sync.WaitGroup
|
||
for i := 0; i < concurrencyStep; i++ {
|
||
wg.Add(1)
|
||
go func(workerID int) {
|
||
defer wg.Done()
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
default:
|
||
res := b.executeRequest(concurrencyStep)
|
||
resultsMutex.Lock()
|
||
results = append(results, res)
|
||
resultsMutex.Unlock()
|
||
// 按泊松分布间隔
|
||
interval := b.rand.ExpFloat64() / b.config.PoissonLambda
|
||
time.Sleep(time.Duration(interval * float64(time.Second)))
|
||
}
|
||
}
|
||
}(i)
|
||
}
|
||
|
||
wg.Wait()
|
||
return results, nil
|
||
}
|
||
|
||
// handleStreamResponse 处理流式响应
|
||
func (b *Benchmark) handleStreamResponse(resp *http.Response, result Result, startTime time.Time) Result {
|
||
var builder strings.Builder
|
||
// 聚合思维链
|
||
var reasoningBuilder strings.Builder
|
||
tokenCount := 0
|
||
firstTokenReceived := false
|
||
var firstTokenTime time.Duration
|
||
|
||
reader := bufio.NewReaderSize(resp.Body, 4096)
|
||
responseHeaderTime := time.Now()
|
||
logger.Debug("接收到响应头,网络请求耗时: %v", responseHeaderTime.Sub(startTime))
|
||
|
||
chunkIdx := 0
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
if err != nil {
|
||
if err != io.EOF {
|
||
logger.Warn("读取流式响应出错: %v", err)
|
||
}
|
||
break
|
||
}
|
||
if !strings.HasPrefix(line, "data: ") {
|
||
continue
|
||
}
|
||
data := strings.TrimPrefix(line, "data: ")
|
||
data = strings.TrimSuffix(data, "\n")
|
||
if data == "[DONE]" {
|
||
logger.Debug("收到流式响应结束标记")
|
||
break
|
||
}
|
||
var chunk map[string]interface{}
|
||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||
logger.Warn("JSON解析失败: %v", err)
|
||
continue
|
||
}
|
||
choices, ok := chunk["choices"].([]interface{})
|
||
if !ok || len(choices) == 0 {
|
||
continue
|
||
}
|
||
choiceIface := choices[0]
|
||
choiceMap, ok := choiceIface.(map[string]interface{})
|
||
if !ok {
|
||
continue
|
||
}
|
||
deltaIface, ok := choiceMap["delta"]
|
||
deltaMap, ok := deltaIface.(map[string]interface{})
|
||
if !ok {
|
||
continue
|
||
}
|
||
contentVal, ok := deltaMap["content"]
|
||
content, ok := contentVal.(string)
|
||
if !ok || content == "" {
|
||
continue
|
||
}
|
||
now := time.Now()
|
||
builder.WriteString(content)
|
||
tokenCount += len(strings.Split(content, " "))
|
||
chunkIdx++
|
||
logger.LogToFile(fmt.Sprintf("[Chunk %d] %s 内容: %s", chunkIdx, now.Format(time.RFC3339Nano), content))
|
||
// 初次令牌检测:推理模型优先考虑思维链内容
|
||
rcVal, rcOk := deltaMap["reasoning_content"]
|
||
rcStr, _ := rcVal.(string)
|
||
if !firstTokenReceived {
|
||
if rcOk && rcStr != "" {
|
||
firstTokenTime = now.Sub(responseHeaderTime)
|
||
firstTokenReceived = true
|
||
logger.Debug("收到首个思维链令牌,TTFT: %v", firstTokenTime)
|
||
} else if content != "" {
|
||
firstTokenTime = now.Sub(responseHeaderTime)
|
||
firstTokenReceived = true
|
||
logger.Debug("收到首个内容令牌,TTFT: %v", firstTokenTime)
|
||
}
|
||
}
|
||
// 聚合思维链字段
|
||
if rcOk && rcStr != "" {
|
||
reasoningBuilder.WriteString(rcStr)
|
||
}
|
||
}
|
||
responseContent := builder.String()
|
||
|
||
result.Response = responseContent
|
||
// CompletionTokens 和总Token
|
||
result.CompletionTokens = tokenCount
|
||
result.TotalTokens = result.PromptTokens + result.CompletionTokens
|
||
// 设置思维链内容
|
||
result.ReasoningContent = reasoningBuilder.String()
|
||
// 设置响应时间和首个令牌时间 (TTFT)
|
||
result.ResponseTime = time.Since(startTime)
|
||
if firstTokenReceived {
|
||
result.FirstTokenTime = firstTokenTime
|
||
} else {
|
||
// 未检测到首个令牌,使用响应时间作为 TTFT
|
||
result.FirstTokenTime = result.ResponseTime
|
||
}
|
||
return result
|
||
}
|
||
|
||
// handleNormalResponse 处理普通响应
|
||
func (b *Benchmark) handleNormalResponse(resp *http.Response, result Result) Result {
|
||
var response Response
|
||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||
result.Error = fmt.Errorf("解析响应失败: %w", err)
|
||
logger.Error("解析响应失败: %v", err)
|
||
return result
|
||
}
|
||
// 先获取响应内容
|
||
if len(response.Choices) > 0 {
|
||
choice := response.Choices[0]
|
||
result.Response = choice.Message.Content
|
||
// 推理模型思维链
|
||
result.ReasoningContent = choice.ReasoningContent
|
||
}
|
||
// 根据 Usage 判断是否兼容推理模型
|
||
if response.Usage.TotalTokens > 0 {
|
||
result.PromptTokens = response.Usage.PromptTokens
|
||
result.CompletionTokens = response.Usage.CompletionTokens
|
||
result.TotalTokens = response.Usage.TotalTokens
|
||
} else {
|
||
// 推理模型未返回 Usage,使用本地统计
|
||
result.PromptTokens = result.PromptTokenCount
|
||
result.CompletionTokens = b.generator.CountTokens(result.Response)
|
||
result.TotalTokens = result.PromptTokens + result.CompletionTokens
|
||
}
|
||
result.FirstTokenTime = result.ResponseTime
|
||
logger.Debug("非流式响应,使用响应时间作为TTFT: %v", result.FirstTokenTime)
|
||
processedLog := fmt.Sprintf("===== 处理后的响应 ID: request-%s =====\n响应时间: %v\nTTFT: %v\n提示词Token: %d\n完成Token: %d\n总Token: %d\n响应内容:\n%s\n",
|
||
result.Timestamp.Format("20060102-150405.000"), result.ResponseTime, result.FirstTokenTime,
|
||
result.PromptTokens, result.CompletionTokens, result.TotalTokens, result.Response)
|
||
logger.LogToFile(processedLog)
|
||
logger.Debug("请求完成 - 类型: %s, 提示词Token: %d, 响应Token: %d, 响应时间: %v",
|
||
result.PromptType, result.PromptTokens, result.CompletionTokens, result.ResponseTime)
|
||
return result
|
||
}
|
||
|
||
// ... (其他代码保持不变)
|