147 lines
2.7 KiB
Go
147 lines
2.7 KiB
Go
package codec
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
|
|
"github.com/dlclark/regexp2"
|
|
)
|
|
|
|
type Codec struct {
|
|
vocabulary vocab
|
|
reverseVocabulary reverse
|
|
specialTokens map[string]uint
|
|
splitRegexp *regexp2.Regexp
|
|
name string
|
|
}
|
|
|
|
func (c *Codec) GetName() string {
|
|
return c.name
|
|
}
|
|
|
|
// Count returns the number of tokens in the input string.
|
|
func (c *Codec) Count(input string) (int, error) {
|
|
var count int
|
|
|
|
err := c.tokenize(input, func(_ uint) {
|
|
count++
|
|
})
|
|
|
|
return count, err
|
|
}
|
|
|
|
// Encode returns the token IDs and tokens for the input string.
|
|
func (c *Codec) Encode(input string) ([]uint, []string, error) {
|
|
|
|
var ids []uint
|
|
var tokens []string
|
|
|
|
err := c.tokenize(input, func(id uint) {
|
|
ids = append(ids, id)
|
|
tokens = append(tokens, c.reverseVocabulary[id])
|
|
})
|
|
|
|
return ids, tokens, err
|
|
}
|
|
|
|
func (c *Codec) tokenize(input string, yield func(uint)) error {
|
|
match, err := c.splitRegexp.FindStringMatch(input)
|
|
if err != nil {
|
|
return fmt.Errorf("error matching: %v", err)
|
|
}
|
|
for match != nil {
|
|
piece := match.String()
|
|
if id, ok := c.vocabulary[piece]; ok {
|
|
yield(id)
|
|
} else {
|
|
parts := c.mergePairs(piece)
|
|
|
|
for i := range len(parts) - 1 {
|
|
token := piece[parts[i].offset:parts[i+1].offset]
|
|
yield(c.vocabulary[token])
|
|
}
|
|
}
|
|
match, err = c.splitRegexp.FindNextMatch(match)
|
|
if err != nil {
|
|
return fmt.Errorf("error matching: %v", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Codec) Decode(tokens []uint) (string, error) {
|
|
if c.reverseVocabulary == nil {
|
|
c.reverseVocabulary = make(map[uint]string)
|
|
for k, v := range c.vocabulary {
|
|
c.reverseVocabulary[v] = k
|
|
}
|
|
}
|
|
|
|
var out string
|
|
for _, t := range tokens {
|
|
piece, ok := c.reverseVocabulary[t]
|
|
if !ok {
|
|
return "", fmt.Errorf("invalid token: %d", t)
|
|
}
|
|
out += piece
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
type part struct {
|
|
offset int
|
|
rank uint
|
|
}
|
|
|
|
func (c *Codec) mergePairs(piece string) []part {
|
|
parts := make([]part, len(piece)+1)
|
|
for i := range len(parts) {
|
|
parts[i] = part{i, math.MaxUint}
|
|
}
|
|
|
|
getRank := func(index, skip int) uint {
|
|
if index+skip+2 < len(parts) {
|
|
start := parts[index].offset
|
|
end := parts[index+skip+2].offset
|
|
if rank, ok := c.vocabulary[piece[start:end]]; ok {
|
|
return rank
|
|
}
|
|
}
|
|
return math.MaxUint
|
|
}
|
|
|
|
for i := 0; i < len(parts)-2; i++ {
|
|
parts[i].rank = getRank(i, 0)
|
|
}
|
|
|
|
for {
|
|
if len(parts) == 1 {
|
|
break
|
|
}
|
|
|
|
minRank := uint(math.MaxUint)
|
|
minIndex := 0
|
|
for i, p := range parts[:len(parts)-1] {
|
|
if p.rank < minRank {
|
|
minRank = p.rank
|
|
minIndex = i
|
|
}
|
|
}
|
|
|
|
if minRank == math.MaxUint {
|
|
break
|
|
}
|
|
|
|
parts[minIndex].rank = getRank(minIndex, 1)
|
|
|
|
if minIndex > 0 {
|
|
parts[minIndex-1].rank = getRank(minIndex-1, 1)
|
|
}
|
|
|
|
parts = append(parts[:minIndex+1], parts[minIndex+2:]...)
|
|
}
|
|
|
|
return parts
|
|
}
|