]> go.fuhry.dev Git - runtime.git/commitdiff
add mtls/supervisor command
authorDan Fuhry <dan@fuhry.com>
Thu, 15 May 2025 21:45:46 +0000 (17:45 -0400)
committerDan Fuhry <dan@fuhry.com>
Thu, 15 May 2025 21:45:46 +0000 (17:45 -0400)
.gitignore
mtls/supervisor/Makefile [new file with mode: 0644]
mtls/supervisor/main.go [new file with mode: 0644]
mtls/supervisor/test_mode.go [new file with mode: 0644]

index 43c001f0c18239212b9ddc743394ed6c7d01965d..a3fd46ecc1616432aa2e75eaf9d1f13c72fee33f 100644 (file)
@@ -54,6 +54,7 @@ automation/bryston_ctl/cli/cli
 automation/bryston_ctl/client/client
 automation/bryston_ctl/server/server
 metrics/prometheus_http_discovery/prometheus_http_discovery
+mtls/supervisor/supervisor
 
 /vendor/
 
diff --git a/mtls/supervisor/Makefile b/mtls/supervisor/Makefile
new file mode 100644 (file)
index 0000000..1d7d424
--- /dev/null
@@ -0,0 +1,14 @@
+GOSRC = $(wildcard *.go)
+GOEXE = $(GOSRC:.go=)
+GOBUILDFLAGS := -buildmode=pie -trimpath
+
+all: $(GOEXE)
+
+clean:
+       rm -fv $(GOEXE)
+
+.PHONY: all clean
+
+$(GOEXE): %: %.go
+       go build $(GOBUILDFLAGS) -o $@ $<
+
diff --git a/mtls/supervisor/main.go b/mtls/supervisor/main.go
new file mode 100644 (file)
index 0000000..6e0c213
--- /dev/null
@@ -0,0 +1,236 @@
+package main
+
+import (
+       "bytes"
+       "context"
+       "crypto"
+       "fmt"
+       "io"
+       "os"
+       "os/exec"
+       "os/signal"
+       "path"
+       "path/filepath"
+       "strings"
+       "sync"
+       "syscall"
+       "time"
+
+       "go.fuhry.dev/runtime/mtls/fsnotify"
+       "go.fuhry.dev/runtime/utils/debounce"
+       "go.fuhry.dev/runtime/utils/hashset"
+       "go.fuhry.dev/runtime/utils/log"
+)
+
+func main() {
+       var restartLock sync.Mutex
+       var mtlsFiles []string
+       exitChan := make(chan struct{})
+
+       if os.Getenv("MTLS_SUPERVISOR_DEBUG") == "1" {
+               log.SetVerbosity(2)
+               log.SetLevel(log.DEBUG)
+       }
+
+       ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+       logger := log.Default().WithPrefix("supervisor")
+       defer cancel()
+
+       if len(os.Args) < 2 {
+               logger.Panicf("usage: %s command-to-run", os.Args[0])
+       }
+       cmdline := os.Args[1:]
+       abspath, err := exec.LookPath(cmdline[0])
+       if err != nil {
+               logger.Panic(err)
+               return
+       }
+       command := exec.Command(abspath, cmdline[1:]...)
+       command.Stdin = os.Stdin
+       command.Stdout = os.Stdout
+       command.Stderr = os.Stderr
+
+       mtlsBasenames := hashset.FromSlice[string]([]string{
+               "cert.pem",
+               "chain.pem",
+               "fullchain.pem",
+               "privkey.pem",
+               "aio.pem",
+               "aio.p12",
+               "ca.pem",
+               "rootca.pem",
+       })
+
+       hashes := make(map[string][]byte)
+       var hashesLock sync.Mutex
+
+       var scanPath func(dir string, depth int)
+       scanPath = func(dir string, depth int) {
+               if depth == 0 {
+                       hashesLock.Lock()
+                       defer hashesLock.Unlock()
+               }
+
+               entries, err := os.ReadDir(dir)
+               if err != nil {
+                       logger.Warningf("failed to scan directory %q: %v", dir, err)
+                       return
+               }
+               for _, ent := range entries {
+                       abspath := path.Join(dir, ent.Name())
+                       if ent.IsDir() && depth == 0 && !strings.HasPrefix(ent.Name(), "..") {
+                               logger.Infof("descending to %s", abspath)
+                               scanPath(abspath, depth+1)
+                       } else {
+                               if mtlsBasenames.Contains(ent.Name()) {
+                                       hash, err := hashFile(abspath)
+                                       if err == nil {
+                                               logger.V(1).Debugf("file %q hash: %s", abspath, bin2hex(hash))
+                                               hashes[abspath] = hash
+                                               mtlsFiles = append(mtlsFiles, abspath)
+                                       }
+                               }
+                       }
+               }
+       }
+
+       stop := func() {
+               err := syscall.Kill(command.Process.Pid, syscall.SIGTERM)
+               if err != nil {
+                       logger.Warningf("failed to kill process %d: %v", command.Process.Pid, err)
+               } else {
+                       // we sent sigterm, now wait for it to die
+                       timeout := time.After(5 * time.Second)
+                       logger.Infof("waiting for pid %d to die", command.Process.Pid)
+                       for {
+                               select {
+                               case <-exitChan:
+                                       logger.Infof("subprocess has died, ok to restart")
+                                       command.Process, command.ProcessState = nil, nil
+                                       return
+                               case <-timeout:
+                                       if command.Process != nil && command.ProcessState == nil {
+                                               logger.Warningf("deadline exceeded, killing pid %d", command.Process.Pid)
+                                               syscall.Kill(command.Process.Pid, syscall.SIGKILL)
+                                               timeout = time.After(5 * time.Second)
+                                       }
+                               }
+                       }
+               }
+       }
+
+       restart := func() {
+               var err error
+
+               restartLock.Lock()
+               defer restartLock.Unlock()
+               if command.Process != nil {
+                       logger.Noticef("restarting subprocess: %+v", command)
+                       stop()
+               } else {
+                       logger.Infof("starting subprocess: %+v", command)
+               }
+
+               err = command.Start()
+               if err != nil {
+                       logger.Panicf("failed to start subprocess: %v", err)
+               }
+               go (func() {
+                       command.Wait()
+                       exitChan <- struct{}{}
+               })()
+       }
+
+       baseDir := "/etc/ssl/mtls"
+       if os.Getenv("MTLS_SUPERVISOR_DEBUG") == "1" {
+               baseDir = startTestTree(ctx)
+               defer os.RemoveAll(baseDir)
+       }
+       scanPath(baseDir, 0)
+
+       if len(mtlsFiles) < 1 {
+               logger.Panicf("no files were discovered under %s", baseDir)
+               return
+       }
+
+       restart()
+
+       debouncer := debounce.NewWithTimeout(restart, 500*time.Millisecond)
+       notify := func(path string, op fsnotify.Op) {
+               if !mtlsBasenames.Contains(filepath.Base(path)) {
+                       return
+               }
+
+               hashesLock.Lock()
+               defer hashesLock.Unlock()
+               if hash, ok := hashes[path]; ok {
+                       newHash, err := hashFile(path)
+                       if err != nil {
+                               logger.Warnf("failed hashing file %q: %v", path, err)
+                               return
+                       }
+                       if bytes.Equal(hash, newHash) {
+                               logger.Infof("%s: contents not changed, so not triggering reload", path)
+                       } else {
+                               hashes[path] = newHash
+                               logger.Noticef("file %q changed, subprocess will be restarted: %s != %s", path, bin2hex(hash), bin2hex(newHash))
+                               debouncer.Trigger()
+                       }
+               }
+       }
+
+       for _, f := range mtlsFiles {
+               logger.Infof("monitoring path: %s", f)
+               fsnotify.NotifyPath(f, notify)
+       }
+
+       signalChan := make(chan os.Signal)
+       signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGUSR1)
+       go (func() {
+               for sig := range signalChan {
+                       logger.Noticef("Received signal %s, restarting program", sig.String())
+                       go restart()
+               }
+       })()
+
+       fsnotify.Debug()
+       for {
+               select {
+               case <-exitChan:
+                       // subprocess exited, but was it controlled?
+                       if ok := restartLock.TryLock(); ok {
+                               restartLock.Unlock()
+                               if command.ProcessState != nil && command.Process != nil {
+                                       logger.Noticef("uncontrolled exit of subprocess with pid %d and status %d", command.Process.Pid, command.ProcessState.ExitCode())
+                                       os.Exit(command.ProcessState.ExitCode())
+                               }
+                       } else {
+                               // proxy notification to restart monitor
+                               exitChan <- struct{}{}
+                       }
+               case <-ctx.Done():
+                       logger.Noticef("signal received, stopping subprocess")
+                       stop()
+                       return
+               }
+       }
+}
+
+func hashFile(filePath string) ([]byte, error) {
+       hash := crypto.SHA256.New()
+       fp, err := os.OpenFile(filePath, os.O_RDONLY, os.FileMode(0))
+       if err != nil {
+               return nil, err
+       }
+       io.Copy(hash, fp)
+       fp.Close()
+       return hash.Sum(nil), nil
+}
+
+func bin2hex(bin []byte) string {
+       out := ""
+       for _, b := range bin {
+               out += fmt.Sprintf("%02x", b)
+       }
+       return out
+}
diff --git a/mtls/supervisor/test_mode.go b/mtls/supervisor/test_mode.go
new file mode 100644 (file)
index 0000000..4f273fc
--- /dev/null
@@ -0,0 +1,171 @@
+package main
+
+import (
+       "context"
+       "fmt"
+       "io/fs"
+       "math/rand"
+       "os"
+       "path"
+       "time"
+
+       "golang.org/x/sys/unix"
+
+       "go.fuhry.dev/runtime/utils/log"
+)
+
+const (
+       dataLinkName = "..data"
+       dateFormat   = "..2006_01_02_15_04_05.9999999999"
+       fileMode     = os.FileMode(0600)
+       dirMode      = os.FileMode(0700)
+)
+
+var logger log.Logger
+
+func init() {
+       logger = log.Default().WithPrefix("SUP DBG")
+}
+
+func randhex(n int) string {
+       reader := rand.New(rand.NewSource(time.Now().UnixNano()))
+       buf := make([]byte, n)
+       for i := 0; i < n; {
+               nr, err := reader.Read(buf[i:])
+               if err != nil {
+                       panic(err)
+               }
+               i += nr
+       }
+
+       return bin2hex(buf)
+}
+
+func writeTestFile(filePath string) {
+       contents := randhex(64)
+       fp, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY, fileMode)
+       if err != nil {
+               panic(err)
+       }
+       defer fp.Close()
+       for i := 0; i < len(contents); {
+               nw, err := fp.Write([]byte(contents[i:]))
+               if err != nil {
+                       panic(err)
+               }
+               i += nw
+       }
+}
+
+func renameExchange(src, dest string) error {
+       srcDir := path.Dir(src)
+       destDir := path.Dir(dest)
+
+       srcfd, err := unix.Open(srcDir, 0, unix.O_DIRECTORY)
+       var destfd int
+       if err != nil {
+               return fmt.Errorf("while opening srcdir %s: %v", srcDir, err)
+       }
+       defer unix.Close(srcfd)
+       if srcDir != destDir {
+               destfd, err = unix.Open(destDir, 0, unix.O_DIRECTORY)
+               if err != nil {
+                       return fmt.Errorf("while open destdir %s: %v", destDir, err)
+               }
+               defer unix.Close(destfd)
+       } else {
+               destfd = srcfd
+       }
+
+       return unix.Renameat2(srcfd, src, destfd, dest, unix.RENAME_EXCHANGE)
+}
+
+func createTestTree(basedir string, filenames []string) string {
+       entryDir := time.Now().Format(dateFormat)
+       dir := path.Join(basedir, entryDir)
+       logger.V(2).Infof("createTestTree %s -> %+v", dir, filenames)
+       if err := os.MkdirAll(dir, dirMode); err != nil {
+               panic(err)
+       }
+       dataPath := path.Join(basedir, dataLinkName)
+       logger.V(2).Infof("symlink %s -> %s", dataPath, entryDir)
+       if isSymlink(dataPath) {
+               if err := os.Symlink(entryDir, dataPath+".new"); err != nil {
+                       panic(err)
+               }
+               if err := renameExchange(dataPath+".new", dataPath); err != nil {
+                       panic(err)
+               }
+               if err := os.Remove(dataPath + ".new"); err != nil {
+                       panic(err)
+               }
+       } else {
+               if err := os.Symlink(entryDir, dataPath); err != nil {
+                       panic(err)
+               }
+       }
+       for _, f := range filenames {
+               writeTestFile(path.Join(dir, f))
+
+               baseFile := path.Join(basedir, f)
+               if !isSymlink(baseFile) {
+                       logger.V(2).Infof("symlink %s -> %s", baseFile, path.Join(dataLinkName, f))
+                       if err := os.Symlink(path.Join(dataLinkName, f), baseFile); err != nil {
+                               panic(err)
+                       }
+               }
+       }
+
+       return entryDir
+}
+
+func maintainTestTree(ctx context.Context, basedir string, curDir string, filenames []string) {
+       ticker := time.NewTicker(5 * time.Second)
+       defer ticker.Stop()
+
+       defer func() {
+               // put this in an anon func to use current value of curDir
+               os.RemoveAll(path.Join(basedir, curDir))
+       }()
+       for {
+               select {
+               case <-ticker.C:
+                       newDir := createTestTree(basedir, filenames)
+                       os.RemoveAll(path.Join(basedir, curDir))
+                       curDir = newDir
+               case <-ctx.Done():
+                       return
+               }
+       }
+}
+
+func startTestTree(ctx context.Context) string {
+       rootPathFiles := []string{
+               "ca.pem",
+               // "rootca.pem",
+               // "step-ca.json",
+       }
+       // certDirFiles := []string{
+       //      "fullchain.pem",
+       //      "privkey.pem",
+       //      "ca.pem",
+       // }
+
+       baseDir := fmt.Sprintf("/tmp/supervisor.%d", os.Getpid())
+
+       rootDir := createTestTree(baseDir, rootPathFiles)
+       // certDir := createTestTree(path.Join(baseDir, "testcert"), certDirFiles)
+
+       go maintainTestTree(ctx, baseDir, rootDir, rootPathFiles)
+       // go maintainTestTree(ctx, path.Join(baseDir, "testcert"), certDir, certDirFiles)
+
+       return baseDir
+}
+
+func isSymlink(fullPath string) bool {
+       if stat, err := os.Lstat(fullPath); err == nil {
+               return stat.Mode()&fs.ModeSymlink == fs.ModeSymlink
+       }
+
+       return false
+}