feat: tokenizer-aware Tracker.CountTokens/CountMessages replaces EstimateMessages in compaction
This commit is contained in:
@@ -2,6 +2,7 @@ package context
|
||||
|
||||
import (
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tokenizer"
|
||||
)
|
||||
|
||||
// TokenState indicates how close to the context limit we are.
|
||||
@@ -40,6 +41,8 @@ type Tracker struct {
|
||||
// Configurable buffers
|
||||
autocompactBuffer int64
|
||||
warningBuffer int64
|
||||
|
||||
tok *tokenizer.Tokenizer
|
||||
}
|
||||
|
||||
func NewTracker(maxTokens int64) *Tracker {
|
||||
@@ -125,6 +128,48 @@ func (t *Tracker) PreEstimate(tokens int64) {
|
||||
t.current += tokens
|
||||
}
|
||||
|
||||
// SetTokenizer sets the tokenizer used for accurate token counting.
|
||||
func (t *Tracker) SetTokenizer(tok *tokenizer.Tokenizer) {
|
||||
t.tok = tok
|
||||
}
|
||||
|
||||
// CountTokens returns the token count for text using the configured tokenizer,
|
||||
// falling back to the len/4 heuristic if no tokenizer is set.
|
||||
func (t *Tracker) CountTokens(text string) int64 {
|
||||
if t.tok != nil {
|
||||
return int64(t.tok.Count(text))
|
||||
}
|
||||
return EstimateTokens(text)
|
||||
}
|
||||
|
||||
// CountMessages returns the token count for a message slice.
|
||||
func (t *Tracker) CountMessages(msgs []message.Message) int64 {
|
||||
var total int64
|
||||
for _, msg := range msgs {
|
||||
for _, c := range msg.Content {
|
||||
switch c.Type {
|
||||
case message.ContentText:
|
||||
total += t.CountTokens(c.Text)
|
||||
case message.ContentToolCall:
|
||||
total += 50
|
||||
if c.ToolCall != nil {
|
||||
total += t.CountTokens(string(c.ToolCall.Arguments))
|
||||
}
|
||||
case message.ContentToolResult:
|
||||
if c.ToolResult != nil {
|
||||
total += t.CountTokens(c.ToolResult.Content)
|
||||
}
|
||||
case message.ContentThinking:
|
||||
if c.Thinking != nil {
|
||||
total += t.CountTokens(c.Thinking.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
total += 4 // per-message overhead (role, separators)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// EstimateTokens returns a rough token estimate for a text string.
|
||||
// Heuristic: ~4 characters per token for English text.
|
||||
func EstimateTokens(text string) int64 {
|
||||
|
||||
29
internal/context/tracker_tokenizer_test.go
Normal file
29
internal/context/tracker_tokenizer_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package context_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tokenizer"
|
||||
)
|
||||
|
||||
func TestTracker_CountTokensWithTokenizer(t *testing.T) {
|
||||
tok := tokenizer.New("cl100k_base")
|
||||
tr := gnomactx.NewTracker(100000)
|
||||
tr.SetTokenizer(tok)
|
||||
|
||||
n := tr.CountTokens("Hello world")
|
||||
// tiktoken gives 2; heuristic gives (11+3)/4 = 3
|
||||
if n < 1 || n > 5 {
|
||||
t.Errorf("unexpected count: %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTracker_CountTokensNilTokenizerFallsBack(t *testing.T) {
|
||||
tr := gnomactx.NewTracker(100000)
|
||||
// nil tokenizer — should use heuristic
|
||||
n := tr.CountTokens("Hello world")
|
||||
if n <= 0 {
|
||||
t.Errorf("expected positive count, got %d", n)
|
||||
}
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func (w *Window) doCompact(force bool) (bool, error) {
|
||||
|
||||
// Re-estimate tokens from actual message content rather than using a
|
||||
// message-count ratio (which is unrelated to token count).
|
||||
w.tracker.Set(EstimateMessages(compacted))
|
||||
w.tracker.Set(w.tracker.CountMessages(compacted))
|
||||
|
||||
w.logger.Info("compaction complete",
|
||||
"messages_before", originalLen,
|
||||
|
||||
Reference in New Issue
Block a user