llm-api-benchmark-tool/test/tokenizer_test.go

63 lines
2.6 KiB
Go

package test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tiktoken-go/tokenizer" // Assuming this is the correct import path
token "llm-api-benchmark-tool/pkg/tokenizer" // Target package (to be created)
)
const (
tolerance = 0.05 // 5%
modelName = "gpt-4" // Or another common model for tiktoken
)
// Helper function to count tokens
func countTokens(t *testing.T, text string, model string) int {
codec, err := tokenizer.Get(tokenizer.Cl100kBase) // Or appropriate codec based on model
require.NoError(t, err, "Failed to get tokenizer codec")
ids, _, err := codec.Encode(text)
require.NoError(t, err, "Failed to encode text")
return len(ids)
}
func TestGeneratePrompt50Tokens(t *testing.T) {
targetTokens := 50
minTokens := int(float64(targetTokens) * (1 - tolerance))
maxTokens := int(float64(targetTokens) * (1 + tolerance))
// Assuming GeneratePrompt takes the target count and maybe a model/codec hint
prompt, err := token.GeneratePrompt(targetTokens)
require.NoError(t, err, "GeneratePrompt should not return an error for 50 tokens")
require.NotEmpty(t, prompt, "Generated prompt should not be empty")
actualTokens := countTokens(t, prompt, modelName)
assert.GreaterOrEqual(t, actualTokens, minTokens, fmt.Sprintf("Token count %d should be >= %d (target %d)", actualTokens, minTokens, targetTokens))
assert.LessOrEqual(t, actualTokens, maxTokens, fmt.Sprintf("Token count %d should be <= %d (target %d)", actualTokens, maxTokens, targetTokens))
t.Logf("Generated 50-token prompt (target: %d±%.0f%%): %d tokens", targetTokens, tolerance*100, actualTokens)
}
func TestGeneratePrompt1000Tokens(t *testing.T) {
targetTokens := 1000
minTokens := int(float64(targetTokens) * (1 - tolerance))
maxTokens := int(float64(targetTokens) * (1 + tolerance))
// Assuming GeneratePrompt takes the target count
prompt, err := token.GeneratePrompt(targetTokens)
require.NoError(t, err, "GeneratePrompt should not return an error for 1000 tokens")
require.NotEmpty(t, prompt, "Generated prompt should not be empty")
actualTokens := countTokens(t, prompt, modelName)
assert.GreaterOrEqual(t, actualTokens, minTokens, fmt.Sprintf("Token count %d should be >= %d (target %d)", actualTokens, minTokens, targetTokens))
assert.LessOrEqual(t, actualTokens, maxTokens, fmt.Sprintf("Token count %d should be <= %d (target %d)", actualTokens, maxTokens, targetTokens))
t.Logf("Generated 1000-token prompt (target: %d±%.0f%%): %d tokens", targetTokens, tolerance*100, actualTokens)
}
// TODO: Add test for invalid target token count (e.g., 0 or negative) if applicable