package prompt import ( "fmt" "math/rand" "strings" "time" "github.com/tiktoken-go/tokenizer" "mingor/llm-api-benchmark-tool/internal/config" ) // PromptGenerator 生成短/长提示词 // 使用 tiktoken-go tokenizer 计算 token 数 type PromptGenerator struct { cfg *config.Config tok *tokenizer.Tokenizer } // NewPromptGenerator 初始化生成器,model 用于 tokenizer func NewPromptGenerator(model string, cfg *config.Config) (*PromptGenerator, error) { tok, err := tokenizer.NewTokenizer(model) if err != nil { return nil, err } rand.Seed(time.Now().UnixNano()) return &PromptGenerator{cfg: cfg, tok: tok}, nil } // Generate 根据类型(tp)和占位值(values)生成提示词 // tp 可为 "short" 或 "long" 等配置中的键 func (g *PromptGenerator) Generate(tp string, values map[string]string) (string, error) { pt, ok := g.cfg.PromptTemplates[tp] if !ok { return "", fmt.Errorf("unknown prompt type: %s", tp) } if len(pt.Templates) == 0 { return "", fmt.Errorf("no templates for type: %s", tp) } tpl := pt.Templates[rand.Intn(len(pt.Templates))] prompt := fillTemplate(tpl, values) // 计算 token 数 toks, err := g.tok.Encode(prompt) if err != nil { return "", err } target := pt.TargetTokens tol := int(float64(target) * 0.05) cnt := len(toks) if cnt < target-tol || cnt > target+tol { // 长度不在范围,用户可根据需要调整 } return prompt, nil } // CountTokens 返回文本的 token 数 func (g *PromptGenerator) CountTokens(text string) (int, error) { toks, err := g.tok.Encode(text) if err != nil { return 0, err } return len(toks), nil } // fillTemplate 用 map[string]string 替换模板中的 {key} func fillTemplate(tpl string, values map[string]string) string { out := tpl for k, v := range values { out = strings.ReplaceAll(out, "{"+k+"}", v) } return out }