authentricity/internal/webui/jwt.go

141 lines
3.3 KiB
Go

package webui
import (
"context"
"fmt"
"net/http"
"net/url"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwe"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/lestrrat-go/jwx/v2/jwt/openid"
"go.e43.eu/authentricity/internal/models"
"go.uber.org/zap"
)
type tokenCtxKey struct{}
func getUserToken(ctx context.Context) openid.Token {
tok, ok := ctx.Value(tokenCtxKey{}).(openid.Token)
zap.S().Debugf("getUserToken %+v", tok)
if ok {
return tok
} else {
return nil
}
}
func requireLogin(w http.ResponseWriter, r *http.Request) bool {
tok := getUserToken(r.Context())
if tok != nil {
return true
}
http.Redirect(w, r, "/login?next="+url.QueryEscape(r.URL.String()), http.StatusFound)
return false
}
func (s *Service) buildTokenForUser(
ctx context.Context,
user *models.UserRecord,
) (jwt.Token, error) {
bld := openid.NewBuilder().
JwtID(uuid.New().String()).
IssuedAt(time.Now()).
NotBefore(time.Now()).
Expiration(time.Now().AddDate(0, 0, 1)).
Subject(user.UUID.String())
if user.UserName != "" {
bld.PreferredUsername(user.UserName)
}
if user.EmailAddress != "" {
bld.Email(user.EmailAddress)
}
if user.RealName != "" {
bld.Name(user.RealName)
}
groups, _, err := s.store.GetUserGroups(ctx, user.UUID)
if err != nil {
return nil, err
}
groupStrs := make([]string, len(groups))
for i, gid := range groups {
groupStrs[i] = gid.String()
}
bld.Claim("authentricity.groups", groupStrs)
return bld.Build()
}
func (s *Service) serializeCookieToken(tok jwt.Token) ([]byte, error) {
opts := []jwt.EncryptOption{
jwt.WithKey(s.cookieKey.Algorithm(), s.cookieKey),
}
if enc, ok := s.cookieKey.Get(jwe.ContentEncryptionKey); ok {
var cea jwa.ContentEncryptionAlgorithm
if err := cea.Accept(enc); err != nil {
return nil, fmt.Errorf("Parsing 'enc' field of key: %w", err)
}
opts = append(opts, jwt.WithEncryptOption(jwe.WithContentEncryption(cea)))
}
return jwt.NewSerializer().Encrypt(opts...).Serialize(tok)
}
func (s *Service) buildTokenCookie(data []byte, maxAge int) http.Cookie {
return http.Cookie{
Name: s.tokenCookie,
Value: string(data),
Domain: s.cookieDomain,
Path: "/",
Secure: s.cookieSecure,
HttpOnly: true,
MaxAge: maxAge,
}
}
func (s *Service) tokenValidationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ck, err := r.Cookie(s.tokenCookie)
zap.L().Debug("Token", zap.Any("token", ck))
if err == nil {
body, err := jwe.Decrypt([]byte(ck.Value), jwe.WithKey(s.cookieKey.Algorithm(), s.cookieKey))
if err != nil {
zap.L().Error("Error decrypting token", zap.Error(err))
s.renderError(w)
return
}
tok, err := jwt.Parse(body, jwt.WithVerify(false), jwt.WithToken(openid.New()))
if err != nil {
zap.L().Error("Error parsing token", zap.Error(err))
s.renderError(w)
return
}
err = jwt.Validate(tok)
if err != nil {
zap.L().Warn("Token validation failed", zap.Error(err))
} else {
ctx := context.WithValue(r.Context(), tokenCtxKey{}, tok)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
} else if err == http.ErrNoCookie {
next.ServeHTTP(w, r)
} else {
zap.L().Error("Error fetching cookie", zap.Error(err))
s.renderError(w)
}
})
}