Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 38 additions & 42 deletions cmd/src/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,43 @@ Examples:
return err
}

var loginEndpointURL *url.URL
if cfg.configFilePath != "" {
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", cfg.configFilePath)
}

if flagSet.NArg() >= 1 {
arg := flagSet.Arg(0)
u, err := parseEndpoint(arg)
loginEndpointURL, err := parseEndpoint(arg)
if err != nil {
return cmderrors.Usage(fmt.Sprintf("invalid endpoint URL: %s", arg))
}
loginEndpointURL = u

hasEndpointURLConflict := cfg.endpointURL.String() != loginEndpointURL.String()

if hasEndpointURLConflict {
// If the default is configured it means SRC_ENDPOINT is not set
if cfg.usingDefaultEndpoint {
fmt.Fprintf(os.Stderr, "⚠️ Warning: No SRC_ENDPOINT is configured in the environment. Logging in using %q.\n", loginEndpointURL)
fmt.Fprintf(os.Stderr, "\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=%s\n\nNOTE: By default src will use %q if SRC_ENDPOINT is not set.\n", loginEndpointURL, SGDotComEndpoint)
} else {
fmt.Fprintf(os.Stderr, "⚠️ Warning: Logging into %s instead of the configured endpoint %s.\n", loginEndpointURL, cfg.endpointURL)
fmt.Fprintf(os.Stderr, "\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=%s\n\n", loginEndpointURL)
}
}

// An explicit endpoint on the CLI overrides the configured endpoint for this login.
cfg.endpointURL = loginEndpointURL
}

client := cfg.apiClient(apiFlags, io.Discard)

return loginCmd(context.Background(), loginParams{
cfg: cfg,
client: client,
out: os.Stdout,
apiFlags: apiFlags,
oauthClient: oauth.NewClient(oauth.DefaultClientID),
loginEndpointURL: loginEndpointURL,
cfg: cfg,
client: client,
out: os.Stdout,
apiFlags: apiFlags,
oauthClient: oauth.NewClient(oauth.DefaultClientID),
})
}

Expand All @@ -80,56 +98,34 @@ Examples:
}

type loginParams struct {
cfg *config
client api.Client
out io.Writer
apiFlags *api.Flags
oauthClient oauth.Client
loginEndpointURL *url.URL
cfg *config
client api.Client
out io.Writer
apiFlags *api.Flags
oauthClient oauth.Client
}

type loginFlow func(context.Context, loginParams) error

type loginFlowKind int

const (
loginFlowOAuth loginFlowKind = iota
loginFlowMissingAuth
loginFlowEndpointConflict
loginFlowValidate
)

func loginCmd(ctx context.Context, p loginParams) error {
if err := p.cfg.requireCIAccessToken(); err != nil {
return err
}

if p.cfg.configFilePath != "" {
fmt.Fprintln(p.out)
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.configFilePath)
}

_, flow := selectLoginFlow(p)
flow := selectLoginFlow(p)
if err := flow(ctx, p); err != nil {
return err
}
fmt.Fprintf(p.out, "\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=%s\n\n", p.cfg.endpointURL)
return nil
}

