package test import ( "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tiktoken-go/tokenizer" // Assuming this is the correct import path token "llm-api-benchmark-tool/pkg/tokenizer" // Target package (to be created) ) const ( tolerance = 0.05 // 5% modelName = "gpt-4" // Or another common model for tiktoken ) // Helper function to count tokens func countTokens(t *testing.T, text string, model string) int { codec, err := tokenizer.Get(tokenizer.Cl100kBase) // Or appropriate codec based on model require.NoError(t, err, "Failed to get tokenizer codec") ids, _, err := codec.Encode(text) require.NoError(t, err, "Failed to encode text") return len(ids) } func TestGeneratePrompt50Tokens(t *testing.T) { targetTokens := 50 minTokens := int(float64(targetTokens) * (1 - tolerance)) maxTokens := int(float64(targetTokens) * (1 + tolerance)) // Assuming GeneratePrompt takes the target count and maybe a model/codec hint prompt, err := token.GeneratePrompt(targetTokens) require.NoError(t, err, "GeneratePrompt should not return an error for 50 tokens") require.NotEmpty(t, prompt, "Generated prompt should not be empty") actualTokens := countTokens(t, prompt, modelName) assert.GreaterOrEqual(t, actualTokens, minTokens, fmt.Sprintf("Token count %d should be >= %d (target %d)", actualTokens, minTokens, targetTokens)) assert.LessOrEqual(t, actualTokens, maxTokens, fmt.Sprintf("Token count %d should be <= %d (target %d)", actualTokens, maxTokens, targetTokens)) t.Logf("Generated 50-token prompt (target: %d±%.0f%%): %d tokens", targetTokens, tolerance*100, actualTokens) } func TestGeneratePrompt1000Tokens(t *testing.T) { targetTokens := 1000 minTokens := int(float64(targetTokens) * (1 - tolerance)) maxTokens := int(float64(targetTokens) * (1 + tolerance)) // Assuming GeneratePrompt takes the target count prompt, err := token.GeneratePrompt(targetTokens) require.NoError(t, err, "GeneratePrompt should not return an error for 1000 tokens") require.NotEmpty(t, prompt, "Generated prompt should not be empty") actualTokens := countTokens(t, prompt, modelName) assert.GreaterOrEqual(t, actualTokens, minTokens, fmt.Sprintf("Token count %d should be >= %d (target %d)", actualTokens, minTokens, targetTokens)) assert.LessOrEqual(t, actualTokens, maxTokens, fmt.Sprintf("Token count %d should be <= %d (target %d)", actualTokens, maxTokens, targetTokens)) t.Logf("Generated 1000-token prompt (target: %d±%.0f%%): %d tokens", targetTokens, tolerance*100, actualTokens) } // TODO: Add test for invalid target token count (e.g., 0 or negative) if applicable