Skip to content
This repository has been archived by the owner on Jan 24, 2019. It is now read-only.

Commit

Permalink
Github provider: use login as user
Browse files Browse the repository at this point in the history
- Save both user and email in session state:
    Encoding/decoding methods save both email and user
    field in session state, for use cases when User is not derived from
    email's local-parth, like for GitHub provider.

    For retrocompatibility, if no user is obtained by the provider,
    (e.g. User is an empty string) the encoding/decoding methods fall back
    to the previous behavior and use the email's local-part

    Updated also related tests and added two more tests to show behavior
    when session contains a non-empty user value.

- Added first basic GitHub provider tests

- Added GetUserName method to Provider interface
    The new GetUserName method is intended to return the User
    value when this is not the email's local-part.

    Added also the default implementation to provider_default.go

- Added call to GetUserName in redeemCode

    the new GetUserName method is used in redeemCode
    to get SessionState User value.

    For backward compatibility, if GetUserName error is
    "not implemented", the error is ignored.

- Added GetUserName method and tests to github provider.
  • Loading branch information
clobrano committed Nov 20, 2017
1 parent 6ddbb2c commit 731fa9f
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 43 deletions.
7 changes: 7 additions & 0 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s)
}

if s.User == "" {
s.User, err = p.provider.GetUserName(s)
if err != nil && err.Error() == "not implemented" {
err = nil
}
}
return
}

Expand Down
47 changes: 45 additions & 2 deletions providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
} else {
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
}

log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)

if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
Expand All @@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {

return "", nil
}

func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
var user struct {
Login string `json:"login"`
Email string `json:"email"`
}

endpoint := &url.URL{
Scheme: p.ValidateURL.Scheme,
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"),
}

req, err := http.NewRequest("GET", endpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err)
}

req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}

body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return "", err
}

if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}

log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)

if err := json.Unmarshal(body, &user); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}

return user.Login, nil
}
146 changes: 146 additions & 0 deletions providers/github_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package providers

import (
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func testGitHubProvider(hostname string) *GitHubProvider {
p := NewGitHubProvider(
&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: ""})
if hostname != "" {
updateURL(p.Data().LoginURL, hostname)
updateURL(p.Data().RedeemURL, hostname)
updateURL(p.Data().ProfileURL, hostname)
updateURL(p.Data().ValidateURL, hostname)
}
return p
}

func testGitHubBackend(payload string) *httptest.Server {
pathToQueryMap := map[string]string{
"/user": "",
"/user/emails": "",
}

return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
query, ok := pathToQueryMap[url.Path]
if !ok {
w.WriteHeader(404)
} else if url.RawQuery != query {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}

func TestGitHubProviderDefaults(t *testing.T) {
p := testGitHubProvider("")
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://github.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://github.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.github.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "user:email", p.Data().Scope)
}

func TestGitHubProviderOverrides(t *testing.T) {
p := NewGitHubProvider(
&ProviderData{
LoginURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/authorize"},
RedeemURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/access_token"},
ValidateURL: &url.URL{
Scheme: "https",
Host: "api.example.com",
Path: "/"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://example.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://example.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.example.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "profile", p.Data().Scope)
}

func TestGitHubProviderGetEmailAddress(t *testing.T) {
b := testGitHubBackend(`[ {"email": "[email protected]", "primary": true} ]`)
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "[email protected]", email)
}

// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testGitHubBackend("unused payload")
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}

func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testGitHubBackend("{\"foo\": \"bar\"}")
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}

func TestGitHubProviderGetUserName(t *testing.T) {
b := testGitHubBackend(`{"email": "[email protected]", "login": "mbland"}`)
defer b.Close()

bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)

session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetUserName(session)
assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email)
}
5 changes: 5 additions & 0 deletions providers/provider_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}

// GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}

// ValidateGroup validates that the provided email exists in the configured provider
// email group(s).
func (p *ProviderData) ValidateGroup(email string) bool {
Expand Down
1 change: 1 addition & 0 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
type Provider interface {
Data() *ProviderData
GetEmailAddress(*SessionState) (string, error)
GetUserName(*SessionState) (string, error)
Redeem(string, string) (*SessionState, error)
ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool
Expand Down
76 changes: 40 additions & 36 deletions providers/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool {
}

func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.userOrEmail())
o := fmt.Sprintf("Session{%s", s.accountInfo())
if s.AccessToken != "" {
o += " token:true"
}
Expand All @@ -40,17 +40,13 @@ func (s *SessionState) String() string {

func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" {
return s.userOrEmail(), nil
return s.accountInfo(), nil
}
return s.EncryptedString(c)
}

func (s *SessionState) userOrEmail() string {
u := s.User
if s.Email != "" {
u = s.Email
}
return u
func (s *SessionState) accountInfo() string {
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
}

func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
Expand All @@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
}
a := s.AccessToken
if a != "" {
a, err = c.Encrypt(a)
if err != nil {
if a, err = c.Encrypt(a); err != nil {
return "", err
}
}
r := s.RefreshToken
if r != "" {
r, err = c.Encrypt(r)
if err != nil {
if r, err = c.Encrypt(r); err != nil {
return "", err
}
}
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil
return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
}

func decodeSessionStatePlain(v string) (s *SessionState, err error) {
chunks := strings.Split(v, " ")
if len(chunks) != 2 {
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks))
}

email := strings.TrimPrefix(chunks[0], "email:")
user := strings.TrimPrefix(chunks[1], "user:")
if user == "" {
user = strings.Split(email, "@")[0]
}

return &SessionState{User: user, Email: email}, nil
}

func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
chunks := strings.Split(v, "|")
if len(chunks) == 1 {
if strings.Contains(chunks[0], "@") {
u := strings.Split(v, "@")[0]
return &SessionState{Email: v, User: u}, nil
}
return &SessionState{User: v}, nil
if c == nil {
return decodeSessionStatePlain(v)
}

chunks := strings.Split(v, "|")
if len(chunks) != 4 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
return
}

s = &SessionState{}
if c != nil && chunks[1] != "" {
s.AccessToken, err = c.Decrypt(chunks[1])
if err != nil {
sessionState, err := decodeSessionStatePlain(chunks[0])
if err != nil {
return nil, err
}

if chunks[1] != "" {
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil {
return nil, err
}
}
if c != nil && chunks[3] != "" {
s.RefreshToken, err = c.Decrypt(chunks[3])
if err != nil {

ts, _ := strconv.Atoi(chunks[2])
sessionState.ExpiresOn = time.Unix(int64(ts), 0)

if chunks[3] != "" {
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
return nil, err
}
}
if u := chunks[0]; strings.Contains(u, "@") {
s.Email = u
s.User = strings.Split(u, "@")[0]
} else {
s.User = u
}
ts, _ := strconv.Atoi(chunks[2])
s.ExpiresOn = time.Unix(int64(ts), 0)
return

return sessionState, nil
}
Loading

0 comments on commit 731fa9f

Please sign in to comment.