ezshare/server/interceptors/auth.go

90 lines
2.4 KiB
Go
Raw Normal View History

2021-12-05 13:55:18 +00:00
package interceptors
import (
"context"
2022-01-13 17:40:15 +00:00
"git.t-juice.club/torjus/ezshare/certs"
"git.t-juice.club/torjus/ezshare/pb"
"git.t-juice.club/torjus/ezshare/store"
2021-12-06 06:55:30 +00:00
"go.uber.org/zap"
2021-12-05 13:55:18 +00:00
"google.golang.org/grpc"
2021-12-06 17:04:51 +00:00
"google.golang.org/grpc/codes"
2021-12-05 13:55:18 +00:00
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
2021-12-06 17:04:51 +00:00
"google.golang.org/grpc/status"
2021-12-05 13:55:18 +00:00
)
type ContextKey string
var ContextKeyRole ContextKey = "role"
2021-12-06 05:53:49 +00:00
var ContextKeyUserID ContextKey = "userid"
2021-12-05 13:55:18 +00:00
2021-12-06 17:04:51 +00:00
func NewAuthInterceptor(s store.UserStore, certSvc *certs.CertService, logger *zap.SugaredLogger) grpc.UnaryServerInterceptor {
2021-12-06 05:53:49 +00:00
// TODO: Verify that cert is signed by our ca
2021-12-05 13:55:18 +00:00
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
2021-12-06 17:04:51 +00:00
// Login doesn't need valid cert
if info.FullMethod == "/ezshare.UserService/Login" {
return handler(ctx, req)
}
2021-12-05 13:55:18 +00:00
p, ok := peer.FromContext(ctx)
if ok {
tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo)
if ok {
if len(tlsInfo.State.PeerCertificates) == 1 {
cert := tlsInfo.State.PeerCertificates[0]
2021-12-06 17:04:51 +00:00
// Check if valid
id, err := certSvc.VerifyClient(cert.Raw)
if err != nil {
logger.Infow("Rejected client due to invalid cert", "error", "err", "remote_addr", p.Addr.String(), "method", info.FullMethod)
return nil, status.Error(codes.Unauthenticated, "invalid client certificate")
}
2021-12-05 13:55:18 +00:00
user, err := s.GetUser(id)
if err == nil {
newCtx := context.WithValue(ctx, ContextKeyRole, user.UserRole)
2021-12-06 05:53:49 +00:00
newCtx = context.WithValue(newCtx, ContextKeyUserID, user.Id)
2021-12-06 17:04:51 +00:00
logger.Debugw("Authenticated user.", "username", user.Username, "role", user.UserRole.String(), "method", info.FullMethod)
2021-12-05 13:55:18 +00:00
return handler(newCtx, req)
}
}
}
}
newCtx := context.WithValue(ctx, ContextKeyRole, pb.User_UNKNOWN)
return handler(newCtx, req)
}
}
func RoleFromContext(ctx context.Context) pb.User_Role {
value := ctx.Value(ContextKeyRole)
if value == nil {
return pb.User_UNKNOWN
}
role, ok := value.(pb.User_Role)
if ok {
return role
}
return pb.User_UNKNOWN
}
2021-12-06 05:53:49 +00:00
func UserIDFromContext(ctx context.Context) string {
value := ctx.Value(ContextKeyUserID)
if value == nil {
return ""
}
id, ok := value.(string)
if ok {
return id
}
return ""
}
2021-12-07 05:51:14 +00:00
func RoleAtLeast(ctx context.Context, role pb.User_Role) bool {
ctxRole := RoleFromContext(ctx)
2021-12-08 04:42:25 +00:00
return ctxRole >= role
2021-12-07 05:51:14 +00:00
}