328 lines
12 KiB
Go
328 lines
12 KiB
Go
package prompt
|
||
|
||
import (
|
||
"fmt"
|
||
"math/rand"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/acmestudio/llm-api-benchmark-tool/config"
|
||
"github.com/acmestudio/llm-api-benchmark-tool/logger"
|
||
"github.com/pkoukk/tiktoken-go"
|
||
)
|
||
|
||
// 填充词,用于调整提示词长度
|
||
var fillerWords = []string{
|
||
"详细", "简要", "请", "麻烦", "帮忙", "认真", "仔细", "全面", "系统性地",
|
||
"清晰地", "准确地", "深入地", "专业地", "客观地", "科学地", "有条理地",
|
||
"具体地", "简洁地", "完整地", "精确地", "有逻辑地", "有说服力地",
|
||
}
|
||
|
||
// 长文本填充段落,用于扩展长文档提示词
|
||
var longFillerParagraphs = []string{
|
||
"请确保分析全面深入,考虑各个方面的影响和变化。需要包含具体的数据和案例来支持你的观点,同时也要有理论依据。",
|
||
"在回答中,请注意逻辑性和连贯性,确保各部分之间有合理的过渡。同时,请尽量使用专业术语,但也要确保非专业人士能够理解。",
|
||
"希望你能够从多个角度进行分析,包括但不限于社会学、经济学、政治学、心理学等视角。同时,也请考虑不同群体、不同地区的差异性。",
|
||
"在分析中,请注意历史背景和未来趋势,将当前情况放在更广阔的时空背景下考量。同时,也请关注国际比较,了解全球范围内的相似与差异。",
|
||
"请在回答中提供一些可行的建议或解决方案,这些建议应该是具体的、可操作的,并且考虑到实施过程中可能遇到的困难和阻力。",
|
||
"在分析问题时,请注意考虑各种可能的因素,包括社会、经济、政治、文化、技术等方面。同时,也请思考这些因素之间的相互作用和影响。",
|
||
"请在回答中使用清晰的结构,可以考虑使用小标题、要点列表等方式组织内容,使读者能够更容易地理解和把握核心内容。",
|
||
"在分析过程中,请注意区分事实和观点,确保事实准确无误,观点有理有据。同时,也请尽量避免个人偏见,保持客观中立的态度。",
|
||
"请确保你的回答既有理论深度,又有实践指导意义。理论分析应该有学术依据,实践建议应该考虑可行性和有效性。",
|
||
"在回答中,请注意时效性,考虑最新的发展和变化。同时,也请关注长期趋势和基本规律,避免被短期波动所干扰。",
|
||
}
|
||
|
||
// 替换词,用于模板中的占位符
|
||
var replacements = map[string][]string{
|
||
"country": {"中国", "美国", "日本", "英国", "法国", "德国", "俄罗斯", "加拿大", "澳大利亚", "巴西", "印度"},
|
||
"concept": {"人工智能", "机器学习", "区块链", "量子计算", "云计算", "边缘计算", "物联网", "大数据", "虚拟现实", "增强现实"},
|
||
"topic": {"全球变暖", "可持续发展", "数字化转型", "远程工作", "网络安全", "太空探索", "生物技术", "清洁能源", "智慧城市", "数字隐私"},
|
||
"event": {"工业革命", "互联网诞生", "冷战", "登月", "柏林墙倒塌", "9/11事件", "COVID-19疫情", "阿拉伯之春", "2008金融危机", "人类基因组计划"},
|
||
}
|
||
|
||
// Generator 提示词生成器
|
||
type Generator struct {
|
||
config *config.Config
|
||
tokenizer *tiktoken.Tiktoken
|
||
rand *rand.Rand
|
||
}
|
||
|
||
// NewGenerator 创建新的提示词生成器
|
||
func NewGenerator(cfg *config.Config) (*Generator, error) {
|
||
// 初始化分词器
|
||
logger.Debug("初始化分词器,模型: %s", cfg.Tokenizer.Model)
|
||
tkm, err := tiktoken.EncodingForModel(cfg.Tokenizer.Model)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("初始化分词器失败: %w", err)
|
||
}
|
||
logger.Debug("分词器初始化成功")
|
||
|
||
// 初始化随机数生成器
|
||
source := rand.NewSource(time.Now().UnixNano())
|
||
|
||
logger.Debug("提示词模板配置 - 短咨询: %d个模板, 目标Token: %d",
|
||
len(cfg.PromptTemplates.Short.Templates), cfg.PromptTemplates.Short.TargetTokens)
|
||
logger.Debug("提示词模板配置 - 长文档: %d个模板, 目标Token: %d",
|
||
len(cfg.PromptTemplates.Long.Templates), cfg.PromptTemplates.Long.TargetTokens)
|
||
|
||
return &Generator{
|
||
config: cfg,
|
||
tokenizer: tkm,
|
||
rand: rand.New(source),
|
||
}, nil
|
||
}
|
||
|
||
// GeneratePrompt 生成指定类型的提示词
|
||
func (g *Generator) GeneratePrompt(promptType string) (string, error) {
|
||
var templateConfig config.PromptTypeConfig
|
||
|
||
// 根据类型选择模板配置
|
||
switch promptType {
|
||
case "short":
|
||
templateConfig = g.config.PromptTemplates.Short
|
||
case "long":
|
||
templateConfig = g.config.PromptTemplates.Long
|
||
default:
|
||
return "", fmt.Errorf("未知的提示词类型: %s", promptType)
|
||
}
|
||
|
||
// 检查模板是否存在
|
||
if len(templateConfig.Templates) == 0 {
|
||
return "", fmt.Errorf("提示词模板为空: %s", promptType)
|
||
}
|
||
|
||
// 随机选择一个模板
|
||
templateIndex := g.rand.Intn(len(templateConfig.Templates))
|
||
template := templateConfig.Templates[templateIndex]
|
||
logger.Debug("为%s类型选择模板 #%d: %s", promptType, templateIndex, template)
|
||
|
||
// 替换模板中的占位符
|
||
prompt := g.replacePlaceholders(template)
|
||
logger.Debug("替换占位符后: %s", prompt)
|
||
|
||
// 调整提示词长度至目标Token数
|
||
result, err := g.adjustPromptLength(prompt, templateConfig.TargetTokens)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 计算最终Token数
|
||
tokens := g.tokenizer.Encode(result, nil, nil)
|
||
logger.Debug("生成%s类型提示词成功,目标Token: %d, 实际Token: %d",
|
||
promptType, templateConfig.TargetTokens, len(tokens))
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// replacePlaceholders 替换模板中的占位符
|
||
func (g *Generator) replacePlaceholders(template string) string {
|
||
result := template
|
||
|
||
// 替换所有占位符
|
||
for placeholder, options := range replacements {
|
||
placeholderTag := fmt.Sprintf("{%s}", placeholder)
|
||
if strings.Contains(result, placeholderTag) {
|
||
// 随机选择一个替换选项
|
||
replacementIndex := g.rand.Intn(len(options))
|
||
replacement := options[replacementIndex]
|
||
logger.Debug("替换占位符 %s -> %s", placeholderTag, replacement)
|
||
result = strings.Replace(result, placeholderTag, replacement, -1)
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// adjustPromptLength 调整提示词长度至目标Token数
|
||
func (g *Generator) adjustPromptLength(prompt string, targetTokens int) (string, error) {
|
||
// 计算当前Token数
|
||
tokens := g.tokenizer.Encode(prompt, nil, nil)
|
||
currentTokens := len(tokens)
|
||
|
||
// 计算目标范围(±5%)
|
||
minTokens := int(float64(targetTokens) * 0.95)
|
||
maxTokens := int(float64(targetTokens) * 1.05)
|
||
|
||
logger.Debug("调整提示词长度 - 当前: %d, 目标范围: %d-%d", currentTokens, minTokens, maxTokens)
|
||
|
||
// 如果当前Token数在目标范围内,直接返回
|
||
if currentTokens >= minTokens && currentTokens <= maxTokens {
|
||
logger.Debug("提示词长度已在目标范围内,无需调整")
|
||
return prompt, nil
|
||
}
|
||
|
||
// 如果Token数过少,添加填充词
|
||
if currentTokens < minTokens {
|
||
logger.Debug("提示词长度过短,需要扩展")
|
||
return g.extendPrompt(prompt, targetTokens)
|
||
}
|
||
|
||
// 如果Token数过多,截断提示词
|
||
logger.Debug("提示词长度过长,需要截断")
|
||
return g.truncatePrompt(prompt, targetTokens)
|
||
}
|
||
|
||
// extendPrompt 通过添加填充词扩展提示词
|
||
func (g *Generator) extendPrompt(prompt string, targetTokens int) (string, error) {
|
||
result := prompt
|
||
|
||
// 计算当前Token数
|
||
tokens := g.tokenizer.Encode(result, nil, nil)
|
||
currentTokens := len(tokens)
|
||
|
||
// 计算目标范围
|
||
minTokens := int(float64(targetTokens) * 0.95)
|
||
|
||
// 根据目标token数选择不同的扩展策略
|
||
if targetTokens > 500 {
|
||
// 长文档提示词,使用段落填充
|
||
logger.Debug("使用段落填充扩展长文档提示词")
|
||
return g.extendPromptWithParagraphs(prompt, targetTokens)
|
||
}
|
||
|
||
// 短咨询提示词,使用填充词
|
||
logger.Debug("使用填充词扩展短咨询提示词")
|
||
|
||
// 添加填充词,直到达到目标Token数
|
||
fillerCount := 0
|
||
for currentTokens < minTokens {
|
||
// 随机选择一个填充词
|
||
fillerIndex := g.rand.Intn(len(fillerWords))
|
||
filler := fillerWords[fillerIndex]
|
||
|
||
// 在提示词前添加填充词
|
||
result = filler + " " + result
|
||
fillerCount++
|
||
|
||
// 重新计算Token数
|
||
tokens = g.tokenizer.Encode(result, nil, nil)
|
||
currentTokens = len(tokens)
|
||
|
||
// 防止无限循环
|
||
if currentTokens > int(float64(targetTokens)*1.05) {
|
||
// 如果添加填充词后超过了最大Token数,回退并尝试截断
|
||
logger.Debug("添加%d个填充词后超过最大Token数,尝试截断", fillerCount)
|
||
return g.truncatePrompt(result, targetTokens)
|
||
}
|
||
|
||
if fillerCount > 100 {
|
||
logger.Warn("添加填充词超过100次,尝试使用段落填充")
|
||
return g.extendPromptWithParagraphs(prompt, targetTokens)
|
||
}
|
||
}
|
||
|
||
logger.Debug("扩展提示词成功,添加了%d个填充词,当前Token数: %d", fillerCount, currentTokens)
|
||
return result, nil
|
||
}
|
||
|
||
// extendPromptWithParagraphs 通过添加段落扩展长文档提示词
|
||
func (g *Generator) extendPromptWithParagraphs(prompt string, targetTokens int) (string, error) {
|
||
result := prompt
|
||
|
||
// 计算当前Token数
|
||
tokens := g.tokenizer.Encode(result, nil, nil)
|
||
currentTokens := len(tokens)
|
||
|
||
// 计算目标范围
|
||
minTokens := int(float64(targetTokens) * 0.95)
|
||
|
||
// 添加段落,直到达到目标Token数
|
||
paragraphCount := 0
|
||
usedParagraphs := make(map[int]bool)
|
||
|
||
for currentTokens < minTokens && paragraphCount < 20 {
|
||
// 随机选择一个未使用的段落
|
||
var paragraphIndex int
|
||
for {
|
||
paragraphIndex = g.rand.Intn(len(longFillerParagraphs))
|
||
if !usedParagraphs[paragraphIndex] {
|
||
usedParagraphs[paragraphIndex] = true
|
||
break
|
||
}
|
||
// 如果所有段落都已使用,重置使用记录
|
||
if len(usedParagraphs) >= len(longFillerParagraphs) {
|
||
usedParagraphs = make(map[int]bool)
|
||
}
|
||
}
|
||
|
||
paragraph := longFillerParagraphs[paragraphIndex]
|
||
|
||
// 添加段落
|
||
if paragraphCount == 0 {
|
||
result = result + "\n\n" + paragraph
|
||
} else {
|
||
result = result + "\n\n" + paragraph
|
||
}
|
||
paragraphCount++
|
||
|
||
// 重新计算Token数
|
||
tokens = g.tokenizer.Encode(result, nil, nil)
|
||
currentTokens = len(tokens)
|
||
|
||
logger.Debug("添加第%d个段落后,当前Token数: %d", paragraphCount, currentTokens)
|
||
|
||
// 如果已经达到目标范围,退出循环
|
||
if currentTokens >= minTokens {
|
||
break
|
||
}
|
||
}
|
||
|
||
// 如果添加了最大段落数仍未达到目标,给出警告
|
||
if currentTokens < minTokens {
|
||
logger.Warn("即使添加了%d个段落,仍未达到目标Token数(当前: %d, 目标: %d)",
|
||
paragraphCount, currentTokens, targetTokens)
|
||
} else {
|
||
logger.Debug("扩展长文档提示词成功,添加了%d个段落,当前Token数: %d", paragraphCount, currentTokens)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// truncatePrompt 通过截断缩短提示词
|
||
func (g *Generator) truncatePrompt(prompt string, targetTokens int) (string, error) {
|
||
// 计算目标范围
|
||
maxTokens := int(float64(targetTokens) * 1.05)
|
||
|
||
// 编码提示词
|
||
tokens := g.tokenizer.Encode(prompt, nil, nil)
|
||
|
||
// 如果Token数超过最大值,截断
|
||
if len(tokens) > maxTokens {
|
||
// 保留目标Token数
|
||
originalLength := len(tokens)
|
||
tokens = tokens[:maxTokens]
|
||
|
||
// 解码截断后的Token
|
||
result := g.tokenizer.Decode(tokens)
|
||
logger.Debug("截断提示词成功,从%d个Token截断到%d个Token", originalLength, len(tokens))
|
||
return result, nil
|
||
}
|
||
|
||
return prompt, nil
|
||
}
|
||
|
||
// SelectPromptType 根据权重随机选择提示词类型
|
||
func (g *Generator) SelectPromptType() string {
|
||
// 计算累积权重
|
||
var cumulativeWeights []float64
|
||
var sum float64
|
||
|
||
for _, req := range g.config.Requests {
|
||
sum += req.Weight
|
||
cumulativeWeights = append(cumulativeWeights, sum)
|
||
}
|
||
|
||
// 生成随机数
|
||
r := g.rand.Float64() * sum
|
||
|
||
// 根据随机数选择提示词类型
|
||
for i, weight := range cumulativeWeights {
|
||
if r <= weight {
|
||
return g.config.Requests[i].Type
|
||
}
|
||
}
|
||
|
||
// 默认返回第一个类型
|
||
return g.config.Requests[0].Type
|
||
}
|