diff --git a/cmd/src/login.go b/cmd/src/login.go index 9818c245b6..889994af51 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -50,25 +50,43 @@ Examples: return err } - var loginEndpointURL *url.URL + if cfg.configFilePath != "" { + fmt.Fprintln(os.Stderr) + fmt.Fprintf(os.Stderr, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", cfg.configFilePath) + } + if flagSet.NArg() >= 1 { arg := flagSet.Arg(0) - u, err := parseEndpoint(arg) + loginEndpointURL, err := parseEndpoint(arg) if err != nil { return cmderrors.Usage(fmt.Sprintf("invalid endpoint URL: %s", arg)) } - loginEndpointURL = u + + hasEndpointURLConflict := cfg.endpointURL.String() != loginEndpointURL.String() + + if hasEndpointURLConflict { + // If the default is configured it means SRC_ENDPOINT is not set + if cfg.usingDefaultEndpoint { + fmt.Fprintf(os.Stderr, "⚠️ Warning: No SRC_ENDPOINT is configured in the environment. Logging in using %q.\n", loginEndpointURL) + fmt.Fprintf(os.Stderr, "\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=%s\n\nNOTE: By default src will use %q if SRC_ENDPOINT is not set.\n", loginEndpointURL, SGDotComEndpoint) + } else { + fmt.Fprintf(os.Stderr, "⚠️ Warning: Logging into %s instead of the configured endpoint %s.\n", loginEndpointURL, cfg.endpointURL) + fmt.Fprintf(os.Stderr, "\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=%s\n\n", loginEndpointURL) + } + } + + // An explicit endpoint on the CLI overrides the configured endpoint for this login. + cfg.endpointURL = loginEndpointURL } client := cfg.apiClient(apiFlags, io.Discard) return loginCmd(context.Background(), loginParams{ - cfg: cfg, - client: client, - out: os.Stdout, - apiFlags: apiFlags, - oauthClient: oauth.NewClient(oauth.DefaultClientID), - loginEndpointURL: loginEndpointURL, + cfg: cfg, + client: client, + out: os.Stdout, + apiFlags: apiFlags, + oauthClient: oauth.NewClient(oauth.DefaultClientID), }) } @@ -80,56 +98,34 @@ Examples: } type loginParams struct { - cfg *config - client api.Client - out io.Writer - apiFlags *api.Flags - oauthClient oauth.Client - loginEndpointURL *url.URL + cfg *config + client api.Client + out io.Writer + apiFlags *api.Flags + oauthClient oauth.Client } type loginFlow func(context.Context, loginParams) error -type loginFlowKind int - -const ( - loginFlowOAuth loginFlowKind = iota - loginFlowMissingAuth - loginFlowEndpointConflict - loginFlowValidate -) - func loginCmd(ctx context.Context, p loginParams) error { if err := p.cfg.requireCIAccessToken(); err != nil { return err } - if p.cfg.configFilePath != "" { - fmt.Fprintln(p.out) - fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath) - } - - _, flow := selectLoginFlow(p) + flow := selectLoginFlow(p) if err := flow(ctx, p); err != nil { return err } - fmt.Fprintf(p.out, "\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=%s\n\n", p.cfg.endpointURL) return nil } // selectLoginFlow decides what login flow to run based on configured AuthMode. -func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) { - if p.loginEndpointURL != nil && p.loginEndpointURL.String() != p.cfg.endpointURL.String() { - return loginFlowEndpointConflict, runEndpointConflictLogin - } - switch p.cfg.AuthMode() { - case AuthModeOAuth: - return loginFlowOAuth, runOAuthLogin - case AuthModeAccessToken: - return loginFlowValidate, runValidatedLogin - default: - return loginFlowMissingAuth, runMissingAuthLogin + +func selectLoginFlow(p loginParams) loginFlow { + if p.cfg.AuthMode() == AuthModeAccessToken { + return runValidatedLogin } + return runOAuthLogin } func printLoginProblem(out io.Writer, problem string) { diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 1250d1adb9..5dba8b464b 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -26,33 +26,22 @@ func mustParseURL(t *testing.T, raw string) *url.URL { } func TestLogin(t *testing.T) { - check := func(t *testing.T, cfg *config, endpointArgURL *url.URL) (output string, err error) { + check := func(t *testing.T, cfg *config) (output string, err error) { t.Helper() var out bytes.Buffer err = loginCmd(context.Background(), loginParams{ - cfg: cfg, - client: cfg.apiClient(nil, io.Discard), - out: &out, - oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")}, - loginEndpointURL: endpointArgURL, + cfg: cfg, + client: cfg.apiClient(nil, io.Discard), + out: &out, + oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")}, }) return strings.TrimSpace(out.String()), err } - t.Run("different endpoint in config vs. arg", func(t *testing.T) { - out, err := check(t, &config{endpointURL: &url.URL{Scheme: "https", Host: "example.com"}}, &url.URL{Scheme: "https", Host: "sourcegraph.example.com"}) - if err == nil { - t.Fatal(err) - } - if !strings.Contains(out, "The configured endpoint is https://example.com, not https://sourcegraph.example.com.") { - t.Errorf("got output %q, want configured endpoint error", out) - } - }) - t.Run("no access token triggers oauth flow", func(t *testing.T) { u := &url.URL{Scheme: "https", Host: "example.com"} - out, err := check(t, &config{endpointURL: u}, u) + out, err := check(t, &config{endpointURL: u}) if err == nil { t.Fatal(err) } @@ -63,7 +52,7 @@ func TestLogin(t *testing.T) { t.Run("CI requires access token", func(t *testing.T) { u := &url.URL{Scheme: "https", Host: "example.com"} - out, err := check(t, &config{endpointURL: u, inCI: true}, u) + out, err := check(t, &config{endpointURL: u, inCI: true}) if err != errCIAccessTokenRequired { t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired) } @@ -72,20 +61,6 @@ func TestLogin(t *testing.T) { } }) - t.Run("warning when using config file", func(t *testing.T) { - endpoint := &url.URL{Scheme: "https", Host: "example.com"} - out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint) - if err != cmderrors.ExitCode1 { - t.Fatal(err) - } - if !strings.Contains(out, "Configuring src with a JSON file is deprecated") { - t.Errorf("got output %q, want deprecation warning", out) - } - if !strings.Contains(out, "OAuth Device flow authentication failed:") { - t.Errorf("got output %q, want oauth failure output", out) - } - }) - t.Run("invalid access token", func(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "", http.StatusUnauthorized) @@ -93,7 +68,7 @@ func TestLogin(t *testing.T) { defer s.Close() u := mustParseURL(t, s.URL) - out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u) + out, err := check(t, &config{endpointURL: u, accessToken: "x"}) if err != cmderrors.ExitCode1 { t.Fatal(err) } @@ -111,11 +86,11 @@ func TestLogin(t *testing.T) { defer s.Close() u := mustParseURL(t, s.URL) - out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u) + out, err := check(t, &config{endpointURL: u, accessToken: "x"}) if err != nil { t.Fatal(err) } - wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=$ENDPOINT" + wantOut := "✔︎ Authenticated as alice on $ENDPOINT" wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL) if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) @@ -156,7 +131,7 @@ func TestLogin(t *testing.T) { t.Fatal("expected stored oauth token to avoid device flow") } gotOut := strings.TrimSpace(out.String()) - wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials\n\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=$ENDPOINT" + wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials" wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL) if gotOut != wantOut { t.Errorf("got output %q, want %q", gotOut, wantOut) @@ -192,39 +167,6 @@ func (f fakeOAuthClient) Refresh(context.Context, *oauth.Token) (*oauth.TokenRes return nil, fmt.Errorf("unexpected call to Refresh") } -func TestSelectLoginFlow(t *testing.T) { - t.Run("uses oauth flow when no access token is configured", func(t *testing.T) { - params := loginParams{ - cfg: &config{endpointURL: mustParseURL(t, "https://example.com")}, - } - - if got, _ := selectLoginFlow(params); got != loginFlowOAuth { - t.Fatalf("flow = %v, want %v", got, loginFlowOAuth) - } - }) - - t.Run("uses endpoint conflict flow when auth exists for a different endpoint", func(t *testing.T) { - params := loginParams{ - cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"}, - loginEndpointURL: mustParseURL(t, "https://sourcegraph.example.com"), - } - - if got, _ := selectLoginFlow(params); got != loginFlowEndpointConflict { - t.Fatalf("flow = %v, want %v", got, loginFlowEndpointConflict) - } - }) - - t.Run("uses validation flow when auth exists for the selected endpoint", func(t *testing.T) { - params := loginParams{ - cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"}, - } - - if got, _ := selectLoginFlow(params); got != loginFlowValidate { - t.Fatalf("flow = %v, want %v", got, loginFlowValidate) - } - }) -} - func TestValidateBrowserURL(t *testing.T) { tests := []struct { name string diff --git a/cmd/src/login_validate.go b/cmd/src/login_validate.go index 9aa65cdcef..095ea7ab22 100644 --- a/cmd/src/login_validate.go +++ b/cmd/src/login_validate.go @@ -11,20 +11,6 @@ import ( "github.com/sourcegraph/src-cli/internal/cmderrors" ) -func runMissingAuthLogin(_ context.Context, p loginParams) error { - fmt.Fprintln(p.out) - printLoginProblem(p.out, "No access token is configured.") - fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL)) - return cmderrors.ExitCode1 -} - -func runEndpointConflictLogin(_ context.Context, p loginParams) error { - fmt.Fprintln(p.out) - printLoginProblem(p.out, fmt.Sprintf("The configured endpoint is %s, not %s.", p.cfg.endpointURL, p.loginEndpointURL)) - fmt.Fprintln(p.out, loginAccessTokenMessage(p.loginEndpointURL)) - return cmderrors.ExitCode1 -} - func runValidatedLogin(ctx context.Context, p loginParams) error { return validateCurrentUser(ctx, p.client, p.out, p.cfg.endpointURL) } diff --git a/cmd/src/main.go b/cmd/src/main.go index 68478c9276..fd217ba51f 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -19,6 +19,8 @@ import ( "github.com/sourcegraph/src-cli/internal/oauth" ) +const SGDotComEndpoint = "https://sourcegraph.com" + const usageText = `src is a tool that provides access to Sourcegraph instances. For more information, see https://github.com/sourcegraph/src-cli @@ -141,13 +143,14 @@ var cfg *config // config holds the resolved configuration used at runtime. type config struct { - accessToken string - additionalHeaders map[string]string - proxyURL *url.URL - proxyPath string - configFilePath string - endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig - inCI bool + accessToken string + additionalHeaders map[string]string + proxyURL *url.URL + proxyPath string + configFilePath string + endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig + usingDefaultEndpoint bool + inCI bool } // configFromFile holds the config as read from the config file, @@ -270,7 +273,8 @@ func readConfig() (*config, error) { endpointStr = envEndpoint } if endpointStr == "" { - endpointStr = "https://sourcegraph.com" + endpointStr = SGDotComEndpoint + cfg.usingDefaultEndpoint = true } if envProxy != "" { proxyStr = envProxy diff --git a/cmd/src/main_test.go b/cmd/src/main_test.go index c0b29822b0..0b23cb9938 100644 --- a/cmd/src/main_test.go +++ b/cmd/src/main_test.go @@ -51,7 +51,8 @@ func TestReadConfig(t *testing.T) { Scheme: "https", Host: "sourcegraph.com", }, - additionalHeaders: map[string]string{}, + usingDefaultEndpoint: true, + additionalHeaders: map[string]string{}, }, }, { @@ -149,8 +150,9 @@ func TestReadConfig(t *testing.T) { Scheme: "https", Host: "sourcegraph.com", }, - accessToken: "abc", - additionalHeaders: map[string]string{}, + usingDefaultEndpoint: true, + accessToken: "abc", + additionalHeaders: map[string]string{}, }, }, { @@ -173,8 +175,9 @@ func TestReadConfig(t *testing.T) { Scheme: "https", Host: "sourcegraph.com", }, - accessToken: "", - proxyPath: "", + usingDefaultEndpoint: true, + accessToken: "", + proxyPath: "", proxyURL: &url.URL{ Scheme: "https", Host: "proxy.com:8080", @@ -209,9 +212,10 @@ func TestReadConfig(t *testing.T) { Scheme: "https", Host: "sourcegraph.com", }, - proxyPath: socketPath, - proxyURL: nil, - additionalHeaders: map[string]string{}, + usingDefaultEndpoint: true, + proxyPath: socketPath, + proxyURL: nil, + additionalHeaders: map[string]string{}, }, }, { @@ -222,9 +226,10 @@ func TestReadConfig(t *testing.T) { Scheme: "https", Host: "sourcegraph.com", }, - proxyPath: socketPath, - proxyURL: nil, - additionalHeaders: map[string]string{}, + usingDefaultEndpoint: true, + proxyPath: socketPath, + proxyURL: nil, + additionalHeaders: map[string]string{}, }, }, { @@ -235,7 +240,8 @@ func TestReadConfig(t *testing.T) { Scheme: "https", Host: "sourcegraph.com", }, - proxyPath: "", + usingDefaultEndpoint: true, + proxyPath: "", proxyURL: &url.URL{ Scheme: "socks5", Host: "localhost:1080", @@ -251,7 +257,8 @@ func TestReadConfig(t *testing.T) { Scheme: "https", Host: "sourcegraph.com", }, - proxyPath: "", + usingDefaultEndpoint: true, + proxyPath: "", proxyURL: &url.URL{ Scheme: "socks5h", Host: "localhost:1080", @@ -331,9 +338,10 @@ func TestReadConfig(t *testing.T) { name: "CI does not require access token during config read", envCI: "1", want: &config{ - endpointURL: &url.URL{Scheme: "https", Host: "sourcegraph.com"}, - additionalHeaders: map[string]string{}, - inCI: true, + endpointURL: &url.URL{Scheme: "https", Host: "sourcegraph.com"}, + usingDefaultEndpoint: true, + additionalHeaders: map[string]string{}, + inCI: true, }, }, {