diff --git a/cmd/simpleauth/main.go b/cmd/simpleauth/main.go index f037ceb..e9c5070 100644 --- a/cmd/simpleauth/main.go +++ b/cmd/simpleauth/main.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "log" "net/http" + "net/url" "os" "path" "strings" @@ -23,6 +24,19 @@ var secret []byte = make([]byte, 256) var lifespan time.Duration var cryptedPasswords map[string]string var loginHtml []byte +var verbose bool + +func debugln(v ...any) { + if verbose { + log.Println(v...) + } +} + +func debugf(fmt string, v ...any) { + if verbose { + log.Printf(fmt, v...) + } +} func authenticationValid(username, password string) bool { c := crypt.SHA256.New() @@ -35,19 +49,32 @@ func authenticationValid(username, password string) bool { } func usernameIfAuthenticated(req *http.Request) string { - if cookie, err := req.Cookie(CookieName); err == nil { - t, _ := token.ParseString(cookie.Value) - if t.Valid(secret) { - return t.Username - } - } - - authUsername, authPassword, ok := req.BasicAuth() - if ok { - if authenticationValid(authUsername, authPassword) { + if authUsername, authPassword, ok := req.BasicAuth(); ok { + valid := authenticationValid(authUsername, authPassword) + debugf("basic auth valid:%v username:%v", valid, authUsername) + if valid { return authUsername } + } else { + debugf("no basic auth") + } + + ncookies := 0 + for i, cookie := range req.Cookies() { + if cookie.Name != CookieName { + continue + } + t, _ := token.ParseString(cookie.Value) + valid := t.Valid(secret) + debugf("cookie %d valid:%v username:%v", i, valid, t.Username) + if valid { + return t.Username + } + ncookies += 1 } + if ncookies == 0 { + debugf("no cookies") + } return "" } @@ -79,12 +106,23 @@ func rootHandler(w http.ResponseWriter, req *http.Request) { // which needs these headers to set the cookie and try again. } - // Log the request clientIP := req.Header.Get("X-Real-IP") if clientIP == "" { clientIP = req.RemoteAddr } - log.Println(clientIP, req.Method, req.URL, status, username) + forwardedMethod := req.Header.Get("X-Forwarded-Method") + forwardedURL := url.URL{ + Scheme: req.Header.Get("X-Forwarded-Proto"), + Host: req.Header.Get("X-Forwarded-Host"), + Path: req.Header.Get("X-Forwarded-Uri"), + User: url.UserPassword(username, ""), + } + + // Log the request + log.Printf("%s %s %s login:%v %s", + clientIP, forwardedMethod, forwardedURL.String(), + login, status, + ) w.Header().Set("Content-Type", "text/html") w.Header().Set("X-Simpleauth-Authentication", status) @@ -124,6 +162,12 @@ func main() { "web", "Path to HTML files", ) + flag.BoolVar( + &verbose, + "verbose", + false, + "Print verbose logs, for debugging", + ) flag.Parse() cryptedPasswords = make(map[string]string, 10) diff --git a/go.mod b/go.mod index 0d45ece..af89ba0 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,13 @@ module git.woozle.org/neale/simpleauth -go 1.13 +go 1.18 require ( github.com/GehirnInc/crypt v0.0.0-20200316065508-bb7000b8a962 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( github.com/kr/pretty v0.3.1 // indirect github.com/stretchr/testify v1.8.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/pkg/acl/acl.go b/pkg/acl/acl.go new file mode 100644 index 0000000..273d013 --- /dev/null +++ b/pkg/acl/acl.go @@ -0,0 +1,46 @@ +package acl + +import ( + "io" + "log" + "net/http" + + "gopkg.in/yaml.v3" +) + +type ACL struct { + Rules []Rule +} + +func Read(r io.Reader) (*ACL, error) { + acl := ACL{} + ydec := yaml.NewDecoder(r) + if err := ydec.Decode(&acl); err != nil { + return nil, err + } + if err := acl.CompileURLs(); err != nil { + return nil, err + } + return &acl, nil +} + +// CompileURLs compiles regular expressions for all URLs. +func (acl *ACL) CompileURLs() error { + for i := range acl.Rules { + rule := &acl.Rules[i] + if err := rule.CompileURL(); err != nil { + return err + } + } + return nil +} + +func (acl *ACL) Match(req *http.Request) Action { + for _, rule := range acl.Rules { + log.Println(rule) + if rule.Match(req) { + return rule.Action + } + } + return Deny +} diff --git a/pkg/acl/acl_test.go b/pkg/acl/acl_test.go new file mode 100644 index 0000000..a8db98b --- /dev/null +++ b/pkg/acl/acl_test.go @@ -0,0 +1,89 @@ +package acl + +import ( + "net/http" + "net/url" + "os" + "testing" +) + +type testAcl struct { + t *testing.T + acl *ACL +} + +func readAcl(filename string) (*ACL, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + acl, err := Read(f) + if err != nil { + return nil, err + } + return acl, nil +} + +func (ta *testAcl) try(method string, URL string, expected Action) { + u, err := url.Parse(URL) + if err != nil { + ta.t.Errorf("Parsing %s: %v", URL, err) + } + req := &http.Request{ + Method: method, + URL: u, + } + action := ta.acl.Match(req) + if action != expected { + ta.t.Errorf("%s %s expected %v but got %v", method, URL, expected, action) + } +} + +func TestRegexen(t *testing.T) { + acl, err := readAcl("testdata/acl.yaml") + if err != nil { + t.Fatal(err) + } + + for i, rule := range acl.Rules { + if rule.urlRegexp == nil { + t.Errorf("Regexp not precompiled on rule %d", i) + } + } +} + +func TestUsers(t *testing.T) { + acl, err := readAcl("testdata/acl.yaml") + if err != nil { + t.Fatal(err) + } + + if acl.Rules[0].Users != nil { + t.Errorf("Rules[0].Users != nil") + } + if acl.Rules[1].Users == nil { + t.Errorf("Rules[0].Users == nil") + } +} + +func TestAclMatching(t *testing.T) { + acl, err := readAcl("testdata/acl.yaml") + if err != nil { + t.Fatal(err) + } + ta := testAcl{ + t: t, + acl: acl, + } + + ta.try("GET", "https://example.com/moo", Deny) + ta.try("GET", "https://example.com/blargh", Deny) + ta.try("GET", "https://example.com/public/moo", Public) + ta.try("BLARGH", "https://example.com/blargh", Public) + ta.try("GET", "https://example.com/only-alice/boog", Deny) + ta.try("GET", "https://alice:@example.com/only-alice/boog", Auth) + ta.try("GET", "https://alice:@example.com/bob/", Deny) + ta.try("GET", "https://bob:@example.com/bob/", Auth) +} diff --git a/pkg/acl/action.go b/pkg/acl/action.go new file mode 100644 index 0000000..5c86a4b --- /dev/null +++ b/pkg/acl/action.go @@ -0,0 +1,40 @@ +package acl + +import ( + "fmt" + "strings" +) + +type Action int + +const ( + Deny Action = iota + Auth + Public +) + +var actions = [...]string{ + "deny", + "auth", + "public", +} + +func (a Action) String() string { + return actions[a] +} + +func (a *Action) UnmarshalYAML(unmarshal func(any) error) error { + var val string + if err := unmarshal(&val); err != nil { + return err + } + val = strings.ToLower(val) + + for i, s := range actions { + if val == s { + *a = (Action)(i) + return nil + } + } + return fmt.Errorf("Unknown action type: %s", val) +} diff --git a/pkg/acl/action_test.go b/pkg/acl/action_test.go new file mode 100644 index 0000000..420e9d3 --- /dev/null +++ b/pkg/acl/action_test.go @@ -0,0 +1,38 @@ +package acl + +import ( + "testing" + + "gopkg.in/yaml.v3" +) + +func TestActions(t *testing.T) { + if Deny.String() != "deny" { + t.Error("Deny string wrong") + } + if Auth.String() != "auth" { + t.Error("Auth string wrong") + } + if Public.String() != "public" { + t.Error("Public string wrong") + } +} + +func TestYamlActions(t *testing.T) { + var out []Action + yamlDoc := "[Deny, Auth, Public, dEnY, pUBLiC]" + expected := []Action{Deny, Auth, Public, Deny, Public} + if err := yaml.Unmarshal([]byte(yamlDoc), &out); err != nil { + t.Fatal(err) + } + + if len(out) != len(expected) { + t.Error("Wrong length of unmarshalled yaml") + } + + for i, a := range out { + if expected[i] != a { + t.Errorf("Wrong value at position %d. Wanted %v, got %v", i, expected[i], a) + } + } +} diff --git a/pkg/acl/rule.go b/pkg/acl/rule.go new file mode 100644 index 0000000..607adba --- /dev/null +++ b/pkg/acl/rule.go @@ -0,0 +1,73 @@ +package acl + +import ( + "net/http" + "net/url" + "regexp" +) + +type Rule struct { + URL string + urlRegexp *regexp.Regexp + Users []string + Methods []string + Action Action +} + +// CompileURL compiles regular expressions for the URL. +// This is an startup optimization that speeds up rule processing. +func (r *Rule) CompileURL() error { + if re, err := regexp.Compile(r.URL); err != nil { + return err + } else { + r.urlRegexp = re + } + return nil +} + +// Match returns true if req is matched by the rule +func (r *Rule) Match(req *http.Request) bool { + if r.urlRegexp == nil { + // Womp womp. Things will be slow, because the compiled regex won't get cached. + r.CompileURL() + } + requestUser := req.URL.User.Username() + anonURL := url.URL(*req.URL) + anonURL.User = nil + found := r.urlRegexp.FindStringSubmatch(anonURL.String()) + if len(found) == 0 { + return false + } + + // Match any listed method + methodMatch := (len(r.Methods) == 0) + for _, method := range r.Methods { + if method == req.Method { + methodMatch = true + } + } + if !methodMatch { + return false + } + + // If they used (?P), + // make sure that matches the username in the request URL + userIndex := r.urlRegexp.SubexpIndex("user") + if (userIndex != -1) && (found[userIndex] != requestUser) { + return false + } + + // Match any listed user + userMatch := (len(r.Users) == 0) + for _, user := range r.Users { + if user == requestUser { + userMatch = true + } + } + if !userMatch { + // If no user match + return false + } + + return true +} diff --git a/pkg/acl/testdata/acl.yaml b/pkg/acl/testdata/acl.yaml new file mode 100644 index 0000000..23dfa01 --- /dev/null +++ b/pkg/acl/testdata/acl.yaml @@ -0,0 +1,22 @@ +groups: + - &any [.] + - &all + - alice + - bob + - carol +rules: + - url: ^https://example.org/public/ + action: public + - url: ^https://example.com/private/ + users: *any + action: auth + - url: ^https://example.com/blargh + methods: + - BLARGH + action: public + - url: ^https://example.com/only-alice/ + users: + - alice + action: auth + - url: ^https://example.com/(?P)/ + action: auth diff --git a/pkg/acl/url.go b/pkg/acl/url.go new file mode 100644 index 0000000..495aa72 --- /dev/null +++ b/pkg/acl/url.go @@ -0,0 +1,17 @@ +package acl + +import "net/url" + +type URL struct { + *url.URL +} + +func (u *URL) UnmarshalYAML(unmarshal func(any) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + nu, err := url.Parse(s) + u.URL = nu + return err +} diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index 03e19c0..2e1c3d7 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -30,8 +30,7 @@ func TestExpired(t *testing.T) { username := "rodney" token := New(secret, username, time.Now().Add(-10*time.Second)) - if token.Valid(secret) { - t.Error("Expired token still valid") - } + if token.Valid(secret) { + t.Error("Expired token still valid") + } } -