--- /dev/null
+package main
+
+import (
+ "bytes"
+ "context"
+ "crypto"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "go.fuhry.dev/runtime/mtls/fsnotify"
+ "go.fuhry.dev/runtime/utils/debounce"
+ "go.fuhry.dev/runtime/utils/hashset"
+ "go.fuhry.dev/runtime/utils/log"
+)
+
+func main() {
+ var restartLock sync.Mutex
+ var mtlsFiles []string
+ exitChan := make(chan struct{})
+
+ if os.Getenv("MTLS_SUPERVISOR_DEBUG") == "1" {
+ log.SetVerbosity(2)
+ log.SetLevel(log.DEBUG)
+ }
+
+ ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+ logger := log.Default().WithPrefix("supervisor")
+ defer cancel()
+
+ if len(os.Args) < 2 {
+ logger.Panicf("usage: %s command-to-run", os.Args[0])
+ }
+ cmdline := os.Args[1:]
+ abspath, err := exec.LookPath(cmdline[0])
+ if err != nil {
+ logger.Panic(err)
+ return
+ }
+ command := exec.Command(abspath, cmdline[1:]...)
+ command.Stdin = os.Stdin
+ command.Stdout = os.Stdout
+ command.Stderr = os.Stderr
+
+ mtlsBasenames := hashset.FromSlice[string]([]string{
+ "cert.pem",
+ "chain.pem",
+ "fullchain.pem",
+ "privkey.pem",
+ "aio.pem",
+ "aio.p12",
+ "ca.pem",
+ "rootca.pem",
+ })
+
+ hashes := make(map[string][]byte)
+ var hashesLock sync.Mutex
+
+ var scanPath func(dir string, depth int)
+ scanPath = func(dir string, depth int) {
+ if depth == 0 {
+ hashesLock.Lock()
+ defer hashesLock.Unlock()
+ }
+
+ entries, err := os.ReadDir(dir)
+ if err != nil {
+ logger.Warningf("failed to scan directory %q: %v", dir, err)
+ return
+ }
+ for _, ent := range entries {
+ abspath := path.Join(dir, ent.Name())
+ if ent.IsDir() && depth == 0 && !strings.HasPrefix(ent.Name(), "..") {
+ logger.Infof("descending to %s", abspath)
+ scanPath(abspath, depth+1)
+ } else {
+ if mtlsBasenames.Contains(ent.Name()) {
+ hash, err := hashFile(abspath)
+ if err == nil {
+ logger.V(1).Debugf("file %q hash: %s", abspath, bin2hex(hash))
+ hashes[abspath] = hash
+ mtlsFiles = append(mtlsFiles, abspath)
+ }
+ }
+ }
+ }
+ }
+
+ stop := func() {
+ err := syscall.Kill(command.Process.Pid, syscall.SIGTERM)
+ if err != nil {
+ logger.Warningf("failed to kill process %d: %v", command.Process.Pid, err)
+ } else {
+ // we sent sigterm, now wait for it to die
+ timeout := time.After(5 * time.Second)
+ logger.Infof("waiting for pid %d to die", command.Process.Pid)
+ for {
+ select {
+ case <-exitChan:
+ logger.Infof("subprocess has died, ok to restart")
+ command.Process, command.ProcessState = nil, nil
+ return
+ case <-timeout:
+ if command.Process != nil && command.ProcessState == nil {
+ logger.Warningf("deadline exceeded, killing pid %d", command.Process.Pid)
+ syscall.Kill(command.Process.Pid, syscall.SIGKILL)
+ timeout = time.After(5 * time.Second)
+ }
+ }
+ }
+ }
+ }
+
+ restart := func() {
+ var err error
+
+ restartLock.Lock()
+ defer restartLock.Unlock()
+ if command.Process != nil {
+ logger.Noticef("restarting subprocess: %+v", command)
+ stop()
+ } else {
+ logger.Infof("starting subprocess: %+v", command)
+ }
+
+ err = command.Start()
+ if err != nil {
+ logger.Panicf("failed to start subprocess: %v", err)
+ }
+ go (func() {
+ command.Wait()
+ exitChan <- struct{}{}
+ })()
+ }
+
+ baseDir := "/etc/ssl/mtls"
+ if os.Getenv("MTLS_SUPERVISOR_DEBUG") == "1" {
+ baseDir = startTestTree(ctx)
+ defer os.RemoveAll(baseDir)
+ }
+ scanPath(baseDir, 0)
+
+ if len(mtlsFiles) < 1 {
+ logger.Panicf("no files were discovered under %s", baseDir)
+ return
+ }
+
+ restart()
+
+ debouncer := debounce.NewWithTimeout(restart, 500*time.Millisecond)
+ notify := func(path string, op fsnotify.Op) {
+ if !mtlsBasenames.Contains(filepath.Base(path)) {
+ return
+ }
+
+ hashesLock.Lock()
+ defer hashesLock.Unlock()
+ if hash, ok := hashes[path]; ok {
+ newHash, err := hashFile(path)
+ if err != nil {
+ logger.Warnf("failed hashing file %q: %v", path, err)
+ return
+ }
+ if bytes.Equal(hash, newHash) {
+ logger.Infof("%s: contents not changed, so not triggering reload", path)
+ } else {
+ hashes[path] = newHash
+ logger.Noticef("file %q changed, subprocess will be restarted: %s != %s", path, bin2hex(hash), bin2hex(newHash))
+ debouncer.Trigger()
+ }
+ }
+ }
+
+ for _, f := range mtlsFiles {
+ logger.Infof("monitoring path: %s", f)
+ fsnotify.NotifyPath(f, notify)
+ }
+
+ signalChan := make(chan os.Signal)
+ signal.Notify(signalChan, syscall.SIGHUP, syscall.SIGUSR1)
+ go (func() {
+ for sig := range signalChan {
+ logger.Noticef("Received signal %s, restarting program", sig.String())
+ go restart()
+ }
+ })()
+
+ fsnotify.Debug()
+ for {
+ select {
+ case <-exitChan:
+ // subprocess exited, but was it controlled?
+ if ok := restartLock.TryLock(); ok {
+ restartLock.Unlock()
+ if command.ProcessState != nil && command.Process != nil {
+ logger.Noticef("uncontrolled exit of subprocess with pid %d and status %d", command.Process.Pid, command.ProcessState.ExitCode())
+ os.Exit(command.ProcessState.ExitCode())
+ }
+ } else {
+ // proxy notification to restart monitor
+ exitChan <- struct{}{}
+ }
+ case <-ctx.Done():
+ logger.Noticef("signal received, stopping subprocess")
+ stop()
+ return
+ }
+ }
+}
+
+func hashFile(filePath string) ([]byte, error) {
+ hash := crypto.SHA256.New()
+ fp, err := os.OpenFile(filePath, os.O_RDONLY, os.FileMode(0))
+ if err != nil {
+ return nil, err
+ }
+ io.Copy(hash, fp)
+ fp.Close()
+ return hash.Sum(nil), nil
+}
+
+func bin2hex(bin []byte) string {
+ out := ""
+ for _, b := range bin {
+ out += fmt.Sprintf("%02x", b)
+ }
+ return out
+}
--- /dev/null
+package main
+
+import (
+ "context"
+ "fmt"
+ "io/fs"
+ "math/rand"
+ "os"
+ "path"
+ "time"
+
+ "golang.org/x/sys/unix"
+
+ "go.fuhry.dev/runtime/utils/log"
+)
+
+const (
+ dataLinkName = "..data"
+ dateFormat = "..2006_01_02_15_04_05.9999999999"
+ fileMode = os.FileMode(0600)
+ dirMode = os.FileMode(0700)
+)
+
+var logger log.Logger
+
+func init() {
+ logger = log.Default().WithPrefix("SUP DBG")
+}
+
+func randhex(n int) string {
+ reader := rand.New(rand.NewSource(time.Now().UnixNano()))
+ buf := make([]byte, n)
+ for i := 0; i < n; {
+ nr, err := reader.Read(buf[i:])
+ if err != nil {
+ panic(err)
+ }
+ i += nr
+ }
+
+ return bin2hex(buf)
+}
+
+func writeTestFile(filePath string) {
+ contents := randhex(64)
+ fp, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY, fileMode)
+ if err != nil {
+ panic(err)
+ }
+ defer fp.Close()
+ for i := 0; i < len(contents); {
+ nw, err := fp.Write([]byte(contents[i:]))
+ if err != nil {
+ panic(err)
+ }
+ i += nw
+ }
+}
+
+func renameExchange(src, dest string) error {
+ srcDir := path.Dir(src)
+ destDir := path.Dir(dest)
+
+ srcfd, err := unix.Open(srcDir, 0, unix.O_DIRECTORY)
+ var destfd int
+ if err != nil {
+ return fmt.Errorf("while opening srcdir %s: %v", srcDir, err)
+ }
+ defer unix.Close(srcfd)
+ if srcDir != destDir {
+ destfd, err = unix.Open(destDir, 0, unix.O_DIRECTORY)
+ if err != nil {
+ return fmt.Errorf("while open destdir %s: %v", destDir, err)
+ }
+ defer unix.Close(destfd)
+ } else {
+ destfd = srcfd
+ }
+
+ return unix.Renameat2(srcfd, src, destfd, dest, unix.RENAME_EXCHANGE)
+}
+
+func createTestTree(basedir string, filenames []string) string {
+ entryDir := time.Now().Format(dateFormat)
+ dir := path.Join(basedir, entryDir)
+ logger.V(2).Infof("createTestTree %s -> %+v", dir, filenames)
+ if err := os.MkdirAll(dir, dirMode); err != nil {
+ panic(err)
+ }
+ dataPath := path.Join(basedir, dataLinkName)
+ logger.V(2).Infof("symlink %s -> %s", dataPath, entryDir)
+ if isSymlink(dataPath) {
+ if err := os.Symlink(entryDir, dataPath+".new"); err != nil {
+ panic(err)
+ }
+ if err := renameExchange(dataPath+".new", dataPath); err != nil {
+ panic(err)
+ }
+ if err := os.Remove(dataPath + ".new"); err != nil {
+ panic(err)
+ }
+ } else {
+ if err := os.Symlink(entryDir, dataPath); err != nil {
+ panic(err)
+ }
+ }
+ for _, f := range filenames {
+ writeTestFile(path.Join(dir, f))
+
+ baseFile := path.Join(basedir, f)
+ if !isSymlink(baseFile) {
+ logger.V(2).Infof("symlink %s -> %s", baseFile, path.Join(dataLinkName, f))
+ if err := os.Symlink(path.Join(dataLinkName, f), baseFile); err != nil {
+ panic(err)
+ }
+ }
+ }
+
+ return entryDir
+}
+
+func maintainTestTree(ctx context.Context, basedir string, curDir string, filenames []string) {
+ ticker := time.NewTicker(5 * time.Second)
+ defer ticker.Stop()
+
+ defer func() {
+ // put this in an anon func to use current value of curDir
+ os.RemoveAll(path.Join(basedir, curDir))
+ }()
+ for {
+ select {
+ case <-ticker.C:
+ newDir := createTestTree(basedir, filenames)
+ os.RemoveAll(path.Join(basedir, curDir))
+ curDir = newDir
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func startTestTree(ctx context.Context) string {
+ rootPathFiles := []string{
+ "ca.pem",
+ // "rootca.pem",
+ // "step-ca.json",
+ }
+ // certDirFiles := []string{
+ // "fullchain.pem",
+ // "privkey.pem",
+ // "ca.pem",
+ // }
+
+ baseDir := fmt.Sprintf("/tmp/supervisor.%d", os.Getpid())
+
+ rootDir := createTestTree(baseDir, rootPathFiles)
+ // certDir := createTestTree(path.Join(baseDir, "testcert"), certDirFiles)
+
+ go maintainTestTree(ctx, baseDir, rootDir, rootPathFiles)
+ // go maintainTestTree(ctx, path.Join(baseDir, "testcert"), certDir, certDirFiles)
+
+ return baseDir
+}
+
+func isSymlink(fullPath string) bool {
+ if stat, err := os.Lstat(fullPath); err == nil {
+ return stat.Mode()&fs.ModeSymlink == fs.ModeSymlink
+ }
+
+ return false
+}