// selectLoginFlow decides what login flow to run based on configured AuthMode.
func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) {
if p.loginEndpointURL != nil && p.loginEndpointURL.String() != p.cfg.endpointURL.String() {
return loginFlowEndpointConflict, runEndpointConflictLogin
}
switch p.cfg.AuthMode() {
case AuthModeOAuth:
return loginFlowOAuth, runOAuthLogin
case AuthModeAccessToken:
return loginFlowValidate, runValidatedLogin
default:
return loginFlowMissingAuth, runMissingAuthLogin

func selectLoginFlow(p loginParams) loginFlow {
if p.cfg.AuthMode() == AuthModeAccessToken {
return runValidatedLogin
}
return runOAuthLogin
}

func printLoginProblem(out io.Writer, problem string) {
Expand Down
80 changes: 11 additions & 69 deletions cmd/src/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,22 @@ func mustParseURL(t *testing.T, raw string) *url.URL {
}

func TestLogin(t *testing.T) {
check := func(t *testing.T, cfg *config, endpointArgURL *url.URL) (output string, err error) {
check := func(t *testing.T, cfg *config) (output string, err error) {
t.Helper()

var out bytes.Buffer
err = loginCmd(context.Background(), loginParams{
cfg: cfg,
client: cfg.apiClient(nil, io.Discard),
out: &out,
oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")},
loginEndpointURL: endpointArgURL,
cfg: cfg,
client: cfg.apiClient(nil, io.Discard),
out: &out,
oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")},
})
return strings.TrimSpace(out.String()), err
}

t.Run("different endpoint in config vs. arg", func(t *testing.T) {
out, err := check(t, &config{endpointURL: &url.URL{Scheme: "https", Host: "example.com"}}, &url.URL{Scheme: "https", Host: "sourcegraph.example.com"})
if err == nil {
t.Fatal(err)
}
if !strings.Contains(out, "The configured endpoint is https://example.com, not https://sourcegraph.example.com.") {
t.Errorf("got output %q, want configured endpoint error", out)
}
})

t.Run("no access token triggers oauth flow", func(t *testing.T) {
u := &url.URL{Scheme: "https", Host: "example.com"}
out, err := check(t, &config{endpointURL: u}, u)
out, err := check(t, &config{endpointURL: u})
if err == nil {
t.Fatal(err)
}
Expand All @@ -63,7 +52,7 @@ func TestLogin(t *testing.T) {

t.Run("CI requires access token", func(t *testing.T) {
u := &url.URL{Scheme: "https", Host: "example.com"}
out, err := check(t, &config{endpointURL: u, inCI: true}, u)
out, err := check(t, &config{endpointURL: u, inCI: true})
if err != errCIAccessTokenRequired {
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
}
Expand All @@ -72,28 +61,14 @@ func TestLogin(t *testing.T) {
}
})

t.Run("warning when using config file", func(t *testing.T) {
endpoint := &url.URL{Scheme: "https", Host: "example.com"}
out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint)
if err != cmderrors.ExitCode1 {
t.Fatal(err)
}
if !strings.Contains(out, "Configuring src with a JSON file is deprecated") {
t.Errorf("got output %q, want deprecation warning", out)
}
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
t.Errorf("got output %q, want oauth failure output", out)
}
})

t.Run("invalid access token", func(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusUnauthorized)
}))
defer s.Close()

u := mustParseURL(t, s.URL)
out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u)
out, err := check(t, &config{endpointURL: u, accessToken: "x"})
if err != cmderrors.ExitCode1 {
t.Fatal(err)
}
Expand All @@ -111,11 +86,11 @@ func TestLogin(t *testing.T) {
defer s.Close()

u := mustParseURL(t, s.URL)
out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u)
out, err := check(t, &config{endpointURL: u, accessToken: "x"})
if err != nil {
t.Fatal(err)
}
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=$ENDPOINT"
wantOut := "✔︎ Authenticated as alice on $ENDPOINT"
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL)
if out != wantOut {
t.Errorf("got output %q, want %q", out, wantOut)
Expand Down Expand Up @@ -156,7 +131,7 @@ func TestLogin(t *testing.T) {
t.Fatal("expected stored oauth token to avoid device flow")
}
gotOut := strings.TrimSpace(out.String())
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials\n\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=$ENDPOINT"
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials"
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL)
if gotOut != wantOut {
t.Errorf("got output %q, want %q", gotOut, wantOut)
Expand Down Expand Up @@ -192,39 +167,6 @@ func (f fakeOAuthClient) Refresh(context.Context, *oauth.Token) (*oauth.TokenRes
return nil, fmt.Errorf("unexpected call to Refresh")
}

func TestSelectLoginFlow(t *testing.T) {
t.Run("uses oauth flow when no access token is configured", func(t *testing.T) {
params := loginParams{
cfg: &config{endpointURL: mustParseURL(t, "https://example.com")},
}

if got, _ := selectLoginFlow(params); got != loginFlowOAuth {
t.Fatalf("flow = %v, want %v", got, loginFlowOAuth)
}
})

t.Run("uses endpoint conflict flow when auth exists for a different endpoint", func(t *testing.T) {
params := loginParams{
cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"},
loginEndpointURL: mustParseURL(t, "https://sourcegraph.example.com"),
}

if got, _ := selectLoginFlow(params); got != loginFlowEndpointConflict {
t.Fatalf("flow = %v, want %v", got, loginFlowEndpointConflict)
}
})

t.Run("uses validation flow when auth exists for the selected endpoint", func(t *testing.T) {
params := loginParams{
cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"},
}

if got, _ := selectLoginFlow(params); got != loginFlowValidate {
t.Fatalf("flow = %v, want %v", got, loginFlowValidate)
}
})
}

func TestValidateBrowserURL(t *testing.T) {
tests := []struct {
name string
Expand Down
14 changes: 0 additions & 14 deletions cmd/src/login_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,6 @@ import (
"github.com/sourcegraph/src-cli/internal/cmderrors"
)

func runMissingAuthLogin(_ context.Context, p loginParams) error {
fmt.Fprintln(p.out)
printLoginProblem(p.out, "No access token is configured.")
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
return cmderrors.ExitCode1
}

func runEndpointConflictLogin(_ context.Context, p loginParams) error {
fmt.Fprintln(p.out)
printLoginProblem(p.out, fmt.Sprintf("The configured endpoint is %s, not %s.", p.cfg.endpointURL, p.loginEndpointURL))
fmt.Fprintln(p.out, loginAccessTokenMessage(p.loginEndpointURL))
return cmderrors.ExitCode1
}

func runValidatedLogin(ctx context.Context, p loginParams) error {
return validateCurrentUser(ctx, p.client, p.out, p.cfg.endpointURL)
}
Expand Down
20 changes: 12 additions & 8 deletions cmd/src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"github.com/sourcegraph/src-cli/internal/oauth"
)

const SGDotComEndpoint = "https://sourcegraph.com"

const usageText = `src is a tool that provides access to Sourcegraph instances.
For more information, see https://github.com/sourcegraph/src-cli

Expand Down Expand Up @@ -141,13 +143,14 @@ var cfg *config

// config holds the resolved configuration used at runtime.
type config struct {
accessToken string
additionalHeaders map[string]string
proxyURL *url.URL
proxyPath string
configFilePath string
endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig
inCI bool
accessToken string
additionalHeaders map[string]string
proxyURL *url.URL
proxyPath string
configFilePath string
endpointURL *url.URL // always non-nil; defaults to https://sourcegraph.com via readConfig
usingDefaultEndpoint bool
inCI bool
}

// configFromFile holds the config as read from the config file,
Expand Down Expand Up @@ -270,7 +273,8 @@ func readConfig() (*config, error) {
endpointStr = envEndpoint
}
if endpointStr == "" {
endpointStr = "https://sourcegraph.com"
endpointStr = SGDotComEndpoint
cfg.usingDefaultEndpoint = true
}
if envProxy != "" {
proxyStr = envProxy
Expand Down
Loading
Loading