Commit b9fd2573 authored by ale's avatar ale

Merge branch 'multi-2fa' into 'master'

Return all supported 2FA mechanisms in the authentication response

See merge request !6
parents e29f41a8 dab5b7f1
Pipeline #2832 passed with stages
in 1 minute and 43 seconds
......@@ -58,7 +58,6 @@ type Request struct {
U2FAppID string
U2FResponse *u2f.SignResponse
DeviceInfo *DeviceInfo
//Extra map[string]string
}
func (r *Request) EncodeToMap(m map[string]string, prefix string) {
......@@ -98,6 +97,25 @@ type UserInfo struct {
Groups []string
}
func encodeStringList(m map[string]string, prefix string, l []string) {
for i, elem := range l {
m[fmt.Sprintf("%s.%d.", prefix, i)] = elem
}
}
func decodeStringList(m map[string]string, prefix string) (out []string) {
i := 0
for {
s, ok := m[fmt.Sprintf("%s.%d.", prefix, i)]
if !ok {
break
}
out = append(out, s)
i++
}
return
}
func (u *UserInfo) EncodeToMap(m map[string]string, prefix string) {
if u.Email != "" {
m[prefix+"email"] = u.Email
......@@ -105,24 +123,14 @@ func (u *UserInfo) EncodeToMap(m map[string]string, prefix string) {
if u.Shard != "" {
m[prefix+"shard"] = u.Shard
}
for i, g := range u.Groups {
m[fmt.Sprintf("%sgroup.%d.", prefix, i)] = g
}
encodeStringList(m, prefix+"group", u.Groups)
}
func decodeUserInfoFromMap(m map[string]string, prefix string) *UserInfo {
u := UserInfo{
Email: m[prefix+"email"],
Shard: m[prefix+"shard"],
}
i := 0
for {
s, ok := m[fmt.Sprintf("%sgroup.%d.", prefix, i)]
if !ok {
break
}
u.Groups = append(u.Groups, s)
i++
Email: m[prefix+"email"],
Shard: m[prefix+"shard"],
Groups: decodeStringList(m, prefix+"group"),
}
if u.Email == "" && u.Shard == "" && len(u.Groups) == 0 {
return nil
......@@ -133,16 +141,49 @@ func decodeUserInfoFromMap(m map[string]string, prefix string) *UserInfo {
// Response to an authentication request.
type Response struct {
Status Status
TFAMethod TFAMethod
TFAMethods []TFAMethod
U2FSignRequest *u2f.WebSignRequest
UserInfo *UserInfo
}
// Has2FAMethod checks for the presence of a two-factor authentication
// method in the Response.
func (r *Response) Has2FAMethod(needle TFAMethod) bool {
for _, m := range r.TFAMethods {
if m == needle {
return true
}
}
return false
}
func encodeTFAMethodList(m map[string]string, prefix string, l []TFAMethod) {
if len(l) == 0 {
return
}
tmp := make([]string, 0, len(l))
for _, el := range l {
tmp = append(tmp, string(el))
}
encodeStringList(m, prefix, tmp)
}
func decodeTFAMethodList(m map[string]string, prefix string) []TFAMethod {
l := decodeStringList(m, prefix)
if len(l) == 0 {
return nil
}
out := make([]TFAMethod, 0, len(l))
for _, el := range l {
out = append(out, TFAMethod(el))
}
return out
}
func (r *Response) EncodeToMap(m map[string]string, prefix string) {
m[prefix+"status"] = r.Status.String()
m[prefix+"2fa_method"] = string(r.TFAMethod)
encodeTFAMethodList(m, prefix+"2fa_methods", r.TFAMethods)
if r.U2FSignRequest != nil {
// External type.
encodeU2FSignRequestToMap(r.U2FSignRequest, m, prefix+"u2f_req.")
}
if r.UserInfo != nil {
......@@ -152,7 +193,7 @@ func (r *Response) EncodeToMap(m map[string]string, prefix string) {
func (r *Response) DecodeFromMap(m map[string]string, prefix string) {
r.Status = parseAuthStatus(m[prefix+"status"])
r.TFAMethod = TFAMethod(m[prefix+"2fa_method"])
r.TFAMethods = decodeTFAMethodList(m, prefix+"2fa_methods")
r.U2FSignRequest = decodeU2FSignRequestFromMap(m, prefix+"u2f_req.")
r.UserInfo = decodeUserInfoFromMap(m, prefix+"user.")
}
......
package auth
import (
"encoding/json"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/tstranex/u2f"
)
......@@ -40,10 +39,8 @@ func TestProtocol_SerializeRequest(t *testing.T) {
if err != nil {
t.Fatal("Decode():", err)
}
if !reflect.DeepEqual(req, &req2) {
d1, _ := json.MarshalIndent(req, "", " ")
d2, _ := json.MarshalIndent(req2, "", " ")
t.Errorf("decode results differ: %+v vs %+v", string(d1), string(d2))
if diffs := cmp.Diff(req, &req2); diffs != "" {
t.Errorf("decode results differ:\n%s", diffs)
}
}
......@@ -51,8 +48,8 @@ func TestProtocol_SerializeResponse(t *testing.T) {
c := &kvCodec{}
resp := &Response{
Status: StatusInsufficientCredentials,
TFAMethod: TFAMethodU2F,
Status: StatusInsufficientCredentials,
TFAMethods: []TFAMethod{TFAMethodU2F},
U2FSignRequest: &u2f.WebSignRequest{
AppID: "https://some-app-id",
Challenge: "u2fChallenge",
......@@ -83,9 +80,7 @@ func TestProtocol_SerializeResponse(t *testing.T) {
if err != nil {
t.Fatal("Decode():", err)
}
if !reflect.DeepEqual(resp, &resp2) {
d1, _ := json.MarshalIndent(resp, "", " ")
d2, _ := json.MarshalIndent(resp2, "", " ")
t.Errorf("decode results differ: %+v vs %+v", string(d1), string(d2))
if diffs := cmp.Diff(resp, &resp2); diffs != "" {
t.Errorf("decode results differ: %s", diffs)
}
}
......@@ -435,15 +435,18 @@ func (s *Server) authenticateUserWith2FA(user *backend.User, req *auth.Request)
resp := &auth.Response{
Status: auth.StatusInsufficientCredentials,
}
// Two-factor mechanisms are returned in order of
// decreasing preference, so start with U2F.
if req.U2FAppID != "" && user.HasU2F() {
resp.TFAMethod = auth.TFAMethodU2F
resp.TFAMethods = append(resp.TFAMethods, auth.TFAMethodU2F)
signReq, err := s.u2fSignRequest(user, req.U2FAppID)
if err != nil {
return nil, err
}
resp.U2FSignRequest = signReq
} else if user.HasOTP() {
resp.TFAMethod = auth.TFAMethodOTP
}
if user.HasOTP() {
resp.TFAMethods = append(resp.TFAMethods, auth.TFAMethodOTP)
}
return resp, nil
}
......
......@@ -171,8 +171,8 @@ func runAuthenticationTest(t *testing.T, client client.Client) {
if resp.Status != td.expectedStatus {
t.Errorf("authentication error: s=interactive u=%s p=%s, expected=%v got=%v", td.username, td.password, td.expectedStatus, resp.Status)
}
if resp.TFAMethod != td.expectedTFAMethod {
t.Errorf("mismatch in TFAMethod hint in authentication response: s=interactive u=%s p=%s, expected=%v got=%v", td.username, td.password, td.expectedTFAMethod, resp.TFAMethod)
if td.expectedTFAMethod != auth.TFAMethodNone && !resp.Has2FAMethod(td.expectedTFAMethod) {
t.Errorf("mismatch in TFAMethod hint in authentication response: s=interactive u=%s p=%s, expected=%v got=%v", td.username, td.password, td.expectedTFAMethod, resp.TFAMethods)
}
}
}
......
Copyright (c) 2017 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
This diff is collapsed.
// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.
// +build !debug
package diff
var debug debugger
type debugger struct{}
func (debugger) Begin(_, _ int, f EqualFunc, _, _ *EditScript) EqualFunc {
return f
}
func (debugger) Update() {}
func (debugger) Finish() {}
// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.
// +build debug
package diff
import (
"fmt"
"strings"
"sync"
"time"
)
// The algorithm can be seen running in real-time by enabling debugging:
// go test -tags=debug -v
//
// Example output:
// === RUN TestDifference/#34
// ┌───────────────────────────────┐
// │ \ · · · · · · · · · · · · · · │
// │ · # · · · · · · · · · · · · · │
// │ · \ · · · · · · · · · · · · · │
// │ · · \ · · · · · · · · · · · · │
// │ · · · X # · · · · · · · · · · │
// │ · · · # \ · · · · · · · · · · │
// │ · · · · · # # · · · · · · · · │
// │ · · · · · # \ · · · · · · · · │
// │ · · · · · · · \ · · · · · · · │
// │ · · · · · · · · \ · · · · · · │
// │ · · · · · · · · · \ · · · · · │
// │ · · · · · · · · · · \ · · # · │
// │ · · · · · · · · · · · \ # # · │
// │ · · · · · · · · · · · # # # · │
// │ · · · · · · · · · · # # # # · │
// │ · · · · · · · · · # # # # # · │
// │ · · · · · · · · · · · · · · \ │
// └───────────────────────────────┘
// [.Y..M.XY......YXYXY.|]
//
// The grid represents the edit-graph where the horizontal axis represents
// list X and the vertical axis represents list Y. The start of the two lists
// is the top-left, while the ends are the bottom-right. The '·' represents
// an unexplored node in the graph. The '\' indicates that the two symbols
// from list X and Y are equal. The 'X' indicates that two symbols are similar
// (but not exactly equal) to each other. The '#' indicates that the two symbols
// are different (and not similar). The algorithm traverses this graph trying to
// make the paths starting in the top-left and the bottom-right connect.
//
// The series of '.', 'X', 'Y', and 'M' characters at the bottom represents
// the currently established path from the forward and reverse searches,
// separated by a '|' character.
const (
updateDelay = 100 * time.Millisecond
finishDelay = 500 * time.Millisecond
ansiTerminal = true // ANSI escape codes used to move terminal cursor
)
var debug debugger
type debugger struct {
sync.Mutex
p1, p2 EditScript
fwdPath, revPath *EditScript
grid []byte
lines int
}
func (dbg *debugger) Begin(nx, ny int, f EqualFunc, p1, p2 *EditScript) EqualFunc {
dbg.Lock()
dbg.fwdPath, dbg.revPath = p1, p2
top := "┌─" + strings.Repeat("──", nx) + "┐\n"
row := "│ " + strings.Repeat("· ", nx) + "│\n"
btm := "└─" + strings.Repeat("──", nx) + "┘\n"
dbg.grid = []byte(top + strings.Repeat(row, ny) + btm)
dbg.lines = strings.Count(dbg.String(), "\n")
fmt.Print(dbg)
// Wrap the EqualFunc so that we can intercept each result.
return func(ix, iy int) (r Result) {
cell := dbg.grid[len(top)+iy*len(row):][len("│ ")+len("· ")*ix:][:len("·")]
for i := range cell {
cell[i] = 0 // Zero out the multiple bytes of UTF-8 middle-dot
}
switch r = f(ix, iy); {
case r.Equal():
cell[0] = '\\'
case r.Similar():
cell[0] = 'X'
default:
cell[0] = '#'
}
return
}
}
func (dbg *debugger) Update() {
dbg.print(updateDelay)
}
func (dbg *debugger) Finish() {
dbg.print(finishDelay)
dbg.Unlock()
}
func (dbg *debugger) String() string {
dbg.p1, dbg.p2 = *dbg.fwdPath, dbg.p2[:0]
for i := len(*dbg.revPath) - 1; i >= 0; i-- {
dbg.p2 = append(dbg.p2, (*dbg.revPath)[i])
}
return fmt.Sprintf("%s[%v|%v]\n\n", dbg.grid, dbg.p1, dbg.p2)
}
func (dbg *debugger) print(d time.Duration) {
if ansiTerminal {
fmt.Printf("\x1b[%dA", dbg.lines) // Reset terminal cursor
}
fmt.Print(dbg)
time.Sleep(d)
}
This diff is collapsed.
// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.
// Package function identifies function types.
package function
import "reflect"
type funcType int
const (
_ funcType = iota
ttbFunc // func(T, T) bool
tibFunc // func(T, I) bool
trFunc // func(T) R
Equal = ttbFunc // func(T, T) bool
EqualAssignable = tibFunc // func(T, I) bool; encapsulates func(T, T) bool
Transformer = trFunc // func(T) R
ValueFilter = ttbFunc // func(T, T) bool
Less = ttbFunc // func(T, T) bool
)
var boolType = reflect.TypeOf(true)
// IsType reports whether the reflect.Type is of the specified function type.
func IsType(t reflect.Type, ft funcType) bool {
if t == nil || t.Kind() != reflect.Func || t.IsVariadic() {
return false
}
ni, no := t.NumIn(), t.NumOut()
switch ft {
case ttbFunc: // func(T, T) bool
if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == boolType {
return true
}
case tibFunc: // func(T, I) bool
if ni == 2 && no == 1 && t.In(0).AssignableTo(t.In(1)) && t.Out(0) == boolType {
return true
}
case trFunc: // func(T) R
if ni == 1 && no == 1 {
return true
}
}
return false
}
// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.
// Package value provides functionality for reflect.Value types.
package value
import (
"fmt"
"reflect"
"strconv"
"strings"
"unicode"
)
var stringerIface = reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
// Format formats the value v as a string.
//
// This is similar to fmt.Sprintf("%+v", v) except this:
// * Prints the type unless it can be elided
// * Avoids printing struct fields that are zero
// * Prints a nil-slice as being nil, not empty
// * Prints map entries in deterministic order
func Format(v reflect.Value, conf FormatConfig) string {
conf.printType = true
conf.followPointers = true
conf.realPointers = true
return formatAny(v, conf, visited{})
}
type FormatConfig struct {
UseStringer bool // Should the String method be used if available?
printType bool // Should we print the type before the value?
PrintPrimitiveType bool // Should we print the type of primitives?
followPointers bool // Should we recursively follow pointers?
realPointers bool // Should we print the real address of pointers?
}
func formatAny(v reflect.Value, conf FormatConfig, m visited) string {
// TODO: Should this be a multi-line printout in certain situations?
if !v.IsValid() {
return "<non-existent>"
}
if conf.UseStringer && v.Type().Implements(stringerIface) && v.CanInterface() {
if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface) && v.IsNil() {
return "<nil>"
}
const stringerPrefix = "s" // Indicates that the String method was used
s := v.Interface().(fmt.Stringer).String()
return stringerPrefix + formatString(s)
}
switch v.Kind() {
case reflect.Bool:
return formatPrimitive(v.Type(), v.Bool(), conf)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return formatPrimitive(v.Type(), v.Int(), conf)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if v.Type().PkgPath() == "" || v.Kind() == reflect.Uintptr {
// Unnamed uints are usually bytes or words, so use hexadecimal.
return formatPrimitive(v.Type(), formatHex(v.Uint()), conf)
}
return formatPrimitive(v.Type(), v.Uint(), conf)
case reflect.Float32, reflect.Float64:
return formatPrimitive(v.Type(), v.Float(), conf)
case reflect.Complex64, reflect.Complex128:
return formatPrimitive(v.Type(), v.Complex(), conf)
case reflect.String:
return formatPrimitive(v.Type(), formatString(v.String()), conf)
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
return formatPointer(v, conf)
case reflect.Ptr:
if v.IsNil() {
if conf.printType {
return fmt.Sprintf("(%v)(nil)", v.Type())
}
return "<nil>"
}
if m.Visit(v) || !conf.followPointers {
return formatPointer(v, conf)
}
return "&" + formatAny(v.Elem(), conf, m)
case reflect.Interface:
if v.IsNil() {
if conf.printType {
return fmt.Sprintf("%v(nil)", v.Type())
}
return "<nil>"
}
return formatAny(v.Elem(), conf, m)
case reflect.Slice:
if v.IsNil() {
if conf.printType {
return fmt.Sprintf("%v(nil)", v.Type())
}
return "<nil>"
}
fallthrough
case reflect.Array:
var ss []string
subConf := conf
subConf.printType = v.Type().Elem().Kind() == reflect.Interface
for i := 0; i < v.Len(); i++ {
vi := v.Index(i)
if vi.CanAddr() { // Check for recursive elements
p := vi.Addr()
if m.Visit(p) {
subConf := conf
subConf.printType = true
ss = append(ss, "*"+formatPointer(p, subConf))
continue
}
}
ss = append(ss, formatAny(vi, subConf, m))
}
s := fmt.Sprintf("{%s}", strings.Join(ss, ", "))
if conf.printType {
return v.Type().String() + s
}
return s
case reflect.Map:
if v.IsNil() {
if conf.printType {
return fmt.Sprintf("%v(nil)", v.Type())
}
return "<nil>"
}
if m.Visit(v) {
return formatPointer(v, conf)
}
var ss []string
keyConf, valConf := conf, conf
keyConf.printType = v.Type().Key().Kind() == reflect.Interface
keyConf.followPointers = false
valConf.printType = v.Type().Elem().Kind() == reflect.Interface
for _, k := range SortKeys(v.MapKeys()) {
sk := formatAny(k, keyConf, m)
sv := formatAny(v.MapIndex(k), valConf, m)
ss = append(ss, fmt.Sprintf("%s: %s", sk, sv))
}
s := fmt.Sprintf("{%s}", strings.Join(ss, ", "))
if conf.printType {
return v.Type().String() + s
}
return s
case reflect.Struct:
var ss []string
subConf := conf
subConf.printType = true
for i := 0; i < v.NumField(); i++ {
vv := v.Field(i)
if isZero(vv) {
continue // Elide zero value fields
}
name := v.Type().Field(i).Name
subConf.UseStringer = conf.UseStringer
s := formatAny(vv, subConf, m)
ss = append(ss, fmt.Sprintf("%s: %s", name, s))
}
s := fmt.Sprintf("{%s}", strings.Join(ss, ", "))
if conf.printType {
return v.Type().String() + s
}
return s
default:
panic(fmt.Sprintf("%v kind not handled", v.Kind()))
}
}
func formatString(s string) string {
// Use quoted string if it the same length as a raw string literal.
// Otherwise, attempt to use the raw string form.
qs := strconv.Quote(s)
if len(qs) == 1+len(s)+1 {
return qs
}
// Disallow newlines to ensure output is a single line.
// Only allow printable runes for readability purposes.
rawInvalid := func(r rune) bool {
return r == '`' || r == '\n' || !unicode.IsPrint(r)
}
if strings.IndexFunc(s, rawInvalid) < 0 {
return "`" + s + "`"
}
return qs
}
func formatPrimitive(t reflect.Type, v interface{}, conf FormatConfig) string {
if conf.printType && (conf.PrintPrimitiveType || t.PkgPath() != "") {
return fmt.Sprintf("%v(%v)", t, v)
}
return fmt.Sprintf("%v", v)
}
func formatPointer(v reflect.Value, conf FormatConfig) string {
p := v.Pointer()
if !conf.realPointers {
p = 0 // For deterministic printing purposes
}
s := formatHex(uint64(p))
if conf.printType {
return fmt.Sprintf("(%v)(%s)", v.Type(), s)
}
return s
}
func formatHex(u uint64) string {
var f string
switch {
case u <= 0xff:
f = "0x%02x"
case u <= 0xffff:
f = "0x%04x"
case u <= 0xffffff:
f = "0x%06x"
case u <= 0xffffffff:
f = "0x%08x"
case u <= 0xffffffffff:
f = "0x%010x"
case u <= 0xffffffffffff:
f = "0x%012x"
case u <= 0xffffffffffffff:
f = "0x%014x"
case u <= 0xffffffffffffffff:
f = "0x%016x"
}
return fmt.Sprintf(f, u)
}
// isZero reports whether v is the zero value.
// This does not rely on Interface and so can be used on unexported fields.
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return v.Bool() == false
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Complex64, reflect.Complex128:
return v.Complex() == 0
case reflect.String:
return v.String() == ""
case reflect.UnsafePointer:
return v.Pointer() == 0
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice:
return v.IsNil()
case reflect.Array:
for i := 0; i < v.Len(); i++ {
if !isZero(v.Index(i)) {
return false
}
}
return true
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
if !isZero(v.Field(i)) {
return false
}