package honeypot

import (
	"context"
	"io"
	"net"
	"os"
	"time"

	gossh "golang.org/x/crypto/ssh"

	"github.uio.no/torjus/apiary/config"

	"github.com/gliderlabs/ssh"
	"github.com/google/uuid"
	"github.uio.no/torjus/apiary/honeypot/store"
	"github.uio.no/torjus/apiary/models"
	"go.uber.org/zap"
)

type HoneypotServer struct {
	attemptStore      store.LoginAttemptStore
	attemptsCallbacks []func(l models.LoginAttempt)

	sshServer *ssh.Server

	Logger *zap.SugaredLogger
}

func NewHoneypotServer(cfg config.HoneypotConfig, store store.LoginAttemptStore) (*HoneypotServer, error) {
	var hs HoneypotServer
	hs.attemptStore = store
	hs.Logger = zap.NewNop().Sugar()

	hs.sshServer = &ssh.Server{
		Addr:            cfg.ListenAddr,
		PasswordHandler: hs.passwordHandler,
		ConnCallback:    hs.connCallback,
		Handler:         handler,
		Version:         "OpenSSH_7.4p1 Debian-10+deb9u6",
	}

	if cfg.HostKeyPath != "" {
		f, err := os.Open(cfg.HostKeyPath)
		if err != nil {
			return nil, err
		}
		pemBytes, err := io.ReadAll(f)
		if err != nil {
			return nil, err
		}
		signer, err := gossh.ParsePrivateKey(pemBytes)
		if err != nil {
			return nil, err
		}

		hs.sshServer.AddHostKey(signer)
	}

	return &hs, nil
}

func (hs *HoneypotServer) ListenAndServe() error {
	return hs.sshServer.ListenAndServe()
}

func (hs *HoneypotServer) Shutdown(ctx context.Context) error {
	return hs.sshServer.Shutdown(ctx)
}

func (hs *HoneypotServer) AddLoginCallback(c func(l models.LoginAttempt)) {
	hs.attemptsCallbacks = append(hs.attemptsCallbacks, c)
}

func (hs *HoneypotServer) passwordHandler(ctx ssh.Context, password string) bool {
	sessUUID, ok := ctx.Value("uuid").(uuid.UUID)
	if !ok {
		hs.Logger.Warn("Unable to get session UUID")
		return false
	}

	la := models.LoginAttempt{
		Date:             time.Now(),
		RemoteIP:         ipFromAddr(ctx.RemoteAddr().String()),
		Username:         ctx.User(),
		Password:         password,
		SSHClientVersion: ctx.ClientVersion(),
		ConnectionUUID:   sessUUID,
	}
	country := hs.LookupCountry(la.RemoteIP)
	la.Country = country
	hs.Logger.Infow("Login attempt",
		"remote_ip", la.RemoteIP.String(),
		"username", la.Username,
		"password", la.Password)

	if err := hs.attemptStore.AddAttempt(&la); err != nil {
		hs.Logger.Warnf("Error adding attempt to store")
	}

	for _, cFunc := range hs.attemptsCallbacks {
		cFunc(la)
	}

	return false
}

func (s *HoneypotServer) connCallback(ctx ssh.Context, conn net.Conn) net.Conn {
	throttledConn := newThrottledConn(conn)
	ctx.SetValue("uuid", throttledConn.ID)
	throttledConn.SetSpeed(2048)
	return throttledConn
}

func handler(session ssh.Session) {
	_, _ = io.WriteString(session, "[root@hostname ~]#")
	session.Exit(1)
}

func ipFromAddr(addr string) net.IP {
	host, _, err := net.SplitHostPort(addr)
	if err != nil {
		return nil
	}

	return net.ParseIP(host)
}