2025-04-21 18:19:09 +08:00

328 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package prompt
import (
"fmt"
"math/rand"
"strings"
"time"
"github.com/acmestudio/llm-api-benchmark-tool/config"
"github.com/acmestudio/llm-api-benchmark-tool/logger"
"github.com/pkoukk/tiktoken-go"
)
// 填充词,用于调整提示词长度
var fillerWords = []string{
"详细", "简要", "请", "麻烦", "帮忙", "认真", "仔细", "全面", "系统性地",
"清晰地", "准确地", "深入地", "专业地", "客观地", "科学地", "有条理地",
"具体地", "简洁地", "完整地", "精确地", "有逻辑地", "有说服力地",
}
// 长文本填充段落,用于扩展长文档提示词
var longFillerParagraphs = []string{
"请确保分析全面深入,考虑各个方面的影响和变化。需要包含具体的数据和案例来支持你的观点,同时也要有理论依据。",
"在回答中,请注意逻辑性和连贯性,确保各部分之间有合理的过渡。同时,请尽量使用专业术语,但也要确保非专业人士能够理解。",
"希望你能够从多个角度进行分析,包括但不限于社会学、经济学、政治学、心理学等视角。同时,也请考虑不同群体、不同地区的差异性。",
"在分析中,请注意历史背景和未来趋势,将当前情况放在更广阔的时空背景下考量。同时,也请关注国际比较,了解全球范围内的相似与差异。",
"请在回答中提供一些可行的建议或解决方案,这些建议应该是具体的、可操作的,并且考虑到实施过程中可能遇到的困难和阻力。",
"在分析问题时,请注意考虑各种可能的因素,包括社会、经济、政治、文化、技术等方面。同时,也请思考这些因素之间的相互作用和影响。",
"请在回答中使用清晰的结构,可以考虑使用小标题、要点列表等方式组织内容,使读者能够更容易地理解和把握核心内容。",
"在分析过程中,请注意区分事实和观点,确保事实准确无误,观点有理有据。同时,也请尽量避免个人偏见,保持客观中立的态度。",
"请确保你的回答既有理论深度,又有实践指导意义。理论分析应该有学术依据,实践建议应该考虑可行性和有效性。",
"在回答中,请注意时效性,考虑最新的发展和变化。同时,也请关注长期趋势和基本规律,避免被短期波动所干扰。",
}
// 替换词,用于模板中的占位符
var replacements = map[string][]string{
"country": {"中国", "美国", "日本", "英国", "法国", "德国", "俄罗斯", "加拿大", "澳大利亚", "巴西", "印度"},
"concept": {"人工智能", "机器学习", "区块链", "量子计算", "云计算", "边缘计算", "物联网", "大数据", "虚拟现实", "增强现实"},
"topic": {"全球变暖", "可持续发展", "数字化转型", "远程工作", "网络安全", "太空探索", "生物技术", "清洁能源", "智慧城市", "数字隐私"},
"event": {"工业革命", "互联网诞生", "冷战", "登月", "柏林墙倒塌", "9/11事件", "COVID-19疫情", "阿拉伯之春", "2008金融危机", "人类基因组计划"},
}
// Generator 提示词生成器
type Generator struct {
config *config.Config
tokenizer *tiktoken.Tiktoken
rand *rand.Rand
}
// NewGenerator 创建新的提示词生成器
func NewGenerator(cfg *config.Config) (*Generator, error) {
// 初始化分词器
logger.Debug("初始化分词器,模型: %s", cfg.Tokenizer.Model)
tkm, err := tiktoken.EncodingForModel(cfg.Tokenizer.Model)
if err != nil {
return nil, fmt.Errorf("初始化分词器失败: %w", err)
}
logger.Debug("分词器初始化成功")
// 初始化随机数生成器
source := rand.NewSource(time.Now().UnixNano())
logger.Debug("提示词模板配置 - 短咨询: %d个模板, 目标Token: %d",
len(cfg.PromptTemplates.Short.Templates), cfg.PromptTemplates.Short.TargetTokens)
logger.Debug("提示词模板配置 - 长文档: %d个模板, 目标Token: %d",
len(cfg.PromptTemplates.Long.Templates), cfg.PromptTemplates.Long.TargetTokens)
return &Generator{
config: cfg,
tokenizer: tkm,
rand: rand.New(source),
}, nil
}
// GeneratePrompt 生成指定类型的提示词
func (g *Generator) GeneratePrompt(promptType string) (string, error) {
var templateConfig config.PromptTypeConfig
// 根据类型选择模板配置
switch promptType {
case "short":
templateConfig = g.config.PromptTemplates.Short
case "long":
templateConfig = g.config.PromptTemplates.Long
default:
return "", fmt.Errorf("未知的提示词类型: %s", promptType)
}
// 检查模板是否存在
if len(templateConfig.Templates) == 0 {
return "", fmt.Errorf("提示词模板为空: %s", promptType)
}
// 随机选择一个模板
templateIndex := g.rand.Intn(len(templateConfig.Templates))
template := templateConfig.Templates[templateIndex]
logger.Debug("为%s类型选择模板 #%d: %s", promptType, templateIndex, template)
// 替换模板中的占位符
prompt := g.replacePlaceholders(template)
logger.Debug("替换占位符后: %s", prompt)
// 调整提示词长度至目标Token数
result, err := g.adjustPromptLength(prompt, templateConfig.TargetTokens)
if err != nil {
return "", err
}
// 计算最终Token数
tokens := g.tokenizer.Encode(result, nil, nil)
logger.Debug("生成%s类型提示词成功目标Token: %d, 实际Token: %d",
promptType, templateConfig.TargetTokens, len(tokens))
return result, nil
}
// replacePlaceholders 替换模板中的占位符
func (g *Generator) replacePlaceholders(template string) string {
result := template
// 替换所有占位符
for placeholder, options := range replacements {
placeholderTag := fmt.Sprintf("{%s}", placeholder)
if strings.Contains(result, placeholderTag) {
// 随机选择一个替换选项
replacementIndex := g.rand.Intn(len(options))
replacement := options[replacementIndex]
logger.Debug("替换占位符 %s -> %s", placeholderTag, replacement)
result = strings.Replace(result, placeholderTag, replacement, -1)
}
}
return result
}
// adjustPromptLength 调整提示词长度至目标Token数
func (g *Generator) adjustPromptLength(prompt string, targetTokens int) (string, error) {
// 计算当前Token数
tokens := g.tokenizer.Encode(prompt, nil, nil)
currentTokens := len(tokens)
// 计算目标范围±5%
minTokens := int(float64(targetTokens) * 0.95)
maxTokens := int(float64(targetTokens) * 1.05)
logger.Debug("调整提示词长度 - 当前: %d, 目标范围: %d-%d", currentTokens, minTokens, maxTokens)
// 如果当前Token数在目标范围内直接返回
if currentTokens >= minTokens && currentTokens <= maxTokens {
logger.Debug("提示词长度已在目标范围内,无需调整")
return prompt, nil
}
// 如果Token数过少添加填充词
if currentTokens < minTokens {
logger.Debug("提示词长度过短,需要扩展")
return g.extendPrompt(prompt, targetTokens)
}
// 如果Token数过多截断提示词
logger.Debug("提示词长度过长,需要截断")
return g.truncatePrompt(prompt, targetTokens)
}
// extendPrompt 通过添加填充词扩展提示词
func (g *Generator) extendPrompt(prompt string, targetTokens int) (string, error) {
result := prompt
// 计算当前Token数
tokens := g.tokenizer.Encode(result, nil, nil)
currentTokens := len(tokens)
// 计算目标范围
minTokens := int(float64(targetTokens) * 0.95)
// 根据目标token数选择不同的扩展策略
if targetTokens > 500 {
// 长文档提示词,使用段落填充
logger.Debug("使用段落填充扩展长文档提示词")
return g.extendPromptWithParagraphs(prompt, targetTokens)
}
// 短咨询提示词,使用填充词
logger.Debug("使用填充词扩展短咨询提示词")
// 添加填充词直到达到目标Token数
fillerCount := 0
for currentTokens < minTokens {
// 随机选择一个填充词
fillerIndex := g.rand.Intn(len(fillerWords))
filler := fillerWords[fillerIndex]
// 在提示词前添加填充词
result = filler + " " + result
fillerCount++
// 重新计算Token数
tokens = g.tokenizer.Encode(result, nil, nil)
currentTokens = len(tokens)
// 防止无限循环
if currentTokens > int(float64(targetTokens)*1.05) {
// 如果添加填充词后超过了最大Token数回退并尝试截断
logger.Debug("添加%d个填充词后超过最大Token数尝试截断", fillerCount)
return g.truncatePrompt(result, targetTokens)
}
if fillerCount > 100 {
logger.Warn("添加填充词超过100次尝试使用段落填充")
return g.extendPromptWithParagraphs(prompt, targetTokens)
}
}
logger.Debug("扩展提示词成功,添加了%d个填充词当前Token数: %d", fillerCount, currentTokens)
return result, nil
}
// extendPromptWithParagraphs 通过添加段落扩展长文档提示词
func (g *Generator) extendPromptWithParagraphs(prompt string, targetTokens int) (string, error) {
result := prompt
// 计算当前Token数
tokens := g.tokenizer.Encode(result, nil, nil)
currentTokens := len(tokens)
// 计算目标范围
minTokens := int(float64(targetTokens) * 0.95)
// 添加段落直到达到目标Token数
paragraphCount := 0
usedParagraphs := make(map[int]bool)
for currentTokens < minTokens && paragraphCount < 20 {
// 随机选择一个未使用的段落
var paragraphIndex int
for {
paragraphIndex = g.rand.Intn(len(longFillerParagraphs))
if !usedParagraphs[paragraphIndex] {
usedParagraphs[paragraphIndex] = true
break
}
// 如果所有段落都已使用,重置使用记录
if len(usedParagraphs) >= len(longFillerParagraphs) {
usedParagraphs = make(map[int]bool)
}
}
paragraph := longFillerParagraphs[paragraphIndex]
// 添加段落
if paragraphCount == 0 {
result = result + "\n\n" + paragraph
} else {
result = result + "\n\n" + paragraph
}
paragraphCount++
// 重新计算Token数
tokens = g.tokenizer.Encode(result, nil, nil)
currentTokens = len(tokens)
logger.Debug("添加第%d个段落后当前Token数: %d", paragraphCount, currentTokens)
// 如果已经达到目标范围,退出循环
if currentTokens >= minTokens {
break
}
}
// 如果添加了最大段落数仍未达到目标,给出警告
if currentTokens < minTokens {
logger.Warn("即使添加了%d个段落仍未达到目标Token数当前: %d, 目标: %d",
paragraphCount, currentTokens, targetTokens)
} else {
logger.Debug("扩展长文档提示词成功,添加了%d个段落当前Token数: %d", paragraphCount, currentTokens)
}
return result, nil
}
// truncatePrompt 通过截断缩短提示词
func (g *Generator) truncatePrompt(prompt string, targetTokens int) (string, error) {
// 计算目标范围
maxTokens := int(float64(targetTokens) * 1.05)
// 编码提示词
tokens := g.tokenizer.Encode(prompt, nil, nil)
// 如果Token数超过最大值截断
if len(tokens) > maxTokens {
// 保留目标Token数
originalLength := len(tokens)
tokens = tokens[:maxTokens]
// 解码截断后的Token
result := g.tokenizer.Decode(tokens)
logger.Debug("截断提示词成功,从%d个Token截断到%d个Token", originalLength, len(tokens))
return result, nil
}
return prompt, nil
}
// SelectPromptType 根据权重随机选择提示词类型
func (g *Generator) SelectPromptType() string {
// 计算累积权重
var cumulativeWeights []float64
var sum float64
for _, req := range g.config.Requests {
sum += req.Weight
cumulativeWeights = append(cumulativeWeights, sum)
}
// 生成随机数
r := g.rand.Float64() * sum
// 根据随机数选择提示词类型
for i, weight := range cumulativeWeights {
if r <= weight {
return g.config.Requests[i].Type
}
}
// 默认返回第一个类型
return g.config.Requests[0].Type
}