feat: tiktoken tokenizer — accurate BPE token counting with provider-aware encoding
This commit is contained in:
2
go.mod
2
go.mod
@@ -39,6 +39,7 @@ require (
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.18.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
@@ -47,6 +48,7 @@ require (
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -101,6 +101,8 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
|
||||
64
internal/tokenizer/tokenizer.go
Normal file
64
internal/tokenizer/tokenizer.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
tiktoken "github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// Tokenizer counts tokens using a tiktoken BPE encoding.
|
||||
// Falls back to len/4 heuristic if the encoding fails to load.
|
||||
type Tokenizer struct {
|
||||
encoding string
|
||||
enc *tiktoken.Tiktoken
|
||||
mu sync.Mutex
|
||||
loaded bool
|
||||
warnOnce sync.Once
|
||||
}
|
||||
|
||||
// New creates a Tokenizer for the given tiktoken encoding name (e.g. "cl100k_base").
|
||||
func New(encoding string) *Tokenizer {
|
||||
return &Tokenizer{encoding: encoding}
|
||||
}
|
||||
|
||||
// ForProvider returns a Tokenizer appropriate for the named provider.
|
||||
func ForProvider(providerName string) *Tokenizer {
|
||||
switch providerName {
|
||||
case "anthropic", "openai":
|
||||
return New("cl100k_base")
|
||||
default:
|
||||
// mistral, google, ollama, llamacpp, unknown
|
||||
return New("o200k_base")
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of tokens for text using the configured encoding.
|
||||
// Falls back to len(text)/4 if encoding is unavailable.
|
||||
func (t *Tokenizer) Count(text string) int {
|
||||
if enc := t.getEncoding(); enc != nil {
|
||||
tokens := enc.Encode(text, nil, nil)
|
||||
return len(tokens)
|
||||
}
|
||||
// heuristic fallback
|
||||
return (len(text) + 3) / 4
|
||||
}
|
||||
|
||||
func (t *Tokenizer) getEncoding() *tiktoken.Tiktoken {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.loaded {
|
||||
return t.enc // may be nil if failed
|
||||
}
|
||||
t.loaded = true
|
||||
enc, err := tiktoken.GetEncoding(t.encoding)
|
||||
if err != nil {
|
||||
t.warnOnce.Do(func() {
|
||||
slog.Warn("tiktoken encoding unavailable, falling back to heuristic",
|
||||
"encoding", t.encoding, "error", err)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
t.enc = enc
|
||||
return enc
|
||||
}
|
||||
47
internal/tokenizer/tokenizer_test.go
Normal file
47
internal/tokenizer/tokenizer_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package tokenizer_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/tokenizer"
|
||||
)
|
||||
|
||||
func TestTokenizer_CountKnownText(t *testing.T) {
|
||||
tok := tokenizer.New("cl100k_base")
|
||||
|
||||
// "Hello world" is 2 tokens in cl100k_base
|
||||
n := tok.Count("Hello world")
|
||||
if n < 1 || n > 5 {
|
||||
t.Errorf("unexpected token count for 'Hello world': %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenizer_FallbackOnBadEncoding(t *testing.T) {
|
||||
tok := tokenizer.New("nonexistent_encoding_xyz")
|
||||
// Must not panic; falls back to heuristic
|
||||
n := tok.Count("some text here")
|
||||
if n <= 0 {
|
||||
t.Errorf("expected positive count, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForProvider_KnownProviders(t *testing.T) {
|
||||
cases := []string{"anthropic", "openai", "mistral", "google", "ollama", "llamacpp", "unknown"}
|
||||
for _, prov := range cases {
|
||||
tok := tokenizer.ForProvider(prov)
|
||||
n := tok.Count("test input")
|
||||
if n <= 0 {
|
||||
t.Errorf("provider %q: expected positive count, got %d", prov, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenizer_CodeCountsReasonably(t *testing.T) {
|
||||
tok := tokenizer.New("cl100k_base")
|
||||
code := `func main() { fmt.Println("hello") }`
|
||||
n := tok.Count(code)
|
||||
// Should be between 5 and 20 tokens for this snippet
|
||||
if n < 5 || n > 20 {
|
||||
t.Errorf("code token count out of expected range: %d", n)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user