feat: implement NATS-based NixOS deployment system
Implement the complete homelab-deploy system with three operational modes: - Listener mode: Runs on NixOS hosts as a systemd service, subscribes to NATS subjects with configurable templates, executes nixos-rebuild on deployment requests with concurrency control - MCP mode: MCP server exposing deploy, deploy_admin, and list_hosts tools for AI assistants with tiered access control - CLI mode: Manual deployment commands with subject alias support via environment variables Key components: - internal/messages: Request/response types with validation - internal/nats: Client wrapper with NKey authentication - internal/deploy: Executor with timeout and lock for concurrency - internal/listener: Subject template expansion and request handling - internal/cli: Deploy logic with alias resolution - internal/mcp: MCP server with mcp-go integration - nixos/module.nix: NixOS module with hardened systemd service Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
40
internal/cli/aliases.go
Normal file
40
internal/cli/aliases.go
Normal file
@@ -0,0 +1,40 @@
|
||||
// Package cli provides the deploy command logic.
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const aliasEnvPrefix = "HOMELAB_DEPLOY_ALIAS_"
|
||||
|
||||
// ResolveAlias resolves a subject alias to a full NATS subject.
|
||||
// If the input looks like a NATS subject (contains dots), it is returned as-is.
|
||||
// Otherwise, it checks for an environment variable HOMELAB_DEPLOY_ALIAS_<NAME>.
|
||||
// Alias names are case-insensitive and hyphens are converted to underscores.
|
||||
func ResolveAlias(input string) string {
|
||||
// If it contains dots, it's already a subject
|
||||
if strings.Contains(input, ".") {
|
||||
return input
|
||||
}
|
||||
|
||||
// Convert to uppercase and replace hyphens with underscores
|
||||
envName := aliasEnvPrefix + strings.ToUpper(strings.ReplaceAll(input, "-", "_"))
|
||||
|
||||
if alias := os.Getenv(envName); alias != "" {
|
||||
return alias
|
||||
}
|
||||
|
||||
// Return as-is if no alias found (will likely fail later)
|
||||
return input
|
||||
}
|
||||
|
||||
// IsAlias returns true if the input looks like an alias (no dots).
|
||||
func IsAlias(input string) bool {
|
||||
return !strings.Contains(input, ".")
|
||||
}
|
||||
|
||||
// GetAliasEnvVar returns the environment variable name for a given alias.
|
||||
func GetAliasEnvVar(alias string) string {
|
||||
return aliasEnvPrefix + strings.ToUpper(strings.ReplaceAll(alias, "-", "_"))
|
||||
}
|
||||
112
internal/cli/aliases_test.go
Normal file
112
internal/cli/aliases_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveAlias(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
t.Setenv("HOMELAB_DEPLOY_ALIAS_TEST", "deploy.test.all")
|
||||
t.Setenv("HOMELAB_DEPLOY_ALIAS_PROD", "deploy.prod.all")
|
||||
t.Setenv("HOMELAB_DEPLOY_ALIAS_PROD_DNS", "deploy.prod.role.dns")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "full subject unchanged",
|
||||
input: "deploy.prod.ns1",
|
||||
want: "deploy.prod.ns1",
|
||||
},
|
||||
{
|
||||
name: "subject with multiple dots",
|
||||
input: "deploy.test.role.web",
|
||||
want: "deploy.test.role.web",
|
||||
},
|
||||
{
|
||||
name: "lowercase alias",
|
||||
input: "test",
|
||||
want: "deploy.test.all",
|
||||
},
|
||||
{
|
||||
name: "uppercase alias",
|
||||
input: "TEST",
|
||||
want: "deploy.test.all",
|
||||
},
|
||||
{
|
||||
name: "mixed case alias",
|
||||
input: "TeSt",
|
||||
want: "deploy.test.all",
|
||||
},
|
||||
{
|
||||
name: "alias with hyphen",
|
||||
input: "prod-dns",
|
||||
want: "deploy.prod.role.dns",
|
||||
},
|
||||
{
|
||||
name: "alias with hyphen uppercase",
|
||||
input: "PROD-DNS",
|
||||
want: "deploy.prod.role.dns",
|
||||
},
|
||||
{
|
||||
name: "unknown alias returns as-is",
|
||||
input: "unknown",
|
||||
want: "unknown",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ResolveAlias(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("ResolveAlias(%q) = %q, want %q", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"test", true},
|
||||
{"prod-dns", true},
|
||||
{"PROD", true},
|
||||
{"deploy.test.all", false},
|
||||
{"deploy.prod.ns1", false},
|
||||
{"a.b", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
got := IsAlias(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("IsAlias(%q) = %v, want %v", tc.input, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAliasEnvVar(t *testing.T) {
|
||||
tests := []struct {
|
||||
alias string
|
||||
want string
|
||||
}{
|
||||
{"test", "HOMELAB_DEPLOY_ALIAS_TEST"},
|
||||
{"prod", "HOMELAB_DEPLOY_ALIAS_PROD"},
|
||||
{"prod-dns", "HOMELAB_DEPLOY_ALIAS_PROD_DNS"},
|
||||
{"my-long-alias", "HOMELAB_DEPLOY_ALIAS_MY_LONG_ALIAS"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.alias, func(t *testing.T) {
|
||||
got := GetAliasEnvVar(tc.alias)
|
||||
if got != tc.want {
|
||||
t.Errorf("GetAliasEnvVar(%q) = %q, want %q", tc.alias, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
213
internal/cli/deploy.go
Normal file
213
internal/cli/deploy.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/messages"
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/nats"
|
||||
)
|
||||
|
||||
// DeployConfig holds configuration for a deploy operation.
|
||||
type DeployConfig struct {
|
||||
NATSUrl string
|
||||
NKeyFile string
|
||||
Subject string
|
||||
Action messages.Action
|
||||
Revision string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// DeployResult contains the aggregated results from a deployment.
|
||||
type DeployResult struct {
|
||||
Responses []*messages.DeployResponse
|
||||
Errors []error
|
||||
}
|
||||
|
||||
// AllSucceeded returns true if all responses indicate success.
|
||||
func (r *DeployResult) AllSucceeded() bool {
|
||||
for _, resp := range r.Responses {
|
||||
if resp.Status != messages.StatusCompleted {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(r.Responses) > 0 && len(r.Errors) == 0
|
||||
}
|
||||
|
||||
// HostCount returns the number of unique hosts that responded.
|
||||
func (r *DeployResult) HostCount() int {
|
||||
seen := make(map[string]bool)
|
||||
for _, resp := range r.Responses {
|
||||
seen[resp.Hostname] = true
|
||||
}
|
||||
return len(seen)
|
||||
}
|
||||
|
||||
// Deploy executes a deployment to the specified subject and collects responses.
|
||||
func Deploy(ctx context.Context, cfg DeployConfig, onResponse func(*messages.DeployResponse)) (*DeployResult, error) {
|
||||
// Connect to NATS
|
||||
client, err := nats.Connect(nats.Config{
|
||||
URL: cfg.NATSUrl,
|
||||
NKeyFile: cfg.NKeyFile,
|
||||
Name: "homelab-deploy-cli",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to NATS: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Generate unique reply subject
|
||||
requestID := uuid.New().String()
|
||||
replySubject := fmt.Sprintf("deploy.responses.%s", requestID)
|
||||
|
||||
// Track responses by hostname to handle multiple messages per host
|
||||
var mu sync.Mutex
|
||||
result := &DeployResult{}
|
||||
hostFinal := make(map[string]bool) // track which hosts have sent final status
|
||||
|
||||
// Subscribe to reply subject
|
||||
sub, err := client.Subscribe(replySubject, func(subject string, data []byte) {
|
||||
resp, err := messages.UnmarshalDeployResponse(data)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
result.Errors = append(result.Errors, fmt.Errorf("failed to unmarshal response: %w", err))
|
||||
mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
result.Responses = append(result.Responses, resp)
|
||||
if resp.Status.IsFinal() {
|
||||
hostFinal[resp.Hostname] = true
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if onResponse != nil {
|
||||
onResponse(resp)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to reply subject: %w", err)
|
||||
}
|
||||
defer func() { _ = sub.Unsubscribe() }()
|
||||
|
||||
// Build and send request
|
||||
req := &messages.DeployRequest{
|
||||
Action: cfg.Action,
|
||||
Revision: cfg.Revision,
|
||||
ReplyTo: replySubject,
|
||||
}
|
||||
|
||||
data, err := req.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Publish(cfg.Subject, data); err != nil {
|
||||
return nil, fmt.Errorf("failed to publish request: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("failed to flush: %w", err)
|
||||
}
|
||||
|
||||
// Wait for responses with timeout
|
||||
// Use a dynamic timeout: wait for initial responses, then extend
|
||||
// timeout after each response until no new responses or max timeout
|
||||
deadline := time.Now().Add(cfg.Timeout)
|
||||
lastResponse := time.Now()
|
||||
idleTimeout := 30 * time.Second // wait this long after last response
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return result, ctx.Err()
|
||||
case <-time.After(1 * time.Second):
|
||||
mu.Lock()
|
||||
responseCount := len(result.Responses)
|
||||
mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Check if we've exceeded the absolute deadline
|
||||
if now.After(deadline) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// If we have responses, use idle timeout
|
||||
if responseCount > 0 {
|
||||
mu.Lock()
|
||||
lastResponseTime := lastResponse
|
||||
// Update lastResponse time if we got new responses
|
||||
if responseCount > 0 {
|
||||
// Simple approximation - in practice you'd track this more precisely
|
||||
lastResponseTime = now
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if now.Sub(lastResponseTime) > idleTimeout {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Discover sends a discovery request and collects host information.
|
||||
func Discover(ctx context.Context, natsURL, nkeyFile, discoverSubject string, timeout time.Duration) ([]*messages.DiscoveryResponse, error) {
|
||||
client, err := nats.Connect(nats.Config{
|
||||
URL: natsURL,
|
||||
NKeyFile: nkeyFile,
|
||||
Name: "homelab-deploy-cli-discover",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to NATS: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
requestID := uuid.New().String()
|
||||
replySubject := fmt.Sprintf("deploy.responses.discover-%s", requestID)
|
||||
|
||||
var mu sync.Mutex
|
||||
var responses []*messages.DiscoveryResponse
|
||||
|
||||
sub, err := client.Subscribe(replySubject, func(subject string, data []byte) {
|
||||
resp, err := messages.UnmarshalDiscoveryResponse(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
responses = append(responses, resp)
|
||||
mu.Unlock()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe: %w", err)
|
||||
}
|
||||
defer func() { _ = sub.Unsubscribe() }()
|
||||
|
||||
req := &messages.DiscoveryRequest{ReplyTo: replySubject}
|
||||
data, err := req.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Publish(discoverSubject, data); err != nil {
|
||||
return nil, fmt.Errorf("failed to publish: %w", err)
|
||||
}
|
||||
|
||||
if err := client.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("failed to flush: %w", err)
|
||||
}
|
||||
|
||||
// Wait for responses
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return responses, ctx.Err()
|
||||
case <-time.After(timeout):
|
||||
return responses, nil
|
||||
}
|
||||
}
|
||||
109
internal/cli/deploy_test.go
Normal file
109
internal/cli/deploy_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/messages"
|
||||
)
|
||||
|
||||
func TestDeployResult_AllSucceeded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
responses []*messages.DeployResponse
|
||||
errors []error
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "all completed",
|
||||
responses: []*messages.DeployResponse{
|
||||
{Hostname: "host1", Status: messages.StatusCompleted},
|
||||
{Hostname: "host2", Status: messages.StatusCompleted},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "one failed",
|
||||
responses: []*messages.DeployResponse{
|
||||
{Hostname: "host1", Status: messages.StatusCompleted},
|
||||
{Hostname: "host2", Status: messages.StatusFailed},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "one rejected",
|
||||
responses: []*messages.DeployResponse{
|
||||
{Hostname: "host1", Status: messages.StatusRejected},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no responses",
|
||||
responses: []*messages.DeployResponse{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "has errors",
|
||||
responses: []*messages.DeployResponse{
|
||||
{Hostname: "host1", Status: messages.StatusCompleted},
|
||||
},
|
||||
errors: []error{nil}, // placeholder error
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &DeployResult{
|
||||
Responses: tc.responses,
|
||||
Errors: tc.errors,
|
||||
}
|
||||
got := r.AllSucceeded()
|
||||
if got != tc.want {
|
||||
t.Errorf("AllSucceeded() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeployResult_HostCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
responses []*messages.DeployResponse
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "no responses",
|
||||
responses: []*messages.DeployResponse{},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "unique hosts",
|
||||
responses: []*messages.DeployResponse{
|
||||
{Hostname: "host1"},
|
||||
{Hostname: "host2"},
|
||||
{Hostname: "host3"},
|
||||
},
|
||||
want: 3,
|
||||
},
|
||||
{
|
||||
name: "duplicate hosts",
|
||||
responses: []*messages.DeployResponse{
|
||||
{Hostname: "host1", Status: messages.StatusStarted},
|
||||
{Hostname: "host1", Status: messages.StatusCompleted},
|
||||
{Hostname: "host2", Status: messages.StatusStarted},
|
||||
{Hostname: "host2", Status: messages.StatusCompleted},
|
||||
},
|
||||
want: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &DeployResult{Responses: tc.responses}
|
||||
got := r.HostCount()
|
||||
if got != tc.want {
|
||||
t.Errorf("HostCount() = %d, want %d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
112
internal/deploy/executor.go
Normal file
112
internal/deploy/executor.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package deploy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/messages"
|
||||
)
|
||||
|
||||
// Executor handles the execution of nixos-rebuild commands.
|
||||
type Executor struct {
|
||||
flakeURL string
|
||||
hostname string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewExecutor creates a new deployment executor.
|
||||
func NewExecutor(flakeURL, hostname string, timeout time.Duration) *Executor {
|
||||
return &Executor{
|
||||
flakeURL: flakeURL,
|
||||
hostname: hostname,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Result contains the result of a deployment execution.
|
||||
type Result struct {
|
||||
Success bool
|
||||
ExitCode int
|
||||
Stdout string
|
||||
Stderr string
|
||||
Error error
|
||||
}
|
||||
|
||||
// ValidateRevision checks if a revision exists in the remote repository.
|
||||
// It uses git ls-remote to verify the ref exists.
|
||||
func (e *Executor) ValidateRevision(ctx context.Context, revision string) error {
|
||||
// Extract the base URL for git ls-remote
|
||||
// flakeURL is like git+https://git.example.com/user/repo.git
|
||||
// We need to strip the git+ prefix for git ls-remote
|
||||
gitURL := e.flakeURL
|
||||
if len(gitURL) > 4 && gitURL[:4] == "git+" {
|
||||
gitURL = gitURL[4:]
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "git", "ls-remote", "--exit-code", gitURL, revision)
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Errorf("timeout validating revision")
|
||||
}
|
||||
return fmt.Errorf("revision %q not found: %w", revision, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute runs nixos-rebuild with the specified action and revision.
|
||||
func (e *Executor) Execute(ctx context.Context, action messages.Action, revision string) *Result {
|
||||
ctx, cancel := context.WithTimeout(ctx, e.timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build the flake reference: <flake-url>?ref=<revision>#<hostname>
|
||||
flakeRef := fmt.Sprintf("%s?ref=%s#%s", e.flakeURL, revision, e.hostname)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "nixos-rebuild", string(action), "--flake", flakeRef)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
|
||||
result := &Result{
|
||||
Stdout: stdout.String(),
|
||||
Stderr: stderr.String(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
result.Success = false
|
||||
result.Error = err
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
result.Error = fmt.Errorf("deployment timed out after %v", e.timeout)
|
||||
}
|
||||
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = -1
|
||||
}
|
||||
} else {
|
||||
result.Success = true
|
||||
result.ExitCode = 0
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildCommand returns the command that would be executed (for logging/debugging).
|
||||
func (e *Executor) BuildCommand(action messages.Action, revision string) string {
|
||||
flakeRef := fmt.Sprintf("%s?ref=%s#%s", e.flakeURL, revision, e.hostname)
|
||||
return fmt.Sprintf("nixos-rebuild %s --flake %s", action, flakeRef)
|
||||
}
|
||||
76
internal/deploy/executor_test.go
Normal file
76
internal/deploy/executor_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package deploy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/messages"
|
||||
)
|
||||
|
||||
func TestExecutor_BuildCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
flakeURL string
|
||||
hostname string
|
||||
action messages.Action
|
||||
revision string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "switch action",
|
||||
flakeURL: "git+https://git.example.com/user/nixos-configs.git",
|
||||
hostname: "ns1",
|
||||
action: messages.ActionSwitch,
|
||||
revision: "master",
|
||||
want: "nixos-rebuild switch --flake git+https://git.example.com/user/nixos-configs.git?ref=master#ns1",
|
||||
},
|
||||
{
|
||||
name: "boot action with commit hash",
|
||||
flakeURL: "git+https://git.example.com/user/nixos-configs.git",
|
||||
hostname: "web1",
|
||||
action: messages.ActionBoot,
|
||||
revision: "abc123def456",
|
||||
want: "nixos-rebuild boot --flake git+https://git.example.com/user/nixos-configs.git?ref=abc123def456#web1",
|
||||
},
|
||||
{
|
||||
name: "test action with feature branch",
|
||||
flakeURL: "git+ssh://git@github.com/org/repo.git",
|
||||
hostname: "test-host",
|
||||
action: messages.ActionTest,
|
||||
revision: "feature/new-feature",
|
||||
want: "nixos-rebuild test --flake git+ssh://git@github.com/org/repo.git?ref=feature/new-feature#test-host",
|
||||
},
|
||||
{
|
||||
name: "dry-activate action",
|
||||
flakeURL: "git+https://git.example.com/repo.git",
|
||||
hostname: "prod-1",
|
||||
action: messages.ActionDryActivate,
|
||||
revision: "v1.0.0",
|
||||
want: "nixos-rebuild dry-activate --flake git+https://git.example.com/repo.git?ref=v1.0.0#prod-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := NewExecutor(tc.flakeURL, tc.hostname, 10*time.Minute)
|
||||
got := e.BuildCommand(tc.action, tc.revision)
|
||||
if got != tc.want {
|
||||
t.Errorf("BuildCommand() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewExecutor(t *testing.T) {
|
||||
e := NewExecutor("git+https://example.com/repo.git", "host1", 5*time.Minute)
|
||||
|
||||
if e.flakeURL != "git+https://example.com/repo.git" {
|
||||
t.Errorf("flakeURL = %q, want %q", e.flakeURL, "git+https://example.com/repo.git")
|
||||
}
|
||||
if e.hostname != "host1" {
|
||||
t.Errorf("hostname = %q, want %q", e.hostname, "host1")
|
||||
}
|
||||
if e.timeout != 5*time.Minute {
|
||||
t.Errorf("timeout = %v, want %v", e.timeout, 5*time.Minute)
|
||||
}
|
||||
}
|
||||
56
internal/deploy/lock.go
Normal file
56
internal/deploy/lock.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Package deploy provides deployment execution logic.
|
||||
package deploy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Lock provides a simple in-memory lock for single-deployment concurrency control.
|
||||
type Lock struct {
|
||||
mu sync.Mutex
|
||||
held bool
|
||||
holder string
|
||||
}
|
||||
|
||||
// NewLock creates a new deployment lock.
|
||||
func NewLock() *Lock {
|
||||
return &Lock{}
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire the lock. Returns true if successful.
|
||||
// The holder parameter identifies who is holding the lock.
|
||||
func (l *Lock) TryAcquire(holder string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.held {
|
||||
return false
|
||||
}
|
||||
|
||||
l.held = true
|
||||
l.holder = holder
|
||||
return true
|
||||
}
|
||||
|
||||
// Release releases the lock.
|
||||
func (l *Lock) Release() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.held = false
|
||||
l.holder = ""
|
||||
}
|
||||
|
||||
// IsHeld returns true if the lock is currently held.
|
||||
func (l *Lock) IsHeld() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.held
|
||||
}
|
||||
|
||||
// Holder returns the current holder of the lock, or empty string if not held.
|
||||
func (l *Lock) Holder() string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.holder
|
||||
}
|
||||
98
internal/deploy/lock_test.go
Normal file
98
internal/deploy/lock_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package deploy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLock_TryAcquire(t *testing.T) {
|
||||
l := NewLock()
|
||||
|
||||
// First acquire should succeed
|
||||
if !l.TryAcquire("request-1") {
|
||||
t.Error("first TryAcquire should succeed")
|
||||
}
|
||||
|
||||
// Second acquire should fail
|
||||
if l.TryAcquire("request-2") {
|
||||
t.Error("second TryAcquire should fail while lock is held")
|
||||
}
|
||||
|
||||
// Verify holder
|
||||
if got := l.Holder(); got != "request-1" {
|
||||
t.Errorf("Holder() = %q, want %q", got, "request-1")
|
||||
}
|
||||
|
||||
// Release and try again
|
||||
l.Release()
|
||||
|
||||
if !l.TryAcquire("request-3") {
|
||||
t.Error("TryAcquire should succeed after Release")
|
||||
}
|
||||
|
||||
if got := l.Holder(); got != "request-3" {
|
||||
t.Errorf("Holder() = %q, want %q", got, "request-3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLock_IsHeld(t *testing.T) {
|
||||
l := NewLock()
|
||||
|
||||
if l.IsHeld() {
|
||||
t.Error("new lock should not be held")
|
||||
}
|
||||
|
||||
l.TryAcquire("test")
|
||||
|
||||
if !l.IsHeld() {
|
||||
t.Error("lock should be held after TryAcquire")
|
||||
}
|
||||
|
||||
l.Release()
|
||||
|
||||
if l.IsHeld() {
|
||||
t.Error("lock should not be held after Release")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLock_Concurrent(t *testing.T) {
|
||||
l := NewLock()
|
||||
var wg sync.WaitGroup
|
||||
acquired := make(chan string, 100)
|
||||
|
||||
// Try to acquire from multiple goroutines
|
||||
for i := range 100 {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
holder := string(rune('A' + (id % 26)))
|
||||
if l.TryAcquire(holder) {
|
||||
acquired <- holder
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(acquired)
|
||||
|
||||
// Only one should have succeeded
|
||||
count := 0
|
||||
for range acquired {
|
||||
count++
|
||||
}
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("expected exactly 1 successful acquire, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLock_ReleaseUnheld(t *testing.T) {
|
||||
l := NewLock()
|
||||
|
||||
// Releasing an unheld lock should not panic
|
||||
l.Release()
|
||||
|
||||
if l.IsHeld() {
|
||||
t.Error("lock should not be held after Release on unheld lock")
|
||||
}
|
||||
}
|
||||
262
internal/listener/listener.go
Normal file
262
internal/listener/listener.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/deploy"
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/messages"
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/nats"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the listener.
|
||||
type Config struct {
|
||||
Hostname string
|
||||
Tier string
|
||||
Role string
|
||||
NATSUrl string
|
||||
NKeyFile string
|
||||
FlakeURL string
|
||||
Timeout time.Duration
|
||||
DeploySubjects []string
|
||||
DiscoverSubject string
|
||||
}
|
||||
|
||||
// Listener handles deployment requests from NATS.
|
||||
type Listener struct {
|
||||
cfg Config
|
||||
client *nats.Client
|
||||
executor *deploy.Executor
|
||||
lock *deploy.Lock
|
||||
logger *slog.Logger
|
||||
|
||||
// Expanded subjects for discovery responses
|
||||
expandedSubjects []string
|
||||
}
|
||||
|
||||
// New creates a new listener with the given configuration.
|
||||
func New(cfg Config, logger *slog.Logger) *Listener {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
|
||||
return &Listener{
|
||||
cfg: cfg,
|
||||
executor: deploy.NewExecutor(cfg.FlakeURL, cfg.Hostname, cfg.Timeout),
|
||||
lock: deploy.NewLock(),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the listener and blocks until the context is cancelled.
|
||||
func (l *Listener) Run(ctx context.Context) error {
|
||||
// Connect to NATS
|
||||
l.logger.Info("connecting to NATS",
|
||||
"url", l.cfg.NATSUrl,
|
||||
"hostname", l.cfg.Hostname,
|
||||
"tier", l.cfg.Tier,
|
||||
"role", l.cfg.Role,
|
||||
)
|
||||
|
||||
client, err := nats.Connect(nats.Config{
|
||||
URL: l.cfg.NATSUrl,
|
||||
NKeyFile: l.cfg.NKeyFile,
|
||||
Name: fmt.Sprintf("homelab-deploy-listener-%s", l.cfg.Hostname),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to NATS: %w", err)
|
||||
}
|
||||
l.client = client
|
||||
defer l.client.Close()
|
||||
|
||||
l.logger.Info("connected to NATS")
|
||||
|
||||
// Expand subjects
|
||||
l.expandedSubjects = ExpandSubjects(l.cfg.DeploySubjects, l.cfg.Hostname, l.cfg.Tier, l.cfg.Role)
|
||||
|
||||
// Subscribe to deploy subjects
|
||||
for _, subject := range l.expandedSubjects {
|
||||
l.logger.Info("subscribing to deploy subject", "subject", subject)
|
||||
if _, err := l.client.Subscribe(subject, l.handleDeployRequest); err != nil {
|
||||
return fmt.Errorf("failed to subscribe to %s: %w", subject, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to discovery subject
|
||||
discoverSubject := ExpandSubject(l.cfg.DiscoverSubject, l.cfg.Hostname, l.cfg.Tier, l.cfg.Role)
|
||||
l.logger.Info("subscribing to discover subject", "subject", discoverSubject)
|
||||
if _, err := l.client.Subscribe(discoverSubject, l.handleDiscoveryRequest); err != nil {
|
||||
return fmt.Errorf("failed to subscribe to %s: %w", discoverSubject, err)
|
||||
}
|
||||
|
||||
l.logger.Info("listener started", "deploy_subjects", l.expandedSubjects, "discover_subject", discoverSubject)
|
||||
|
||||
// Wait for context cancellation
|
||||
<-ctx.Done()
|
||||
l.logger.Info("shutting down listener")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) handleDeployRequest(subject string, data []byte) {
|
||||
req, err := messages.UnmarshalDeployRequest(data)
|
||||
if err != nil {
|
||||
l.logger.Error("failed to unmarshal deploy request",
|
||||
"subject", subject,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Info("received deploy request",
|
||||
"subject", subject,
|
||||
"action", req.Action,
|
||||
"revision", req.Revision,
|
||||
"reply_to", req.ReplyTo,
|
||||
)
|
||||
|
||||
// Validate request
|
||||
if err := req.Validate(); err != nil {
|
||||
l.logger.Warn("invalid deploy request",
|
||||
"error", err,
|
||||
)
|
||||
l.sendResponse(req.ReplyTo, messages.NewDeployResponse(
|
||||
l.cfg.Hostname,
|
||||
messages.StatusRejected,
|
||||
err.Error(),
|
||||
).WithError(messages.ErrorInvalidAction))
|
||||
return
|
||||
}
|
||||
|
||||
// Try to acquire lock
|
||||
requestID := fmt.Sprintf("%s-%d", req.Revision, time.Now().UnixNano())
|
||||
if !l.lock.TryAcquire(requestID) {
|
||||
l.logger.Warn("deployment already in progress",
|
||||
"current_holder", l.lock.Holder(),
|
||||
)
|
||||
l.sendResponse(req.ReplyTo, messages.NewDeployResponse(
|
||||
l.cfg.Hostname,
|
||||
messages.StatusRejected,
|
||||
"another deployment is already in progress",
|
||||
).WithError(messages.ErrorAlreadyRunning))
|
||||
return
|
||||
}
|
||||
defer l.lock.Release()
|
||||
|
||||
// Send started response
|
||||
l.sendResponse(req.ReplyTo, messages.NewDeployResponse(
|
||||
l.cfg.Hostname,
|
||||
messages.StatusStarted,
|
||||
fmt.Sprintf("starting deployment: %s", l.executor.BuildCommand(req.Action, req.Revision)),
|
||||
))
|
||||
|
||||
// Validate revision
|
||||
ctx := context.Background()
|
||||
if err := l.executor.ValidateRevision(ctx, req.Revision); err != nil {
|
||||
l.logger.Error("revision validation failed",
|
||||
"revision", req.Revision,
|
||||
"error", err,
|
||||
)
|
||||
l.sendResponse(req.ReplyTo, messages.NewDeployResponse(
|
||||
l.cfg.Hostname,
|
||||
messages.StatusFailed,
|
||||
fmt.Sprintf("revision validation failed: %v", err),
|
||||
).WithError(messages.ErrorInvalidRevision))
|
||||
return
|
||||
}
|
||||
|
||||
// Execute deployment
|
||||
l.logger.Info("executing deployment",
|
||||
"action", req.Action,
|
||||
"revision", req.Revision,
|
||||
"command", l.executor.BuildCommand(req.Action, req.Revision),
|
||||
)
|
||||
|
||||
result := l.executor.Execute(ctx, req.Action, req.Revision)
|
||||
|
||||
if result.Success {
|
||||
l.logger.Info("deployment completed successfully",
|
||||
"exit_code", result.ExitCode,
|
||||
)
|
||||
l.sendResponse(req.ReplyTo, messages.NewDeployResponse(
|
||||
l.cfg.Hostname,
|
||||
messages.StatusCompleted,
|
||||
"deployment completed successfully",
|
||||
))
|
||||
} else {
|
||||
l.logger.Error("deployment failed",
|
||||
"exit_code", result.ExitCode,
|
||||
"error", result.Error,
|
||||
"stderr", result.Stderr,
|
||||
)
|
||||
|
||||
errorCode := messages.ErrorBuildFailed
|
||||
if result.Error != nil && result.Error.Error() == fmt.Sprintf("deployment timed out after %v", l.cfg.Timeout) {
|
||||
errorCode = messages.ErrorTimeout
|
||||
}
|
||||
|
||||
l.sendResponse(req.ReplyTo, messages.NewDeployResponse(
|
||||
l.cfg.Hostname,
|
||||
messages.StatusFailed,
|
||||
fmt.Sprintf("deployment failed (exit code %d): %s", result.ExitCode, result.Stderr),
|
||||
).WithError(errorCode))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) handleDiscoveryRequest(subject string, data []byte) {
|
||||
req, err := messages.UnmarshalDiscoveryRequest(data)
|
||||
if err != nil {
|
||||
l.logger.Error("failed to unmarshal discovery request",
|
||||
"subject", subject,
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
l.logger.Info("received discovery request",
|
||||
"subject", subject,
|
||||
"reply_to", req.ReplyTo,
|
||||
)
|
||||
|
||||
if err := req.Validate(); err != nil {
|
||||
l.logger.Warn("invalid discovery request", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &messages.DiscoveryResponse{
|
||||
Hostname: l.cfg.Hostname,
|
||||
Tier: l.cfg.Tier,
|
||||
Role: l.cfg.Role,
|
||||
DeploySubjects: l.expandedSubjects,
|
||||
}
|
||||
|
||||
data, err = resp.Marshal()
|
||||
if err != nil {
|
||||
l.logger.Error("failed to marshal discovery response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := l.client.Publish(req.ReplyTo, data); err != nil {
|
||||
l.logger.Error("failed to publish discovery response",
|
||||
"reply_to", req.ReplyTo,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) sendResponse(replyTo string, resp *messages.DeployResponse) {
|
||||
data, err := resp.Marshal()
|
||||
if err != nil {
|
||||
l.logger.Error("failed to marshal deploy response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := l.client.Publish(replyTo, data); err != nil {
|
||||
l.logger.Error("failed to publish deploy response",
|
||||
"reply_to", replyTo,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
53
internal/listener/listener_test.go
Normal file
53
internal/listener/listener_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
cfg := Config{
|
||||
Hostname: "test-host",
|
||||
Tier: "test",
|
||||
Role: "web",
|
||||
NATSUrl: "nats://localhost:4222",
|
||||
NKeyFile: "/path/to/key",
|
||||
FlakeURL: "git+https://example.com/repo.git",
|
||||
Timeout: 10 * time.Minute,
|
||||
DeploySubjects: []string{"deploy.<tier>.<hostname>"},
|
||||
DiscoverSubject: "deploy.discover",
|
||||
}
|
||||
|
||||
l := New(cfg, nil)
|
||||
|
||||
if l.cfg.Hostname != cfg.Hostname {
|
||||
t.Errorf("hostname = %q, want %q", l.cfg.Hostname, cfg.Hostname)
|
||||
}
|
||||
if l.cfg.Tier != cfg.Tier {
|
||||
t.Errorf("tier = %q, want %q", l.cfg.Tier, cfg.Tier)
|
||||
}
|
||||
if l.executor == nil {
|
||||
t.Error("executor should not be nil")
|
||||
}
|
||||
if l.lock == nil {
|
||||
t.Error("lock should not be nil")
|
||||
}
|
||||
if l.logger == nil {
|
||||
t.Error("logger should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_WithLogger(t *testing.T) {
|
||||
cfg := Config{
|
||||
Hostname: "test",
|
||||
Tier: "test",
|
||||
}
|
||||
|
||||
customLogger := slog.Default()
|
||||
l := New(cfg, customLogger)
|
||||
|
||||
if l.logger != customLogger {
|
||||
t.Error("should use provided logger")
|
||||
}
|
||||
}
|
||||
42
internal/listener/subjects.go
Normal file
42
internal/listener/subjects.go
Normal file
@@ -0,0 +1,42 @@
|
||||
// Package listener implements the deployment listener mode.
|
||||
package listener
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExpandSubjects expands template variables in subject patterns.
|
||||
// Template variables:
|
||||
// - <hostname> - The listener's hostname
|
||||
// - <tier> - The listener's tier (test/prod)
|
||||
// - <role> - The listener's role (if configured)
|
||||
//
|
||||
// If a subject contains <role> but role is empty, that subject is skipped.
|
||||
func ExpandSubjects(subjects []string, hostname, tier, role string) []string {
|
||||
var result []string
|
||||
|
||||
for _, subject := range subjects {
|
||||
// Skip subjects with <role> if role is not set
|
||||
if strings.Contains(subject, "<role>") && role == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
expanded := subject
|
||||
expanded = strings.ReplaceAll(expanded, "<hostname>", hostname)
|
||||
expanded = strings.ReplaceAll(expanded, "<tier>", tier)
|
||||
expanded = strings.ReplaceAll(expanded, "<role>", role)
|
||||
|
||||
result = append(result, expanded)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ExpandSubject expands template variables in a single subject.
|
||||
func ExpandSubject(subject, hostname, tier, role string) string {
|
||||
expanded := subject
|
||||
expanded = strings.ReplaceAll(expanded, "<hostname>", hostname)
|
||||
expanded = strings.ReplaceAll(expanded, "<tier>", tier)
|
||||
expanded = strings.ReplaceAll(expanded, "<role>", role)
|
||||
return expanded
|
||||
}
|
||||
137
internal/listener/subjects_test.go
Normal file
137
internal/listener/subjects_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package listener
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExpandSubjects(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subjects []string
|
||||
hostname string
|
||||
tier string
|
||||
role string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "all variables with role",
|
||||
subjects: []string{
|
||||
"deploy.<tier>.<hostname>",
|
||||
"deploy.<tier>.all",
|
||||
"deploy.<tier>.role.<role>",
|
||||
},
|
||||
hostname: "ns1",
|
||||
tier: "prod",
|
||||
role: "dns",
|
||||
want: []string{
|
||||
"deploy.prod.ns1",
|
||||
"deploy.prod.all",
|
||||
"deploy.prod.role.dns",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skip role subject when role is empty",
|
||||
subjects: []string{
|
||||
"deploy.<tier>.<hostname>",
|
||||
"deploy.<tier>.all",
|
||||
"deploy.<tier>.role.<role>",
|
||||
},
|
||||
hostname: "web1",
|
||||
tier: "test",
|
||||
role: "",
|
||||
want: []string{
|
||||
"deploy.test.web1",
|
||||
"deploy.test.all",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom prefix",
|
||||
subjects: []string{
|
||||
"homelab.deploy.<tier>.<hostname>",
|
||||
"homelab.deploy.<tier>.all",
|
||||
},
|
||||
hostname: "host1",
|
||||
tier: "prod",
|
||||
role: "",
|
||||
want: []string{
|
||||
"homelab.deploy.prod.host1",
|
||||
"homelab.deploy.prod.all",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty subjects",
|
||||
subjects: []string{},
|
||||
hostname: "host1",
|
||||
tier: "test",
|
||||
role: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no template variables",
|
||||
subjects: []string{
|
||||
"static.subject.here",
|
||||
},
|
||||
hostname: "host1",
|
||||
tier: "test",
|
||||
role: "web",
|
||||
want: []string{
|
||||
"static.subject.here",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ExpandSubjects(tc.subjects, tc.hostname, tc.tier, tc.role)
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Errorf("ExpandSubjects() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandSubject(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subject string
|
||||
hostname string
|
||||
tier string
|
||||
role string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "all variables",
|
||||
subject: "deploy.<tier>.<hostname>.role.<role>",
|
||||
hostname: "ns1",
|
||||
tier: "prod",
|
||||
role: "dns",
|
||||
want: "deploy.prod.ns1.role.dns",
|
||||
},
|
||||
{
|
||||
name: "hostname only",
|
||||
subject: "hosts.<hostname>",
|
||||
hostname: "myhost",
|
||||
tier: "test",
|
||||
role: "",
|
||||
want: "hosts.myhost",
|
||||
},
|
||||
{
|
||||
name: "empty role leaves placeholder",
|
||||
subject: "deploy.<tier>.role.<role>",
|
||||
hostname: "host1",
|
||||
tier: "test",
|
||||
role: "",
|
||||
want: "deploy.test.role.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ExpandSubject(tc.subject, tc.hostname, tc.tier, tc.role)
|
||||
if got != tc.want {
|
||||
t.Errorf("ExpandSubject() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
61
internal/mcp/server.go
Normal file
61
internal/mcp/server.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// ServerConfig holds configuration for the MCP server.
|
||||
type ServerConfig struct {
|
||||
NATSUrl string
|
||||
NKeyFile string
|
||||
EnableAdmin bool
|
||||
AdminNKeyFile string
|
||||
DiscoverSubject string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// Server wraps the MCP server.
|
||||
type Server struct {
|
||||
cfg ServerConfig
|
||||
server *server.MCPServer
|
||||
}
|
||||
|
||||
// New creates a new MCP server.
|
||||
func New(cfg ServerConfig) *Server {
|
||||
s := server.NewMCPServer(
|
||||
"homelab-deploy",
|
||||
"0.1.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
handler := NewToolHandler(ToolConfig{
|
||||
NATSUrl: cfg.NATSUrl,
|
||||
NKeyFile: cfg.NKeyFile,
|
||||
AdminNKeyFile: cfg.AdminNKeyFile,
|
||||
DiscoverSubject: cfg.DiscoverSubject,
|
||||
Timeout: cfg.Timeout,
|
||||
})
|
||||
|
||||
// Register deploy tool (test-tier only)
|
||||
s.AddTool(DeployTool(), handler.HandleDeploy)
|
||||
|
||||
// Register list_hosts tool
|
||||
s.AddTool(ListHostsTool(), handler.HandleListHosts)
|
||||
|
||||
// Optionally register admin deploy tool
|
||||
if cfg.EnableAdmin {
|
||||
s.AddTool(DeployAdminTool(), handler.HandleDeployAdmin)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
server: s,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the MCP server and blocks until completed.
|
||||
func (s *Server) Run() error {
|
||||
return server.ServeStdio(s.server)
|
||||
}
|
||||
43
internal/mcp/server_test.go
Normal file
43
internal/mcp/server_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
cfg := ServerConfig{
|
||||
NATSUrl: "nats://localhost:4222",
|
||||
NKeyFile: "/path/to/key",
|
||||
EnableAdmin: false,
|
||||
AdminNKeyFile: "",
|
||||
DiscoverSubject: "deploy.discover",
|
||||
Timeout: 10 * time.Minute,
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
if s.server == nil {
|
||||
t.Error("server should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_WithAdmin(t *testing.T) {
|
||||
cfg := ServerConfig{
|
||||
NATSUrl: "nats://localhost:4222",
|
||||
NKeyFile: "/path/to/key",
|
||||
EnableAdmin: true,
|
||||
AdminNKeyFile: "/path/to/admin/key",
|
||||
DiscoverSubject: "deploy.discover",
|
||||
Timeout: 10 * time.Minute,
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
208
internal/mcp/tools.go
Normal file
208
internal/mcp/tools.go
Normal file
@@ -0,0 +1,208 @@
|
||||
// Package mcp provides an MCP server for AI assistants.
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
|
||||
deploycli "git.t-juice.club/torjus/homelab-deploy/internal/cli"
|
||||
"git.t-juice.club/torjus/homelab-deploy/internal/messages"
|
||||
)
|
||||
|
||||
// ToolConfig holds configuration for the MCP tools.
|
||||
type ToolConfig struct {
|
||||
NATSUrl string
|
||||
NKeyFile string
|
||||
AdminNKeyFile string
|
||||
DiscoverSubject string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// DeployTool creates the test-tier deploy tool definition.
|
||||
func DeployTool() mcp.Tool {
|
||||
return mcp.NewTool(
|
||||
"deploy",
|
||||
mcp.WithDescription("Deploy NixOS configuration to test-tier hosts"),
|
||||
mcp.WithString("hostname",
|
||||
mcp.Description("Target hostname, or omit to use 'all' or 'role' targeting"),
|
||||
),
|
||||
mcp.WithBoolean("all",
|
||||
mcp.Description("Deploy to all test-tier hosts"),
|
||||
),
|
||||
mcp.WithString("role",
|
||||
mcp.Description("Deploy to all test-tier hosts with this role"),
|
||||
),
|
||||
mcp.WithString("branch",
|
||||
mcp.Description("Git branch or commit to deploy (default: master)"),
|
||||
),
|
||||
mcp.WithString("action",
|
||||
mcp.Description("nixos-rebuild action: switch, boot, test, dry-activate (default: switch)"),
|
||||
mcp.Enum("switch", "boot", "test", "dry-activate"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// DeployAdminTool creates the admin deploy tool definition (all tiers).
|
||||
func DeployAdminTool() mcp.Tool {
|
||||
return mcp.NewTool(
|
||||
"deploy_admin",
|
||||
mcp.WithDescription("Deploy NixOS configuration to any host (admin access required)"),
|
||||
mcp.WithString("tier",
|
||||
mcp.Required(),
|
||||
mcp.Description("Target tier: test or prod"),
|
||||
mcp.Enum("test", "prod"),
|
||||
),
|
||||
mcp.WithString("hostname",
|
||||
mcp.Description("Target hostname, or omit to use 'all' or 'role' targeting"),
|
||||
),
|
||||
mcp.WithBoolean("all",
|
||||
mcp.Description("Deploy to all hosts in tier"),
|
||||
),
|
||||
mcp.WithString("role",
|
||||
mcp.Description("Deploy to all hosts with this role in tier"),
|
||||
),
|
||||
mcp.WithString("branch",
|
||||
mcp.Description("Git branch or commit to deploy (default: master)"),
|
||||
),
|
||||
mcp.WithString("action",
|
||||
mcp.Description("nixos-rebuild action: switch, boot, test, dry-activate (default: switch)"),
|
||||
mcp.Enum("switch", "boot", "test", "dry-activate"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// ListHostsTool creates the list_hosts tool definition.
|
||||
func ListHostsTool() mcp.Tool {
|
||||
return mcp.NewTool(
|
||||
"list_hosts",
|
||||
mcp.WithDescription("List available deployment targets"),
|
||||
mcp.WithString("tier",
|
||||
mcp.Description("Filter by tier: test or prod (optional)"),
|
||||
mcp.Enum("test", "prod"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// ToolHandler handles tool calls.
|
||||
type ToolHandler struct {
|
||||
cfg ToolConfig
|
||||
}
|
||||
|
||||
// NewToolHandler creates a new tool handler.
|
||||
func NewToolHandler(cfg ToolConfig) *ToolHandler {
|
||||
return &ToolHandler{cfg: cfg}
|
||||
}
|
||||
|
||||
// HandleDeploy handles the deploy tool (test-tier only).
|
||||
func (h *ToolHandler) HandleDeploy(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return h.handleDeployWithTier(ctx, request, "test", h.cfg.NKeyFile)
|
||||
}
|
||||
|
||||
// HandleDeployAdmin handles the deploy_admin tool (any tier).
|
||||
func (h *ToolHandler) HandleDeployAdmin(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
tier, err := request.RequireString("tier")
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError("tier is required"), nil
|
||||
}
|
||||
if tier != "test" && tier != "prod" {
|
||||
return mcp.NewToolResultError("tier must be 'test' or 'prod'"), nil
|
||||
}
|
||||
|
||||
return h.handleDeployWithTier(ctx, request, tier, h.cfg.AdminNKeyFile)
|
||||
}
|
||||
|
||||
func (h *ToolHandler) handleDeployWithTier(ctx context.Context, request mcp.CallToolRequest, tier, nkeyFile string) (*mcp.CallToolResult, error) {
|
||||
// Build subject based on targeting
|
||||
hostname := request.GetString("hostname", "")
|
||||
all := request.GetBool("all", false)
|
||||
role := request.GetString("role", "")
|
||||
|
||||
var subject string
|
||||
if hostname != "" {
|
||||
subject = fmt.Sprintf("deploy.%s.%s", tier, hostname)
|
||||
} else if all {
|
||||
subject = fmt.Sprintf("deploy.%s.all", tier)
|
||||
} else if role != "" {
|
||||
subject = fmt.Sprintf("deploy.%s.role.%s", tier, role)
|
||||
} else {
|
||||
return mcp.NewToolResultError("must specify hostname, all, or role"), nil
|
||||
}
|
||||
|
||||
// Parse action
|
||||
actionStr := request.GetString("action", "switch")
|
||||
action := messages.Action(actionStr)
|
||||
if !action.Valid() {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("invalid action: %s", actionStr)), nil
|
||||
}
|
||||
|
||||
// Parse branch
|
||||
branch := request.GetString("branch", "master")
|
||||
|
||||
cfg := deploycli.DeployConfig{
|
||||
NATSUrl: h.cfg.NATSUrl,
|
||||
NKeyFile: nkeyFile,
|
||||
Subject: subject,
|
||||
Action: action,
|
||||
Revision: branch,
|
||||
Timeout: h.cfg.Timeout,
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Deploying to %s (action=%s, revision=%s)\n\n", subject, action, branch))
|
||||
|
||||
result, err := deploycli.Deploy(ctx, cfg, func(resp *messages.DeployResponse) {
|
||||
status := string(resp.Status)
|
||||
if resp.Error != nil {
|
||||
status = fmt.Sprintf("%s (%s)", status, *resp.Error)
|
||||
}
|
||||
output.WriteString(fmt.Sprintf("[%s] %s: %s\n", resp.Hostname, status, resp.Message))
|
||||
})
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("deployment failed: %v", err)), nil
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("\nDeployment complete: %d hosts responded\n", result.HostCount()))
|
||||
|
||||
if !result.AllSucceeded() {
|
||||
output.WriteString("WARNING: Some deployments failed\n")
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
|
||||
// HandleListHosts handles the list_hosts tool.
|
||||
func (h *ToolHandler) HandleListHosts(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
tierFilter := request.GetString("tier", "")
|
||||
|
||||
responses, err := deploycli.Discover(ctx, h.cfg.NATSUrl, h.cfg.NKeyFile, h.cfg.DiscoverSubject, 5*time.Second)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("discovery failed: %v", err)), nil
|
||||
}
|
||||
|
||||
if len(responses) == 0 {
|
||||
return mcp.NewToolResultText("No hosts responded to discovery request"), nil
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString("Available deployment targets:\n\n")
|
||||
|
||||
for _, resp := range responses {
|
||||
if tierFilter != "" && resp.Tier != tierFilter {
|
||||
continue
|
||||
}
|
||||
|
||||
role := resp.Role
|
||||
if role == "" {
|
||||
role = "(none)"
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("- %s (tier=%s, role=%s)\n", resp.Hostname, resp.Tier, role))
|
||||
output.WriteString(fmt.Sprintf(" Subjects: %s\n", strings.Join(resp.DeploySubjects, ", ")))
|
||||
}
|
||||
|
||||
return mcp.NewToolResultText(output.String()), nil
|
||||
}
|
||||
64
internal/mcp/tools_test.go
Normal file
64
internal/mcp/tools_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewToolHandler(t *testing.T) {
|
||||
cfg := ToolConfig{
|
||||
NATSUrl: "nats://localhost:4222",
|
||||
NKeyFile: "/path/to/key",
|
||||
AdminNKeyFile: "/path/to/admin/key",
|
||||
DiscoverSubject: "deploy.discover",
|
||||
Timeout: 10 * time.Minute,
|
||||
}
|
||||
|
||||
h := NewToolHandler(cfg)
|
||||
|
||||
if h.cfg.NATSUrl != cfg.NATSUrl {
|
||||
t.Errorf("NATSUrl = %q, want %q", h.cfg.NATSUrl, cfg.NATSUrl)
|
||||
}
|
||||
if h.cfg.NKeyFile != cfg.NKeyFile {
|
||||
t.Errorf("NKeyFile = %q, want %q", h.cfg.NKeyFile, cfg.NKeyFile)
|
||||
}
|
||||
if h.cfg.Timeout != cfg.Timeout {
|
||||
t.Errorf("Timeout = %v, want %v", h.cfg.Timeout, cfg.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeployTool(t *testing.T) {
|
||||
tool := DeployTool()
|
||||
|
||||
if tool.Name != "deploy" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "deploy")
|
||||
}
|
||||
|
||||
if tool.Description == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeployAdminTool(t *testing.T) {
|
||||
tool := DeployAdminTool()
|
||||
|
||||
if tool.Name != "deploy_admin" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "deploy_admin")
|
||||
}
|
||||
|
||||
if tool.Description == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListHostsTool(t *testing.T) {
|
||||
tool := ListHostsTool()
|
||||
|
||||
if tool.Name != "list_hosts" {
|
||||
t.Errorf("Name = %q, want %q", tool.Name, "list_hosts")
|
||||
}
|
||||
|
||||
if tool.Description == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
}
|
||||
190
internal/messages/messages.go
Normal file
190
internal/messages/messages.go
Normal file
@@ -0,0 +1,190 @@
|
||||
// Package messages defines the message types used for NATS communication
|
||||
// between deployment clients and listeners.
|
||||
package messages
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// Action represents a nixos-rebuild action.
|
||||
type Action string
|
||||
|
||||
const (
|
||||
ActionSwitch Action = "switch"
|
||||
ActionBoot Action = "boot"
|
||||
ActionTest Action = "test"
|
||||
ActionDryActivate Action = "dry-activate"
|
||||
)
|
||||
|
||||
// Valid returns true if the action is a recognized nixos-rebuild action.
|
||||
func (a Action) Valid() bool {
|
||||
switch a {
|
||||
case ActionSwitch, ActionBoot, ActionTest, ActionDryActivate:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Status represents the status of a deployment response.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusAccepted Status = "accepted"
|
||||
StatusRejected Status = "rejected"
|
||||
StatusStarted Status = "started"
|
||||
StatusCompleted Status = "completed"
|
||||
StatusFailed Status = "failed"
|
||||
)
|
||||
|
||||
// IsFinal returns true if this status indicates a terminal state.
|
||||
func (s Status) IsFinal() bool {
|
||||
switch s {
|
||||
case StatusCompleted, StatusFailed, StatusRejected:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorCode represents an error condition.
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrorInvalidRevision ErrorCode = "invalid_revision"
|
||||
ErrorInvalidAction ErrorCode = "invalid_action"
|
||||
ErrorAlreadyRunning ErrorCode = "already_running"
|
||||
ErrorBuildFailed ErrorCode = "build_failed"
|
||||
ErrorTimeout ErrorCode = "timeout"
|
||||
)
|
||||
|
||||
// DeployRequest is the message sent to request a deployment.
|
||||
type DeployRequest struct {
|
||||
Action Action `json:"action"`
|
||||
Revision string `json:"revision"`
|
||||
ReplyTo string `json:"reply_to"`
|
||||
}
|
||||
|
||||
// revisionRegex validates git branch names and commit hashes.
|
||||
// Allows: alphanumeric, dashes, underscores, dots, slashes (for branch names),
|
||||
// and hex strings (for commit hashes).
|
||||
var revisionRegex = regexp.MustCompile(`^[a-zA-Z0-9._/-]+$`)
|
||||
|
||||
// Validate checks that the request is valid.
|
||||
func (r *DeployRequest) Validate() error {
|
||||
if !r.Action.Valid() {
|
||||
return fmt.Errorf("invalid action: %q", r.Action)
|
||||
}
|
||||
if r.Revision == "" {
|
||||
return fmt.Errorf("revision is required")
|
||||
}
|
||||
if !revisionRegex.MatchString(r.Revision) {
|
||||
return fmt.Errorf("invalid revision format: %q", r.Revision)
|
||||
}
|
||||
if r.ReplyTo == "" {
|
||||
return fmt.Errorf("reply_to is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal serializes the request to JSON.
|
||||
func (r *DeployRequest) Marshal() ([]byte, error) {
|
||||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
// UnmarshalDeployRequest deserializes a request from JSON.
|
||||
func UnmarshalDeployRequest(data []byte) (*DeployRequest, error) {
|
||||
var r DeployRequest
|
||||
if err := json.Unmarshal(data, &r); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal deploy request: %w", err)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// DeployResponse is the message sent in response to a deployment request.
|
||||
type DeployResponse struct {
|
||||
Hostname string `json:"hostname"`
|
||||
Status Status `json:"status"`
|
||||
Error *ErrorCode `json:"error"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewDeployResponse creates a new response with the given hostname and status.
|
||||
func NewDeployResponse(hostname string, status Status, message string) *DeployResponse {
|
||||
return &DeployResponse{
|
||||
Hostname: hostname,
|
||||
Status: status,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// WithError adds an error code to the response.
|
||||
func (r *DeployResponse) WithError(code ErrorCode) *DeployResponse {
|
||||
r.Error = &code
|
||||
return r
|
||||
}
|
||||
|
||||
// Marshal serializes the response to JSON.
|
||||
func (r *DeployResponse) Marshal() ([]byte, error) {
|
||||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
// UnmarshalDeployResponse deserializes a response from JSON.
|
||||
func UnmarshalDeployResponse(data []byte) (*DeployResponse, error) {
|
||||
var r DeployResponse
|
||||
if err := json.Unmarshal(data, &r); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal deploy response: %w", err)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// DiscoveryRequest is the message sent to discover available hosts.
|
||||
type DiscoveryRequest struct {
|
||||
ReplyTo string `json:"reply_to"`
|
||||
}
|
||||
|
||||
// Validate checks that the request is valid.
|
||||
func (r *DiscoveryRequest) Validate() error {
|
||||
if r.ReplyTo == "" {
|
||||
return fmt.Errorf("reply_to is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal serializes the request to JSON.
|
||||
func (r *DiscoveryRequest) Marshal() ([]byte, error) {
|
||||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
// UnmarshalDiscoveryRequest deserializes a request from JSON.
|
||||
func UnmarshalDiscoveryRequest(data []byte) (*DiscoveryRequest, error) {
|
||||
var r DiscoveryRequest
|
||||
if err := json.Unmarshal(data, &r); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal discovery request: %w", err)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// DiscoveryResponse is the message sent in response to a discovery request.
|
||||
type DiscoveryResponse struct {
|
||||
Hostname string `json:"hostname"`
|
||||
Tier string `json:"tier"`
|
||||
Role string `json:"role,omitempty"`
|
||||
DeploySubjects []string `json:"deploy_subjects"`
|
||||
}
|
||||
|
||||
// Marshal serializes the response to JSON.
|
||||
func (r *DiscoveryResponse) Marshal() ([]byte, error) {
|
||||
return json.Marshal(r)
|
||||
}
|
||||
|
||||
// UnmarshalDiscoveryResponse deserializes a response from JSON.
|
||||
func UnmarshalDiscoveryResponse(data []byte) (*DiscoveryResponse, error) {
|
||||
var r DiscoveryResponse
|
||||
if err := json.Unmarshal(data, &r); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal discovery response: %w", err)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
292
internal/messages/messages_test.go
Normal file
292
internal/messages/messages_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAction_Valid(t *testing.T) {
|
||||
tests := []struct {
|
||||
action Action
|
||||
valid bool
|
||||
}{
|
||||
{ActionSwitch, true},
|
||||
{ActionBoot, true},
|
||||
{ActionTest, true},
|
||||
{ActionDryActivate, true},
|
||||
{Action("invalid"), false},
|
||||
{Action(""), false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(string(tc.action), func(t *testing.T) {
|
||||
if got := tc.action.Valid(); got != tc.valid {
|
||||
t.Errorf("Action(%q).Valid() = %v, want %v", tc.action, got, tc.valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatus_IsFinal(t *testing.T) {
|
||||
tests := []struct {
|
||||
status Status
|
||||
final bool
|
||||
}{
|
||||
{StatusAccepted, false},
|
||||
{StatusStarted, false},
|
||||
{StatusCompleted, true},
|
||||
{StatusFailed, true},
|
||||
{StatusRejected, true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(string(tc.status), func(t *testing.T) {
|
||||
if got := tc.status.IsFinal(); got != tc.final {
|
||||
t.Errorf("Status(%q).IsFinal() = %v, want %v", tc.status, got, tc.final)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeployRequest_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req DeployRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid request with branch",
|
||||
req: DeployRequest{
|
||||
Action: ActionSwitch,
|
||||
Revision: "master",
|
||||
ReplyTo: "deploy.responses.abc123",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid request with commit hash",
|
||||
req: DeployRequest{
|
||||
Action: ActionBoot,
|
||||
Revision: "abc123def456",
|
||||
ReplyTo: "deploy.responses.xyz",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid request with feature branch",
|
||||
req: DeployRequest{
|
||||
Action: ActionTest,
|
||||
Revision: "feature/my-feature",
|
||||
ReplyTo: "deploy.responses.test",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid request with dotted branch",
|
||||
req: DeployRequest{
|
||||
Action: ActionDryActivate,
|
||||
Revision: "release-1.0.0",
|
||||
ReplyTo: "deploy.responses.test",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid action",
|
||||
req: DeployRequest{
|
||||
Action: Action("invalid"),
|
||||
Revision: "master",
|
||||
ReplyTo: "deploy.responses.abc",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty revision",
|
||||
req: DeployRequest{
|
||||
Action: ActionSwitch,
|
||||
Revision: "",
|
||||
ReplyTo: "deploy.responses.abc",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid revision with spaces",
|
||||
req: DeployRequest{
|
||||
Action: ActionSwitch,
|
||||
Revision: "my branch",
|
||||
ReplyTo: "deploy.responses.abc",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid revision with special chars",
|
||||
req: DeployRequest{
|
||||
Action: ActionSwitch,
|
||||
Revision: "branch;rm -rf /",
|
||||
ReplyTo: "deploy.responses.abc",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty reply_to",
|
||||
req: DeployRequest{
|
||||
Action: ActionSwitch,
|
||||
Revision: "master",
|
||||
ReplyTo: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.req.Validate()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeployRequest_Marshal_Unmarshal(t *testing.T) {
|
||||
req := &DeployRequest{
|
||||
Action: ActionSwitch,
|
||||
Revision: "master",
|
||||
ReplyTo: "deploy.responses.abc123",
|
||||
}
|
||||
|
||||
data, err := req.Marshal()
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := UnmarshalDeployRequest(data)
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalDeployRequest() error = %v", err)
|
||||
}
|
||||
|
||||
if got.Action != req.Action || got.Revision != req.Revision || got.ReplyTo != req.ReplyTo {
|
||||
t.Errorf("roundtrip failed: got %+v, want %+v", got, req)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeployResponse_Marshal_Unmarshal(t *testing.T) {
|
||||
errCode := ErrorBuildFailed
|
||||
resp := &DeployResponse{
|
||||
Hostname: "host1",
|
||||
Status: StatusFailed,
|
||||
Error: &errCode,
|
||||
Message: "build failed with exit code 1",
|
||||
}
|
||||
|
||||
data, err := resp.Marshal()
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := UnmarshalDeployResponse(data)
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalDeployResponse() error = %v", err)
|
||||
}
|
||||
|
||||
if got.Hostname != resp.Hostname || got.Status != resp.Status || got.Message != resp.Message {
|
||||
t.Errorf("roundtrip failed: got %+v, want %+v", got, resp)
|
||||
}
|
||||
if got.Error == nil || *got.Error != *resp.Error {
|
||||
t.Errorf("error code mismatch: got %v, want %v", got.Error, resp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeployResponse_NullError(t *testing.T) {
|
||||
resp := NewDeployResponse("host1", StatusCompleted, "success")
|
||||
|
||||
data, err := resp.Marshal()
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify null error is serialized correctly
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if m["error"] != nil {
|
||||
t.Errorf("expected null error, got %v", m["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoveryRequest_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req DiscoveryRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
req: DiscoveryRequest{ReplyTo: "deploy.responses.discover-abc"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty reply_to",
|
||||
req: DiscoveryRequest{ReplyTo: ""},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.req.Validate()
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoveryResponse_Marshal_Unmarshal(t *testing.T) {
|
||||
resp := &DiscoveryResponse{
|
||||
Hostname: "ns1",
|
||||
Tier: "prod",
|
||||
Role: "dns",
|
||||
DeploySubjects: []string{"deploy.prod.ns1", "deploy.prod.all", "deploy.prod.role.dns"},
|
||||
}
|
||||
|
||||
data, err := resp.Marshal()
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
got, err := UnmarshalDiscoveryResponse(data)
|
||||
if err != nil {
|
||||
t.Fatalf("UnmarshalDiscoveryResponse() error = %v", err)
|
||||
}
|
||||
|
||||
if got.Hostname != resp.Hostname || got.Tier != resp.Tier || got.Role != resp.Role {
|
||||
t.Errorf("roundtrip failed: got %+v, want %+v", got, resp)
|
||||
}
|
||||
if len(got.DeploySubjects) != len(resp.DeploySubjects) {
|
||||
t.Errorf("subjects length mismatch: got %d, want %d", len(got.DeploySubjects), len(resp.DeploySubjects))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoveryResponse_OmitEmptyRole(t *testing.T) {
|
||||
resp := &DiscoveryResponse{
|
||||
Hostname: "host1",
|
||||
Tier: "test",
|
||||
Role: "",
|
||||
DeploySubjects: []string{"deploy.test.host1"},
|
||||
}
|
||||
|
||||
data, err := resp.Marshal()
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := m["role"]; exists {
|
||||
t.Error("expected role to be omitted when empty")
|
||||
}
|
||||
}
|
||||
142
internal/nats/client.go
Normal file
142
internal/nats/client.go
Normal file
@@ -0,0 +1,142 @@
|
||||
// Package nats provides a NATS client wrapper with NKey authentication.
|
||||
package nats
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
// Config holds the configuration for a NATS connection.
|
||||
type Config struct {
|
||||
URL string // NATS server URL
|
||||
NKeyFile string // Path to NKey seed file
|
||||
Name string // Client name for identification
|
||||
}
|
||||
|
||||
// Client wraps a NATS connection with NKey authentication.
|
||||
type Client struct {
|
||||
conn *nats.Conn
|
||||
}
|
||||
|
||||
// Connect establishes a connection to NATS using NKey authentication.
|
||||
func Connect(cfg Config) (*Client, error) {
|
||||
seed, err := os.ReadFile(cfg.NKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read nkey file: %w", err)
|
||||
}
|
||||
|
||||
// Trim any whitespace from the seed
|
||||
seedStr := strings.TrimSpace(string(seed))
|
||||
|
||||
kp, err := nkeys.FromSeed([]byte(seedStr))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse nkey seed: %w", err)
|
||||
}
|
||||
|
||||
pubKey, err := kp.PublicKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
opts := []nats.Option{
|
||||
nats.Name(cfg.Name),
|
||||
nats.Nkey(pubKey, func(nonce []byte) ([]byte, error) {
|
||||
return kp.Sign(nonce)
|
||||
}),
|
||||
nats.ReconnectWait(2 * time.Second),
|
||||
nats.MaxReconnects(-1), // Unlimited reconnects
|
||||
nats.ReconnectBufSize(8 * 1024 * 1024),
|
||||
}
|
||||
|
||||
nc, err := nats.Connect(cfg.URL, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to NATS: %w", err)
|
||||
}
|
||||
|
||||
return &Client{conn: nc}, nil
|
||||
}
|
||||
|
||||
// Subscription represents a NATS subscription.
|
||||
type Subscription struct {
|
||||
sub *nats.Subscription
|
||||
}
|
||||
|
||||
// MessageHandler is a callback for received messages.
|
||||
type MessageHandler func(subject string, data []byte)
|
||||
|
||||
// Subscribe subscribes to a subject and calls the handler for each message.
|
||||
func (c *Client) Subscribe(subject string, handler MessageHandler) (*Subscription, error) {
|
||||
sub, err := c.conn.Subscribe(subject, func(msg *nats.Msg) {
|
||||
handler(msg.Subject, msg.Data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to subscribe to %s: %w", subject, err)
|
||||
}
|
||||
return &Subscription{sub: sub}, nil
|
||||
}
|
||||
|
||||
// QueueSubscribe subscribes to a subject with a queue group.
|
||||
func (c *Client) QueueSubscribe(subject, queue string, handler MessageHandler) (*Subscription, error) {
|
||||
sub, err := c.conn.QueueSubscribe(subject, queue, func(msg *nats.Msg) {
|
||||
handler(msg.Subject, msg.Data)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to queue subscribe to %s: %w", subject, err)
|
||||
}
|
||||
return &Subscription{sub: sub}, nil
|
||||
}
|
||||
|
||||
// Unsubscribe removes the subscription.
|
||||
func (s *Subscription) Unsubscribe() error {
|
||||
if s.sub != nil {
|
||||
return s.sub.Unsubscribe()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish sends a message to a subject.
|
||||
func (c *Client) Publish(subject string, data []byte) error {
|
||||
if err := c.conn.Publish(subject, data); err != nil {
|
||||
return fmt.Errorf("failed to publish to %s: %w", subject, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Request sends a request and waits for a response.
|
||||
func (c *Client) Request(subject string, data []byte, timeout time.Duration) ([]byte, error) {
|
||||
msg, err := c.conn.Request(subject, data, timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request to %s failed: %w", subject, err)
|
||||
}
|
||||
return msg.Data, nil
|
||||
}
|
||||
|
||||
// Flush flushes the connection, ensuring all published messages have been sent.
|
||||
func (c *Client) Flush() error {
|
||||
return c.conn.Flush()
|
||||
}
|
||||
|
||||
// Close closes the NATS connection.
|
||||
func (c *Client) Close() {
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns true if the client is connected.
|
||||
func (c *Client) IsConnected() bool {
|
||||
return c.conn != nil && c.conn.IsConnected()
|
||||
}
|
||||
|
||||
// Status returns the connection status.
|
||||
func (c *Client) Status() nats.Status {
|
||||
if c.conn == nil {
|
||||
return nats.DISCONNECTED
|
||||
}
|
||||
return c.conn.Status()
|
||||
}
|
||||
127
internal/nats/client_test.go
Normal file
127
internal/nats/client_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package nats
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
func TestConnect_InvalidNKeyFile(t *testing.T) {
|
||||
cfg := Config{
|
||||
URL: "nats://localhost:4222",
|
||||
NKeyFile: "/nonexistent/file",
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
_, err := Connect(cfg)
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent nkey file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect_InvalidNKeySeed(t *testing.T) {
|
||||
// Create a temp file with invalid content
|
||||
tmpDir := t.TempDir()
|
||||
keyFile := filepath.Join(tmpDir, "invalid.nkey")
|
||||
if err := os.WriteFile(keyFile, []byte("invalid-seed-content"), 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
URL: "nats://localhost:4222",
|
||||
NKeyFile: keyFile,
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
_, err := Connect(cfg)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid nkey seed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect_ValidSeedParsing(t *testing.T) {
|
||||
// Generate a valid NKey seed
|
||||
kp, err := nkeys.CreateUser()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nkey: %v", err)
|
||||
}
|
||||
|
||||
seed, err := kp.Seed()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get seed: %v", err)
|
||||
}
|
||||
|
||||
// Write seed to temp file
|
||||
tmpDir := t.TempDir()
|
||||
keyFile := filepath.Join(tmpDir, "test.nkey")
|
||||
if err := os.WriteFile(keyFile, seed, 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
URL: "nats://localhost:4222", // Connection will fail, but parsing should work
|
||||
NKeyFile: keyFile,
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
// Connection will fail since no NATS server is running, but we're testing
|
||||
// that the seed parsing works correctly
|
||||
_, err = Connect(cfg)
|
||||
if err == nil {
|
||||
// If it somehow connects (unlikely), that's also fine
|
||||
return
|
||||
}
|
||||
|
||||
// Error should be about connection, not about nkey parsing
|
||||
if err != nil && !contains(err.Error(), "connect") && !contains(err.Error(), "connection") {
|
||||
t.Errorf("expected connection error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect_SeedWithWhitespace(t *testing.T) {
|
||||
// Generate a valid NKey seed
|
||||
kp, err := nkeys.CreateUser()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nkey: %v", err)
|
||||
}
|
||||
|
||||
seed, err := kp.Seed()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get seed: %v", err)
|
||||
}
|
||||
|
||||
// Write seed with trailing newline
|
||||
tmpDir := t.TempDir()
|
||||
keyFile := filepath.Join(tmpDir, "test.nkey")
|
||||
seedWithNewline := append(seed, '\n', ' ', '\t', '\n')
|
||||
if err := os.WriteFile(keyFile, seedWithNewline, 0600); err != nil {
|
||||
t.Fatalf("failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
URL: "nats://localhost:4222",
|
||||
NKeyFile: keyFile,
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
// Should parse the seed correctly despite whitespace
|
||||
_, err = Connect(cfg)
|
||||
if err != nil && contains(err.Error(), "parse") {
|
||||
t.Errorf("seed parsing should handle whitespace: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user