From 9cdcff72908afa34a38a6d9dd9b917182c7e2e4d Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Thu, 15 May 2025 17:45:46 -0400 Subject: [PATCH] add mtls/supervisor command --- .gitignore | 1 + mtls/supervisor/Makefile | 14 +++ mtls/supervisor/main.go | 236 +++++++++++++++++++++++++++++++++++ mtls/supervisor/test_mode.go | 171 +++++++++++++++++++++++++ 4 files changed, 422 insertions(+) create mode 100644 mtls/supervisor/Makefile create mode 100644 mtls/supervisor/main.go create mode 100644 mtls/supervisor/test_mode.go diff --git a/.gitignore b/.gitignore index 43c001f..a3fd46e 100644 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,7 @@ automation/bryston_ctl/cli/cli automation/bryston_ctl/client/client automation/bryston_ctl/server/server metrics/prometheus_http_discovery/prometheus_http_discovery +mtls/supervisor/supervisor /vendor/ diff --git a/mtls/supervisor/Makefile b/mtls/supervisor/Makefile new file mode 100644 index 0000000..1d7d424 --- /dev/null +++ b/mtls/supervisor/Makefile @@ -0,0 +1,14 @@ +GOSRC = $(wildcard *.go) +GOEXE = $(GOSRC:.go=) +GOBUILDFLAGS := -buildmode=pie -trimpath + +all: $(GOEXE) + +clean: + rm -fv $(GOEXE) + +.PHONY: all clean + +$(GOEXE): %: %.go + go build $(GOBUILDFLAGS) -o $@ $< + diff --git a/mtls/supervisor/main.go b/mtls/supervisor/main.go new file mode 100644 index 0000000..6e0c213 --- /dev/null +++ b/mtls/supervisor/main.go @@ -0,0 +1,236 @@ +package main + +import ( + "bytes" + "context" + "crypto" + "fmt" + "io" + "os" + "os/exec" + "os/signal" + "path" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "go.fuhry.dev/runtime/mtls/fsnotify" + "go.fuhry.dev/runtime/utils/debounce" + "go.fuhry.dev/runtime/utils/hashset" + "go.fuhry.dev/runtime/utils/log" +) + +func main() { + var restartLock sync.Mutex + var mtlsFiles []string + exitChan := make(chan struct{}) + + if os.Getenv("MTLS_SUPERVISOR_DEBUG") == "1" { + log.SetVerbosity(2) + log.SetLevel(log.DEBUG) + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + logger := log.Default().WithPrefix("supervisor") + defer cancel() + + if len(os.Args) < 2 { + logger.Panicf("usage: %s command-to-run", os.Args[0]) + } + cmdline := os.Args[1:] + abspath, err := exec.LookPath(cmdline[0]) + if err != nil { + logger.Panic(err) + return + } + command := exec.Command(abspath, cmdline[1:]...) + command.Stdin = os.Stdin + command.Stdout = os.Stdout + command.Stderr = os.Stderr + + mtlsBasenames := hashset.FromSlice[string]([]string{ + "cert.pem", + "chain.pem", + "fullchain.pem", + "privkey.pem", + "aio.pem", + "aio.p12", + "ca.pem", + "rootca.pem", + }) + + hashes := make(map[string][]byte) + var hashesLock sync.Mutex + + var scanPath func(dir string, depth int) + scanPath = func(dir string, depth int) { + if depth == 0 { + hashesLock.Lock() + defer hashesLock.Unlock() + } + + entries, err := os.ReadDir(dir) + if err != nil { + logger.Warningf("failed to scan directory %q: %v", dir, err) + return + } + for _, ent := range entries { + abspath := path.Join(dir, ent.Name()) + if ent.IsDir() && depth == 0 && !strings.HasPrefix(ent.Name(), "..") { + logger.Infof("descending to %s", abspath) + scanPath(abspath, depth+1) + } else { + if mtlsBasenames.Contains(ent.Name()) { + hash, err := hashFile(abspath) + if err == nil { + logger.V(1).Debugf("file %q hash: %s", abspath, bin2hex(hash)) + hashes[abspath] = hash + mtlsFiles = append(mtlsFiles, abspath) + } + } + } + } + } + + stop := func() { + err := syscall.Kill(command.Process.Pid, syscall.SIGTERM) + if err != nil { + logger.Warningf("failed to kill process %d: %v", command.Process.Pid, err) + } else { + // we sent sigterm, now wait for it to die + timeout := time.After(5 * time.Second) + logger.Infof("waiting for pid %d to die", command.Process.Pid) + for { + select { + case <-exitChan: + logger.Infof("subprocess has died, ok to restart") + command.Process, command.ProcessState = nil, nil + return + case <-timeout: + if command.Process != nil && command.ProcessState == nil { + logger.Warningf("deadline exceeded, killing pid %d", command.Process.Pid) + syscall.Kill(command.Process.Pid, syscall.SIGKILL) + timeout = time.After(5 * time.Second) + } + } + } + } + } + + restart := func() { + var err error + + restartLock.Lock() + defer restartLock.Unlock() + if command.Process != nil { + logger.Noticef("restarting subprocess: %+v", command) + stop() + } else { + logger.Infof("starting subprocess: %+v", command) + } + + err = command.Start() + if err != nil { + logger.Panicf("failed to start subprocess: %v", err) + } + go (func() { + command.Wait() + exitChan <- struct{}{} + })() + } + + baseDir := "/etc/ssl/mtls" + if os.Getenv("MTLS_SUPERVISOR_DEBUG") == "1" { + baseDir = startTestTree(ctx) + defer os.RemoveAll(baseDir) + } + scanPath(baseDir, 0) + + if len(mtlsFiles) < 1 { + logger.Panicf("no files were discovered under %s", baseDir) + return + } + + restart() + + debouncer := debounce.NewWithTimeout(restart, 500*time.Millisecond) + notify := func(path string, op fsnotify.Op) { + if !mtlsBasenames.Contains(filepath.Base(path)) { + return + } + + hashesLock.Lock() + defer hashesLock.Unlock() + if hash, ok := hashes[path]; ok { + newHash, err := hashFile(path) + if err != nil { + logger.Warnf("failed hashing file %q: %v", path, err) + return + } + if bytes.Equal(hash, newHash) { + logger.Infof("%s: contents not changed, so not triggering reload", path) + } else { + hashes[path] = newHash + logger.Noticef("file %q changed, subprocess will be restarted: %s != %s", path, bin2hex(hash), bin2hex(newHash)) + debouncer.Trigger() + } + } + } + + for _, f := range mtlsFiles { + logger.Infof("monitoring path: %s", f) + fsnotify.NotifyPath(f, notify) + } + + signalChan := make(chan os.Signal) + signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGUSR1) + go (func() { + for sig := range signalChan { + logger.Noticef("Received signal %s, restarting program", sig.String()) + go restart() + } + })() + + fsnotify.Debug() + for { + select { + case <-exitChan: + // subprocess exited, but was it controlled? + if ok := restartLock.TryLock(); ok { + restartLock.Unlock() + if command.ProcessState != nil && command.Process != nil { + logger.Noticef("uncontrolled exit of subprocess with pid %d and status %d", command.Process.Pid, command.ProcessState.ExitCode()) + os.Exit(command.ProcessState.ExitCode()) + } + } else { + // proxy notification to restart monitor + exitChan <- struct{}{} + } + case <-ctx.Done(): + logger.Noticef("signal received, stopping subprocess") + stop() + return + } + } +} + +func hashFile(filePath string) ([]byte, error) { + hash := crypto.SHA256.New() + fp, err := os.OpenFile(filePath, os.O_RDONLY, os.FileMode(0)) + if err != nil { + return nil, err + } + io.Copy(hash, fp) + fp.Close() + return hash.Sum(nil), nil +} + +func bin2hex(bin []byte) string { + out := "" + for _, b := range bin { + out += fmt.Sprintf("%02x", b) + } + return out +} diff --git a/mtls/supervisor/test_mode.go b/mtls/supervisor/test_mode.go new file mode 100644 index 0000000..4f273fc --- /dev/null +++ b/mtls/supervisor/test_mode.go @@ -0,0 +1,171 @@ +package main + +import ( + "context" + "fmt" + "io/fs" + "math/rand" + "os" + "path" + "time" + + "golang.org/x/sys/unix" + + "go.fuhry.dev/runtime/utils/log" +) + +const ( + dataLinkName = "..data" + dateFormat = "..2006_01_02_15_04_05.9999999999" + fileMode = os.FileMode(0600) + dirMode = os.FileMode(0700) +) + +var logger log.Logger + +func init() { + logger = log.Default().WithPrefix("SUP DBG") +} + +func randhex(n int) string { + reader := rand.New(rand.NewSource(time.Now().UnixNano())) + buf := make([]byte, n) + for i := 0; i < n; { + nr, err := reader.Read(buf[i:]) + if err != nil { + panic(err) + } + i += nr + } + + return bin2hex(buf) +} + +func writeTestFile(filePath string) { + contents := randhex(64) + fp, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY, fileMode) + if err != nil { + panic(err) + } + defer fp.Close() + for i := 0; i < len(contents); { + nw, err := fp.Write([]byte(contents[i:])) + if err != nil { + panic(err) + } + i += nw + } +} + +func renameExchange(src, dest string) error { + srcDir := path.Dir(src) + destDir := path.Dir(dest) + + srcfd, err := unix.Open(srcDir, 0, unix.O_DIRECTORY) + var destfd int + if err != nil { + return fmt.Errorf("while opening srcdir %s: %v", srcDir, err) + } + defer unix.Close(srcfd) + if srcDir != destDir { + destfd, err = unix.Open(destDir, 0, unix.O_DIRECTORY) + if err != nil { + return fmt.Errorf("while open destdir %s: %v", destDir, err) + } + defer unix.Close(destfd) + } else { + destfd = srcfd + } + + return unix.Renameat2(srcfd, src, destfd, dest, unix.RENAME_EXCHANGE) +} + +func createTestTree(basedir string, filenames []string) string { + entryDir := time.Now().Format(dateFormat) + dir := path.Join(basedir, entryDir) + logger.V(2).Infof("createTestTree %s -> %+v", dir, filenames) + if err := os.MkdirAll(dir, dirMode); err != nil { + panic(err) + } + dataPath := path.Join(basedir, dataLinkName) + logger.V(2).Infof("symlink %s -> %s", dataPath, entryDir) + if isSymlink(dataPath) { + if err := os.Symlink(entryDir, dataPath+".new"); err != nil { + panic(err) + } + if err := renameExchange(dataPath+".new", dataPath); err != nil { + panic(err) + } + if err := os.Remove(dataPath + ".new"); err != nil { + panic(err) + } + } else { + if err := os.Symlink(entryDir, dataPath); err != nil { + panic(err) + } + } + for _, f := range filenames { + writeTestFile(path.Join(dir, f)) + + baseFile := path.Join(basedir, f) + if !isSymlink(baseFile) { + logger.V(2).Infof("symlink %s -> %s", baseFile, path.Join(dataLinkName, f)) + if err := os.Symlink(path.Join(dataLinkName, f), baseFile); err != nil { + panic(err) + } + } + } + + return entryDir +} + +func maintainTestTree(ctx context.Context, basedir string, curDir string, filenames []string) { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + defer func() { + // put this in an anon func to use current value of curDir + os.RemoveAll(path.Join(basedir, curDir)) + }() + for { + select { + case <-ticker.C: + newDir := createTestTree(basedir, filenames) + os.RemoveAll(path.Join(basedir, curDir)) + curDir = newDir + case <-ctx.Done(): + return + } + } +} + +func startTestTree(ctx context.Context) string { + rootPathFiles := []string{ + "ca.pem", + // "rootca.pem", + // "step-ca.json", + } + // certDirFiles := []string{ + // "fullchain.pem", + // "privkey.pem", + // "ca.pem", + // } + + baseDir := fmt.Sprintf("/tmp/supervisor.%d", os.Getpid()) + + rootDir := createTestTree(baseDir, rootPathFiles) + // certDir := createTestTree(path.Join(baseDir, "testcert"), certDirFiles) + + go maintainTestTree(ctx, baseDir, rootDir, rootPathFiles) + // go maintainTestTree(ctx, path.Join(baseDir, "testcert"), certDir, certDirFiles) + + return baseDir +} + +func isSymlink(fullPath string) bool { + if stat, err := os.Lstat(fullPath); err == nil { + return stat.Mode()&fs.ModeSymlink == fs.ModeSymlink + } + + return false +} -- 2.50.1