diff --git a/internal/provider/google/provider.go b/internal/provider/google/provider.go index dab778e..9499c08 100644 --- a/internal/provider/google/provider.go +++ b/internal/provider/google/provider.go @@ -3,7 +3,9 @@ package google import ( "context" "encoding/json" + "errors" "fmt" + "log/slog" "os" "path/filepath" "time" @@ -16,6 +18,13 @@ import ( "google.golang.org/genai" ) +// cloudPlatformScope is the standard OAuth scope used for Vertex AI and +// the Gemini API on Google Cloud. credentials.DetectDefault REQUIRES at +// least Scopes or Audience to be set — calling it with nil options +// returns "credentials: options must be provided" and the ADC branch +// becomes dead code. +const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" + const defaultModel = "gemini-3.5-flash" // Provider implements provider.Provider for Google's Gemini API. @@ -77,6 +86,16 @@ func (tp *fileTokenProvider) Token(ctx context.Context) (*auth.Token, error) { return nil, fmt.Errorf("no access token in credentials file") } + // We don't perform an OAuth refresh exchange ourselves; the upstream + // CLI (gemini / antigravity) refreshes the file out-of-band. If we're + // asked for a token after expiry and the file hasn't been refreshed, + // fail loudly with an actionable message instead of sending a known- + // dead bearer that the API would reject with a confusing 401. + expiry := creds.Expiry() + if !expiry.IsZero() && time.Now().After(expiry) { + return nil, fmt.Errorf("oauth token at %s is expired (re-run the upstream CLI to refresh)", tp.filePath) + } + tokenType := creds.TokenType if tokenType == "" { tokenType = creds.TokenType2 @@ -88,7 +107,7 @@ func (tp *fileTokenProvider) Token(ctx context.Context) (*auth.Token, error) { return &auth.Token{ Value: tokVal, Type: tokenType, - Expiry: creds.Expiry(), + Expiry: expiry, }, nil } @@ -109,38 +128,108 @@ func expandHome(path string) string { return path } +// errCredentialMissing wraps os.ErrNotExist for the precedence walker so +// the "file isn't there" case is silent while permission / parse / empty- +// token failures get a slog.Warn (they typically indicate a misconfigured +// install — chmod 0600 on the wrong file, half-written JSON, etc.). +var errCredentialMissing = errors.New("credential file not present") + func tryLoadOAuthCredentials(filePath string) (*auth.Credentials, error) { - filePath = expandHome(filePath) - if _, err := os.Stat(filePath); err != nil { + expanded := expandHome(filePath) + if _, err := os.Stat(expanded); err != nil { + if os.IsNotExist(err) { + return nil, errCredentialMissing + } + slog.Warn("google oauth: stat failed", "path", expanded, "err", err) return nil, err } - data, err := os.ReadFile(filePath) + data, err := os.ReadFile(expanded) if err != nil { + slog.Warn("google oauth: read failed", "path", expanded, "err", err) return nil, err } var creds oauthCreds if err := json.Unmarshal(data, &creds); err != nil { + slog.Warn("google oauth: parse failed", "path", expanded, "err", err) return nil, err } tokVal := creds.Token() if tokVal == "" { - return nil, fmt.Errorf("empty access token") + slog.Warn("google oauth: empty access token", "path", expanded) + return nil, fmt.Errorf("empty access token in %s", expanded) } expiry := creds.Expiry() if !expiry.IsZero() && time.Now().After(expiry) { - return nil, fmt.Errorf("token expired") + slog.Warn("google oauth: token expired", "path", expanded, "expired_at", expiry) + return nil, fmt.Errorf("token in %s expired at %s", expanded, expiry.Format(time.RFC3339)) } - tp := &fileTokenProvider{filePath: filePath} + tp := &fileTokenProvider{filePath: expanded} return auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: tp, }), nil } +// CredentialSource labels the origin of the auth credential returned by +// selectOAuthCredentials. Used by tests and diagnostics. +type CredentialSource string + +const ( + CredentialSourceNone CredentialSource = "" + CredentialSourceAgy CredentialSource = "agy" + CredentialSourceGemini CredentialSource = "gemini" + CredentialSourceADC CredentialSource = "adc" +) + +// agyCredentialPaths lists the OAuth credential file locations that the +// agy / antigravity CLIs are known to write to. First match wins. +var agyCredentialPaths = []string{ + "~/.config/google-antigravity/session.json", + "~/.config/google-antigravity/oauth_creds.json", + "~/.config/antigravity/session.json", + "~/.config/antigravity/oauth_creds.json", + "~/.config/antigravity-cli/session.json", + "~/.config/antigravity-cli/oauth_creds.json", + "~/.gemini/antigravity-cli/oauth_creds.json", +} + +// geminiCredentialPaths lists the locations the official gemini CLI uses. +var geminiCredentialPaths = []string{ + "~/.gemini/oauth_creds.json", + "~/.config/gemini-cli/oauth_creds.json", +} + +// selectOAuthCredentials walks the precedence chain (agy → gemini → ADC) +// and returns the first usable credential plus a tag identifying which +// source it came from. Tests use the tag to verify precedence; the New() +// builder discards it. +func selectOAuthCredentials() (*auth.Credentials, CredentialSource, error) { + for _, path := range agyCredentialPaths { + if c, err := tryLoadOAuthCredentials(path); err == nil { + return c, CredentialSourceAgy, nil + } + } + for _, path := range geminiCredentialPaths { + if c, err := tryLoadOAuthCredentials(path); err == nil { + return c, CredentialSourceGemini, nil + } + } + // Application Default Credentials. DetectDefault REQUIRES scopes — + // passing nil makes the call always error, leaving ADC unreachable. + c, err := credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{cloudPlatformScope}, + }) + if err == nil { + return c, CredentialSourceADC, nil + } + slog.Debug("google adc: DetectDefault failed", "err", err) + return nil, CredentialSourceNone, fmt.Errorf("no google credentials found (tried agy session, gemini session, and ADC)") +} + // New creates a Google GenAI provider from config. func New(cfg provider.ProviderConfig) (provider.Provider, error) { var client *genai.Client @@ -155,50 +244,11 @@ func New(cfg provider.ProviderConfig) (provider.Provider, error) { return nil, fmt.Errorf("google: create client (Gemini API): %w", err) } } else { - // Precedence: agy > gemini > adc - var creds *auth.Credentials - - // 1. Agy credentials - agyPaths := []string{ - "~/.config/google-antigravity/session.json", - "~/.config/google-antigravity/oauth_creds.json", - "~/.config/antigravity/session.json", - "~/.config/antigravity/oauth_creds.json", - "~/.config/antigravity-cli/session.json", - "~/.config/antigravity-cli/oauth_creds.json", - "~/.gemini/antigravity-cli/oauth_creds.json", - } - for _, path := range agyPaths { - if c, err := tryLoadOAuthCredentials(path); err == nil { - creds = c - break - } - } - - // 2. Gemini credentials - if creds == nil { - geminiPaths := []string{ - "~/.gemini/oauth_creds.json", - "~/.config/gemini-cli/oauth_creds.json", - } - for _, path := range geminiPaths { - if c, err := tryLoadOAuthCredentials(path); err == nil { - creds = c - break - } - } - } - - // 3. Application Default Credentials (ADC) - if creds == nil { - if c, err := credentials.DetectDefault(nil); err == nil { - creds = c - } - } - - if creds == nil { - return nil, fmt.Errorf("google: no credentials found (tried agy session, gemini session, and ADC)") + creds, source, selErr := selectOAuthCredentials() + if selErr != nil { + return nil, fmt.Errorf("google: %w", selErr) } + slog.Debug("google auth: credential selected", "source", source) // Resolve Project ID var projectID string diff --git a/internal/provider/google/provider_test.go b/internal/provider/google/provider_test.go index 26fa43d..acd14d2 100644 --- a/internal/provider/google/provider_test.go +++ b/internal/provider/google/provider_test.go @@ -5,10 +5,13 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "testing" "time" - "somegit.dev/Owlibou/gnoma/internal/provider" + "cloud.google.com/go/auth" + + _ "somegit.dev/Owlibou/gnoma/internal/provider" ) func TestTryLoadOAuthCredentials_Formats(t *testing.T) { @@ -97,30 +100,15 @@ func TestTryLoadOAuthCredentials_Formats(t *testing.T) { } } -func TestNew_Precedence(t *testing.T) { - // We will override the HOME env var in the test to control the expanded path. - origHome := os.Getenv("HOME") - defer func() { - if err := os.Setenv("HOME", origHome); err != nil { - t.Errorf("failed to restore HOME env var: %v", err) - } - }() +func TestSelectOAuthCredentials_Precedence(t *testing.T) { + // Override HOME so expandHome() resolves into a sandbox dir. + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) - tmpHome, err := os.MkdirTemp("", "gnoma-home-test-*") - if err != nil { - t.Fatalf("failed to create temp home dir: %v", err) - } - defer os.RemoveAll(tmpHome) - - if err := os.Setenv("HOME", tmpHome); err != nil { - t.Fatalf("failed to set HOME env var: %v", err) - } - - // Helper to write a mock credentials file writeCreds := func(relPath, tokenVal string) { absPath := filepath.Join(tmpHome, relPath) if err := os.MkdirAll(filepath.Dir(absPath), 0755); err != nil { - t.Fatalf("failed to create dir: %v", err) + t.Fatalf("mkdir: %v", err) } data := oauthCreds{ AccessToken: tokenVal, @@ -128,50 +116,117 @@ func TestNew_Precedence(t *testing.T) { } bz, err := json.Marshal(data) if err != nil { - t.Fatalf("failed to marshal: %v", err) + t.Fatal(err) } - if err := os.WriteFile(absPath, bz, 0644); err != nil { - t.Fatalf("failed to write file: %v", err) + if err := os.WriteFile(absPath, bz, 0600); err != nil { + t.Fatalf("write: %v", err) } } - // 1. Setup both agy and gemini. agy should take precedence. - // We use the first path of agyPaths: "~/.config/google-antigravity/session.json" - // and geminiPaths: "~/.gemini/oauth_creds.json" - writeCreds(filepath.Join(".config", "google-antigravity", "session.json"), "token-agy") - writeCreds(filepath.Join(".gemini", "oauth_creds.json"), "token-gemini") - - cfg := provider.ProviderConfig{ - Options: map[string]interface{}{ - "project": "test-project-123", - "location": "us-central1", - }, + tokenOf := func(c *auth.Credentials) string { + t.Helper() + tok, err := c.Token(context.Background()) + if err != nil { + t.Fatalf("Token: %v", err) + } + return tok.Value } - p, err := New(cfg) - if err != nil { - t.Fatalf("New() with both creds failed: %v", err) - } + t.Run("agy beats gemini when both present", func(t *testing.T) { + // Fresh sandbox per subtest to avoid leftover files. + sub := t.TempDir() + t.Setenv("HOME", sub) + // Use the first agy path and the first gemini path. + writeAt := func(rel, tok string) { + abs := filepath.Join(sub, rel) + if err := os.MkdirAll(filepath.Dir(abs), 0755); err != nil { + t.Fatal(err) + } + bz, _ := json.Marshal(oauthCreds{ + AccessToken: tok, + ExpiryDate: time.Now().Add(time.Hour).Unix(), + }) + if err := os.WriteFile(abs, bz, 0600); err != nil { + t.Fatal(err) + } + } + writeAt(filepath.Join(".config", "google-antigravity", "session.json"), "token-agy") + writeAt(filepath.Join(".gemini", "oauth_creds.json"), "token-gemini") - googleProv, ok := p.(*Provider) - if !ok { - t.Fatalf("expected *Provider, got %T", p) - } + creds, source, err := selectOAuthCredentials() + if err != nil { + t.Fatalf("selectOAuthCredentials: %v", err) + } + if source != CredentialSourceAgy { + t.Errorf("source = %q, want %q", source, CredentialSourceAgy) + } + if got := tokenOf(creds); got != "token-agy" { + t.Errorf("loaded token = %q, want token-agy (agy precedence violated)", got) + } + }) - // Use googleProv's client to check the configured token (by calling Credentials.Token) - // We can't access client.Credentials directly as it might be unexported/not exposed, but we can verify the client config or test credentials directly. - // Actually, we can just test the tryLoadOAuthCredentials lookup logic or call New and check errors. - // Let's verify we get no error. - _ = googleProv + t.Run("falls back to gemini when agy missing", func(t *testing.T) { + sub := t.TempDir() + t.Setenv("HOME", sub) + // Only gemini file present. + geminiPath := filepath.Join(sub, ".gemini", "oauth_creds.json") + if err := os.MkdirAll(filepath.Dir(geminiPath), 0755); err != nil { + t.Fatal(err) + } + bz, _ := json.Marshal(oauthCreds{ + AccessToken: "token-gemini-only", + ExpiryDate: time.Now().Add(time.Hour).Unix(), + }) + if err := os.WriteFile(geminiPath, bz, 0600); err != nil { + t.Fatal(err) + } - // 2. Now delete agy and keep only gemini. - if err := os.Remove(filepath.Join(tmpHome, ".config", "google-antigravity", "session.json")); err != nil { - t.Fatalf("failed to remove agy config: %v", err) - } + creds, source, err := selectOAuthCredentials() + if err != nil { + t.Fatalf("selectOAuthCredentials: %v", err) + } + if source != CredentialSourceGemini { + t.Errorf("source = %q, want %q", source, CredentialSourceGemini) + } + if got := tokenOf(creds); got != "token-gemini-only" { + t.Errorf("loaded token = %q, want token-gemini-only", got) + } + }) - p2, err := New(cfg) - if err != nil { - t.Fatalf("New() with gemini creds failed: %v", err) + t.Run("missing files are not warning-worthy", func(t *testing.T) { + // Sanity check: empty home directory walks the chain without + // failing in unexpected ways (only ADC would remain, which we + // don't assert on here because the test host may or may not have + // gcloud configured). + sub := t.TempDir() + t.Setenv("HOME", sub) + _, _, err := selectOAuthCredentials() + // Either ADC works on this host (no error) or no creds anywhere + // (returns our specific "no google credentials" error). Both are + // fine; the point is we don't panic or report a misconfiguration. + if err != nil && !strings.Contains(err.Error(), "no google credentials") { + t.Errorf("unexpected error shape: %v", err) + } + }) + _ = writeCreds // keep helper available if extended in future +} + +func TestFileTokenProvider_RejectsExpired(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "creds.json") + bz, _ := json.Marshal(oauthCreds{ + AccessToken: "stale", + ExpiryDate: time.Now().Add(-time.Hour).Unix(), + }) + if err := os.WriteFile(path, bz, 0600); err != nil { + t.Fatal(err) } - _ = p2 + tp := &fileTokenProvider{filePath: path} + tok, err := tp.Token(context.Background()) + if err == nil { + t.Errorf("expected error for expired token, got token %+v", tok) + } + if err != nil && !strings.Contains(err.Error(), "expired") { + t.Errorf("error %q should mention expiry", err) + } }