]> go.fuhry.dev Git - runtime.git/commitdiff
more ephs
authorDan Fuhry <dan@fuhry.com>
Thu, 6 Nov 2025 12:03:20 +0000 (07:03 -0500)
committerDan Fuhry <dan@fuhry.com>
Sun, 9 Nov 2025 12:24:10 +0000 (07:24 -0500)
17 files changed:
ephs/client/main.go [new file with mode: 0644]
ephs/server/Makefile [new file with mode: 0644]
ephs/server/ephs_acl.yaml [new file with mode: 0644]
ephs/server/main.go [new file with mode: 0644]
ephs/server/rules.yaml [new file with mode: 0644]
grpc/client.go
grpc/conn_quic.go
grpc/health_probe/main.go [new file with mode: 0644]
grpc/healthcheck.go [new file with mode: 0644]
grpc/server.go
mtls/fsnotify/fsnotify.go
mtls/fsnotify/util.go
mtls/identity.go
mtls/verify_names.go
mtls/verify_roots.go
sd/etcd_factory.go
sd/watcher.go

diff --git a/ephs/client/main.go b/ephs/client/main.go
new file mode 100644 (file)
index 0000000..f15ca24
--- /dev/null
@@ -0,0 +1,275 @@
+package main
+
+import (
+       "context"
+       "errors"
+       "flag"
+       "fmt"
+       "io"
+       "os"
+       "os/signal"
+       "path"
+       "strings"
+       "syscall"
+
+       "github.com/urfave/cli/v3"
+
+       "go.fuhry.dev/runtime/constants"
+       "go.fuhry.dev/runtime/ephs"
+       "go.fuhry.dev/runtime/utils/log"
+)
+
+func cmdStat(ctx context.Context, cmd *cli.Command) error {
+       client, err := ephs.DefaultClient()
+       if err != nil {
+               return err
+       }
+
+       resp, err := client.StatContext(ctx, strings.TrimSuffix(cmd.StringArg("path"), "/"))
+       if err != nil {
+               return err
+       }
+
+       fmt.Println(ephs.FormatFsEntry(resp))
+       return nil
+}
+
+func cmdCat(ctx context.Context, cmd *cli.Command) error {
+       client, err := ephs.DefaultClient()
+       if err != nil {
+               return err
+       }
+
+       reader, err := client.GetContext(ctx, cmd.StringArg("path"))
+       if err != nil {
+               return err
+       }
+
+       io.Copy(os.Stdout, reader)
+
+       return nil
+}
+
+func cmdCopy(ctx context.Context, cmd *cli.Command) error {
+       client, err := ephs.DefaultClient()
+       if err != nil {
+               return err
+       }
+
+       src := cmd.StringArg("src")
+       dst := cmd.StringArg("dst")
+
+       if ephs.IsEphsPath(src) && ephs.IsEphsPath(dst) {
+               return errors.New("copying from one ephs path to another is not presently supported")
+       } else if !ephs.IsEphsPath(src) && !ephs.IsEphsPath(dst) {
+               return errors.New("copying from one non-ephs path to another is not supported")
+       }
+
+       if strings.HasSuffix(dst, "/") {
+               dst += path.Base(src)
+       }
+
+       if ephs.IsEphsPath(src) {
+               // copy from ephs to local
+               fp, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE, os.FileMode(0600))
+               if err != nil {
+                       return err
+               }
+               defer fp.Close()
+
+               reader, err := client.GetContext(ctx, src)
+               if err != nil {
+                       return err
+               }
+
+               nw, _ := io.Copy(fp, reader)
+
+               log.Default().Infof("wrote %d bytes to %s", nw, dst)
+               return nil
+       } else if ephs.IsEphsPath(dst) {
+               // copy from local to ephs (put request)
+               st, err := os.Stat(src)
+               if err != nil {
+                       return err
+               }
+               fp, err := os.OpenFile(src, os.O_RDONLY, os.FileMode(0))
+               if err != nil {
+                       return err
+               }
+               defer fp.Close()
+
+               resp, err := client.PutContext(ctx, dst, uint64(st.Size()), fp)
+               if err != nil {
+                       return err
+               }
+
+               log.Default().Noticef("successfully uploaded: %s\n%s", dst, ephs.FormatFsEntry(resp))
+       }
+
+       return nil
+}
+
+func cmdDelete(ctx context.Context, cmd *cli.Command) error {
+       client, err := ephs.DefaultClient()
+       if err != nil {
+               return err
+       }
+
+       err = client.DeleteContext(ctx, cmd.StringArg("path"), cmd.IsSet("recursive"))
+       if err != nil {
+               return err
+       }
+
+       log.Default().Infof("Deleted object: %s", cmd.StringArg("path"))
+       return nil
+}
+
+func cmdWatch(ctx context.Context, cmd *cli.Command) error {
+       client, err := ephs.DefaultClient()
+       if err != nil {
+               return err
+       }
+
+       watchCl, err := client.Watch(ctx, cmd.StringArg("path"))
+       if err != nil {
+               return err
+       }
+
+       for msg := range watchCl {
+               log.Default().Infof("Got event:\nEvent type:      %s\n%s", msg.Event.String(), ephs.FormatFsEntry(msg.GetEntry()))
+       }
+       return nil
+}
+
+func cmdMkdir(ctx context.Context, cmd *cli.Command) error {
+       client, err := ephs.DefaultClient()
+       if err != nil {
+               return err
+       }
+
+       err = client.MkDirContext(ctx, cmd.StringArg("path"), cmd.Bool("recursive"))
+       if err != nil {
+               return err
+       }
+
+       log.Default().Infof("Created directory: %s", cmd.StringArg("path"))
+       return nil
+}
+
+func main() {
+       ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+       defer cancel()
+
+       flag.Parse()
+
+       cmd := &cli.Command{
+               Name:        "ephs",
+               Version:     constants.Version,
+               Description: "interact with ephs",
+
+               Commands: []*cli.Command{
+                       {
+                               Name:        "stat",
+                               Aliases:     []string{"ls", "dir", "list"},
+                               Description: "get info about a file or list contents of a directory",
+                               Action:      cmdStat,
+                               Arguments: []cli.Argument{
+                                       &cli.StringArg{
+                                               Name:      "path",
+                                               UsageText: "path to read",
+                                       },
+                               },
+                       },
+                       {
+                               Name:        "cat",
+                               Description: "read a file",
+                               Action:      cmdCat,
+                               Arguments: []cli.Argument{
+                                       &cli.StringArg{
+                                               Name:      "path",
+                                               UsageText: "path to read",
+                                       },
+                               },
+                       },
+                       {
+                               Name:        "cp",
+                               Description: "copy a file",
+                               Action:      cmdCopy,
+                               Arguments: []cli.Argument{
+                                       &cli.StringArg{
+                                               Name:      "src",
+                                               UsageText: "source",
+                                       },
+                                       &cli.StringArg{
+                                               Name:      "dst",
+                                               UsageText: "destination",
+                                       },
+                               },
+                       },
+                       {
+                               Name:        "rm",
+                               Description: "delete a file",
+                               Action:      cmdDelete,
+                               Arguments: []cli.Argument{
+                                       &cli.StringArg{
+                                               Name:      "path",
+                                               UsageText: "path to read",
+                                       },
+                               },
+                               Flags: []cli.Flag{
+                                       &cli.BoolFlag{
+                                               Name:  "recursive",
+                                               Usage: "allow removing directories and all of their contents",
+                                       },
+                               },
+                       },
+                       {
+                               Name:        "watch",
+                               Description: "watch a path for changes",
+                               Action:      cmdWatch,
+                               Arguments: []cli.Argument{
+                                       &cli.StringArg{
+                                               Name:      "path",
+                                               UsageText: "path to read",
+                                       },
+                               },
+                       },
+                       {
+                               Name:        "mkdir",
+                               Description: "create a directory",
+                               Action:      cmdMkdir,
+                               Arguments: []cli.Argument{
+                                       &cli.StringArg{
+                                               Name:      "path",
+                                               UsageText: "directory to create",
+                                       },
+                               },
+                               Flags: []cli.Flag{
+                                       &cli.BoolFlag{
+                                               Name:  "recursive",
+                                               Usage: "also create any parent directories that don't already exist",
+                                       },
+                               },
+                       },
+               },
+
+               Flags: []cli.Flag{
+                       &cli.IntFlag{
+                               Name:  "vv",
+                               Usage: "verbosity level",
+                       },
+                       &cli.StringFlag{
+                               Name:  "v",
+                               Usage: "log level",
+                       },
+                       &cli.StringFlag{
+                               Name:  "grpc.transport",
+                               Usage: "grpc transport (tcp or quic)",
+                       },
+               },
+       }
+
+       if err := cmd.Run(ctx, os.Args); err != nil {
+               log.Default().Panic(err)
+       }
+}
diff --git a/ephs/server/Makefile b/ephs/server/Makefile
new file mode 100644 (file)
index 0000000..bfab546
--- /dev/null
@@ -0,0 +1,14 @@
+GOSRC = $(wildcard *.go)
+GOEXE = $(shell basename `pwd`)
+GOBUILDFLAGS := -buildmode=pie -trimpath
+
+all: $(GOEXE)
+
+clean:
+       rm -fv $(GOEXE)
+
+.PHONY: all clean
+
+$(GOEXE): %: $(GOSRC)
+       go build $(GOBUILDFLAGS) -o $@ $<
+
diff --git a/ephs/server/ephs_acl.yaml b/ephs/server/ephs_acl.yaml
new file mode 100644 (file)
index 0000000..c4b1b46
--- /dev/null
@@ -0,0 +1,3 @@
+DEFAULT:
+  - service: '*'
+  - user: '*'
diff --git a/ephs/server/main.go b/ephs/server/main.go
new file mode 100644 (file)
index 0000000..d0e24cf
--- /dev/null
@@ -0,0 +1,64 @@
+package main
+
+import (
+       "context"
+       "flag"
+       "os/signal"
+       "syscall"
+
+       "go.fuhry.dev/runtime/ephs/servicer"
+       "go.fuhry.dev/runtime/grpc"
+       "go.fuhry.dev/runtime/mtls"
+       ephs_pb "go.fuhry.dev/runtime/proto/service/ephs"
+       "go.fuhry.dev/runtime/utils/log"
+
+       google_grpc "google.golang.org/grpc"
+)
+
+func main() {
+       var err error
+
+       acl := flag.String("ephs.acl", "", "YAML file containing ACLs for ephs")
+       awsFile := flag.String("ephs.s3-creds-file", "", "file to load AWS S3 credentials from")
+       awsEnv := flag.Bool("ephs.s3-creds-env", false, "set to true to load AWS credentials from environment")
+
+       flag.Parse()
+
+       serverIdentity := mtls.DefaultIdentity()
+       s, err := grpc.NewGrpcServer(serverIdentity)
+       if err != nil {
+               panic(err)
+       }
+
+       var opts []servicer.Option
+       if *acl != "" {
+               opts = append(opts, servicer.WithAclFile(*acl))
+       }
+
+       if awsEnv != nil && *awsEnv {
+               opts = append(opts, servicer.WithAWSEnvCredentials())
+       } else if awsFile != nil && *awsFile != "" {
+               opts = append(opts, servicer.WithAWSCredentialFile(*awsFile))
+       }
+
+       serv, err := servicer.NewEphsServicer(opts...)
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+
+       err = s.PublishAndServe(ctx, func(s *google_grpc.Server) {
+               ephs_pb.RegisterEphsServer(s, serv)
+       })
+       if err != nil {
+               panic(err)
+       }
+       defer s.Stop()
+
+       <-ctx.Done()
+}
+
+func init() {
+       mtls.SetDefaultIdentity("ephs")
+}
diff --git a/ephs/server/rules.yaml b/ephs/server/rules.yaml
new file mode 100644 (file)
index 0000000..9ba49d6
--- /dev/null
@@ -0,0 +1,11 @@
+rules:
+- principal:
+    prefix: spiffe://roc.xx0r.info/service/
+  key:
+    prefix: services/{{principal}}/
+- principal:
+    or:
+      - exact: spiffe://roc.xx0r.info/user/dan
+      - exact: spiffe://roc.xx0r.info/service/conffs
+  key:
+    any: true
index e904b30c532b9dbab93ac76f34c649bb6c5e4e80..93072f5447ec7b0dae392c51b1db0798ceac6d04 100644 (file)
@@ -3,44 +3,127 @@ package grpc
 import (
        "context"
        "fmt"
+       "net"
 
        "go.fuhry.dev/runtime/mtls"
        "go.fuhry.dev/runtime/sd"
        "google.golang.org/grpc"
 )
 
-type Client struct {
+type ClientConn = grpc.ClientConn
+
+type Client interface {
+       Conn() (*grpc.ClientConn, error)
+}
+
+type ClientOption interface {
+       apply(*client) error
+}
+
+type clientOption struct {
+       f func(*client) error
+}
+
+type AddressProvider interface {
+       GetAddrs(context.Context) ([]sd.ServiceAddress, error)
+}
+
+func (o *clientOption) apply(c *client) error {
+       return o.f(c)
+}
+
+type client struct {
        ctx      context.Context
        serverId mtls.Identity
        clientId mtls.Identity
-       watcher  *sd.SDWatcher
+       watcher  AddressProvider
        connFac  ConnectionFactory
 }
 
-func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity) (*Client, error) {
-       etcdc, err := sd.NewDefaultEtcdClient()
-       if err != nil {
-               panic(err)
+type staticAddressProvider struct {
+       addresses []sd.ServiceAddress
+}
+
+func (s *staticAddressProvider) GetAddrs(_ context.Context) ([]sd.ServiceAddress, error) {
+       return s.addresses, nil
+}
+
+func WithConnectionFactory(fac ConnectionFactory) ClientOption {
+       return &clientOption{
+               f: func(c *client) error {
+                       c.connFac = fac
+                       return nil
+               },
+       }
+}
+
+func WithAddressProvider(ap AddressProvider) ClientOption {
+       return &clientOption{
+               f: func(c *client) error {
+                       c.watcher = ap
+                       return nil
+               },
+       }
+}
+
+func WithStaticAddress(addresses ...*net.TCPAddr) ClientOption {
+       var addrs []sd.ServiceAddress
+
+       for _, addr := range addresses {
+               var ip4, ip6 string
+               if len(addr.IP) == 4 {
+                       ip4 = addr.IP.String()
+               } else {
+                       ip6 = addr.IP.String()
+               }
+               addrs = append(addrs, sd.ServiceAddress{
+                       IP4:  ip4,
+                       IP6:  ip6,
+                       Port: uint16(addr.Port),
+               })
        }
 
-       w := &sd.SDWatcher{
-               Service:    serverId.Name(),
-               EtcdClient: etcdc,
-               Protocol:   sd.ProtocolGRPC,
+       return &clientOption{
+               f: func(c *client) error {
+                       c.watcher = &staticAddressProvider{
+                               addresses: addrs,
+                       }
+                       return nil
+               },
        }
+}
 
-       cl := &Client{
+func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity, opts ...ClientOption) (Client, error) {
+       cl := &client{
                ctx:      ctx,
                serverId: serverId,
                clientId: clientId,
-               watcher:  w,
                connFac:  NewDefaultConnectionFactory(),
        }
 
+       for _, opt := range opts {
+               if err := opt.apply(cl); err != nil {
+                       return nil, err
+               }
+       }
+
+       if cl.watcher == nil {
+               etcdc, err := sd.NewDefaultEtcdClient()
+               if err != nil {
+                       return nil, err
+               }
+
+               cl.watcher = &sd.SDWatcher{
+                       Service:    serverId.Name(),
+                       EtcdClient: etcdc,
+                       Protocol:   sd.ProtocolGRPC,
+               }
+       }
+
        return cl, nil
 }
 
-func (c *Client) Conn() (*grpc.ClientConn, error) {
+func (c *client) Conn() (*grpc.ClientConn, error) {
        addrs, err := c.watcher.GetAddrs(c.ctx)
        if err != nil {
                return nil, err
index 99d16dffd8d6f430c594d42574842b135bdc890b..c949686808d6cabffbc45b7351a5d2612c4c3ac3 100644 (file)
@@ -1,9 +1,11 @@
 package grpc
 
 import (
+       "context"
        "crypto/tls"
        "fmt"
        "net"
+       "time"
 
        "google.golang.org/grpc/credentials"
 
@@ -11,7 +13,14 @@ import (
        grpc_quic "go.fuhry.dev/grpc-quic"
 )
 
-type QUICConnectionFactory struct{}
+var defaultQuicConfig = &quic.Config{
+       HandshakeIdleTimeout: 15 * time.Second,
+       MaxIdleTimeout:       5 * time.Minute,
+}
+
+type QUICConnectionFactory struct {
+       QUICConfig *quic.Config
+}
 
 func (cf *QUICConnectionFactory) NewCredentials(tlsConfig *tls.Config) credentials.TransportCredentials {
        tlsConfig.NextProtos = []string{"grpc-quic"}
@@ -24,7 +33,11 @@ func (cf *QUICConnectionFactory) NewListener(port uint16, tlsConfig *tls.Config)
                return nil, err
        }
 
-       quicListener, err := quic.Listen(udpListener, tlsConfig, nil)
+       if cf.QUICConfig == nil {
+               cf.QUICConfig = defaultQuicConfig.Clone()
+       }
+
+       quicListener, err := quic.Listen(udpListener, tlsConfig, cf.QUICConfig.Clone())
        if err != nil {
                return nil, err
        }
@@ -35,7 +48,18 @@ func (cf *QUICConnectionFactory) NewListener(port uint16, tlsConfig *tls.Config)
 }
 
 func (cf *QUICConnectionFactory) NewDialer(tlsConfig *tls.Config) ContextDialer {
-       return grpc_quic.NewQuicDialer(tlsConfig)
+       if cf.QUICConfig == nil {
+               cf.QUICConfig = defaultQuicConfig.Clone()
+       }
+
+       return func(ctx context.Context, target string) (net.Conn, error) {
+               sess, err := quic.DialAddr(ctx, target, tlsConfig, cf.QUICConfig.Clone())
+               if err != nil {
+                       return nil, err
+               }
+
+               return grpc_quic.NewConn(sess)
+       }
 }
 
 func init() {
diff --git a/grpc/health_probe/main.go b/grpc/health_probe/main.go
new file mode 100644 (file)
index 0000000..a32adc9
--- /dev/null
@@ -0,0 +1,152 @@
+package main
+
+import (
+       "context"
+       "flag"
+       "net"
+       "os"
+       "os/signal"
+       "strconv"
+       "syscall"
+       "time"
+
+       "go.fuhry.dev/runtime/grpc"
+       "go.fuhry.dev/runtime/mtls"
+       "go.fuhry.dev/runtime/utils/log"
+
+       "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+func main() {
+       serverId := flag.String("server-id", "", "mtls ID of gRPC server")
+       serverAddr := flag.String("server-addr", "", "server address as ip:port only")
+       watch := flag.Bool("watch", false, "continually monitor server for health status changes")
+       wait := flag.Bool("wait", false, "continually retry connecting, then stream server status until RUNNING")
+       var waitTimeout = 60 * time.Second
+       flag.Func("wait.timeout", "maximum time to wait for a healthy server (default: 60s)", func(val string) (err error) {
+               waitTimeout, err = time.ParseDuration(val)
+               return
+       })
+
+       mtls.SetDefaultIdentity("anonymous")
+
+       flag.Parse()
+
+       var opts []grpc.ClientOption
+
+       if *serverAddr != "" {
+               host, port, err := net.SplitHostPort(*serverAddr)
+               if err != nil {
+                       log.Default().Panic(err)
+               }
+
+               portInt, err := strconv.Atoi(port)
+               if err != nil {
+                       portInt, err = net.LookupPort("tcp", port)
+                       if err != nil {
+                               log.Default().Panic(err)
+                       }
+               }
+
+               if ip := net.ParseIP(host); ip == nil {
+                       log.Default().Panicf("%q: not a valid IPv4 or IPv6 address", host)
+               } else {
+                       opts = append(opts, grpc.WithStaticAddress(&net.TCPAddr{ip, portInt, ""}))
+               }
+       }
+
+       ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+       defer cancel()
+       var deadline time.Time
+
+       if *wait {
+               ctx, cancel = context.WithTimeout(ctx, waitTimeout)
+               defer cancel()
+               deadline, _ = ctx.Deadline()
+       }
+
+       if *watch && *wait {
+               log.Default().Fatal("watch and wait options are mutually exclusive")
+               os.Exit(1)
+       }
+
+       var conn *grpc.ClientConn
+       for {
+               client, err := grpc.NewGrpcClient(ctx, mtls.NewServiceIdentity(*serverId), mtls.DefaultIdentity(), opts...)
+               if err != nil {
+                       if *wait && time.Now().Before(deadline) {
+                               log.Default().Warningf("error connecting (%v) retrying in 1s", err)
+                               time.Sleep(1 * time.Second)
+                               continue
+                       }
+                       log.Default().Critical(err)
+                       os.Exit(1)
+               }
+
+               conn, err = client.Conn()
+               if err != nil {
+                       if *wait && time.Now().Before(deadline) {
+                               log.Default().Warningf("error connecting (%v) retrying in 1s", err)
+                               time.Sleep(1 * time.Second)
+                               continue
+                       }
+                       log.Default().Critical(err)
+                       os.Exit(1)
+               }
+
+               defer conn.Close()
+               break
+       }
+
+       for {
+               healthClient := grpc_health_v1.NewHealthClient(conn)
+
+               if *watch {
+                       stream, err := healthClient.Watch(ctx, &grpc_health_v1.HealthCheckRequest{})
+                       if err != nil {
+                               log.Default().Alertf("healthcheck failed with error %T: %v", err, err)
+                               os.Exit(2)
+                       }
+
+                       for msg, err := stream.Recv(); err == nil; msg, err = stream.Recv() {
+                               log.Default().Infof("server status: %s", msg.Status.String())
+                               time.Sleep(100 * time.Millisecond)
+                       }
+
+                       os.Exit(0)
+               } else if *wait {
+                       stream, err := healthClient.Watch(ctx, &grpc_health_v1.HealthCheckRequest{})
+                       if err != nil {
+                               if time.Now().Before(deadline) {
+                                       log.Default().Warningf("error connecting (%v) retrying in 1s", err)
+                                       time.Sleep(1 * time.Second)
+                                       continue
+                               }
+                               log.Default().Critical(err)
+                               os.Exit(1)
+                       }
+
+                       for msg, err := stream.Recv(); err == nil; msg, err = stream.Recv() {
+                               log.Default().Infof("server status: %s", msg.Status.String())
+                               if msg.Status == grpc_health_v1.HealthCheckResponse_SERVING {
+                                       return
+                               }
+                       }
+
+                       os.Exit(1)
+               } else {
+                       resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
+                       if err != nil {
+                               log.Default().Alertf("healthcheck failed with error %T: %v", err, err)
+                               os.Exit(2)
+                       }
+
+                       log.Default().Infof("server status: %s", resp.Status.String())
+                       if resp.Status == grpc_health_v1.HealthCheckResponse_SERVING {
+                               os.Exit(0)
+                       }
+
+                       os.Exit(1)
+               }
+       }
+}
diff --git a/grpc/healthcheck.go b/grpc/healthcheck.go
new file mode 100644 (file)
index 0000000..ce8b14b
--- /dev/null
@@ -0,0 +1,85 @@
+package grpc
+
+import (
+       "context"
+       "sync"
+
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/health/grpc_health_v1"
+
+       "go.fuhry.dev/runtime/utils/log"
+)
+
+type HealthCheckServicer interface {
+       grpc_health_v1.HealthServer
+
+       SetStatus(grpc_health_v1.HealthCheckResponse_ServingStatus)
+}
+
+type healthCheckServicer struct {
+       s      grpc_health_v1.HealthCheckResponse_ServingStatus
+       logger log.Logger
+
+       watchers   map[chan grpc_health_v1.HealthCheckResponse_ServingStatus]struct{}
+       watchersMu sync.Mutex
+}
+
+var _ grpc_health_v1.HealthServer = &healthCheckServicer{}
+
+func (h *healthCheckServicer) Check(ctx context.Context, req *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
+       return &grpc_health_v1.HealthCheckResponse{
+               Status: h.s,
+       }, nil
+}
+
+func (h *healthCheckServicer) SetStatus(status grpc_health_v1.HealthCheckResponse_ServingStatus) {
+       h.s = status
+
+       h.watchersMu.Lock()
+       defer h.watchersMu.Unlock()
+
+       for wa := range h.watchers {
+               wa <- status
+       }
+}
+
+func (h *healthCheckServicer) Watch(req *grpc_health_v1.HealthCheckRequest, stream grpc.ServerStreamingServer[grpc_health_v1.HealthCheckResponse]) error {
+       h.watchersMu.Lock()
+       ch := make(chan grpc_health_v1.HealthCheckResponse_ServingStatus)
+       h.watchers[ch] = struct{}{}
+       h.watchersMu.Unlock()
+
+       defer (func() {
+               h.watchersMu.Lock()
+               delete(h.watchers, ch)
+               close(ch)
+               h.watchersMu.Unlock()
+       })()
+
+       msg := &grpc_health_v1.HealthCheckResponse{
+               Status: h.s,
+       }
+       if err := stream.SendMsg(msg); err != nil {
+               return err
+       }
+
+       for {
+               select {
+               case st := <-ch:
+                       h.logger.Infof("got status update: %s", st)
+                       msg.Status = st
+                       if err := stream.SendMsg(msg); err != nil {
+                               return err
+                       }
+               case <-stream.Context().Done():
+                       return stream.Context().Err()
+               }
+       }
+}
+
+func NewHealthCheckServicer() HealthCheckServicer {
+       return &healthCheckServicer{
+               logger:   log.WithPrefix("grpc_health_v1"),
+               watchers: make(map[chan grpc_health_v1.HealthCheckResponse_ServingStatus]struct{}),
+       }
+}
index b0954311aca708c34cb243fa06f32e45bc3c9bc2..9e56c73c80fe893bfda4fd6f74704039135924d8 100644 (file)
@@ -19,6 +19,7 @@ import (
        "google.golang.org/grpc"
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/credentials"
+       "google.golang.org/grpc/health/grpc_health_v1"
        "google.golang.org/grpc/peer"
        "google.golang.org/grpc/status"
 )
@@ -33,6 +34,7 @@ type Server struct {
        log        log.Logger
        sessions   *lru.Cache[string, *session]
        connFac    ConnectionFactory
+       hc         HealthCheckServicer
 }
 
 var defaultPort *uint
@@ -85,6 +87,7 @@ func NewGrpcServerWithPort(id mtls.Identity, port uint16) (*Server, error) {
                log:       log.WithPrefix(fmt.Sprintf("grpcServer:%s", id.Name())),
                sessions:  sessionsLru,
                connFac:   NewDefaultConnectionFactory(),
+               hc:        NewHealthCheckServicer(),
        }
 
        return server, nil
@@ -100,6 +103,11 @@ func (s *Server) PublishAndServe(ctx context.Context, callback func(*grpc.Server
        tc.MinVersion = tls.VersionTLS13
        tc.MaxVersion = tls.VersionTLS13
 
+       // Verify the client certificate if presented, but still allow the handshake to complete if no certificate
+       // is presented at all. This allows for anonymous gRPC health checks; see handleConnection and
+       // handleStreamConnection below.
+       tc.ClientAuth = tls.VerifyClientCertIfGiven
+
        err = s.verifier.ConfigureServer(tc)
        if err != nil {
                return err
@@ -123,9 +131,11 @@ func (s *Server) PublishAndServe(ctx context.Context, callback func(*grpc.Server
 
        grpcServer := grpc.NewServer(opts...)
 
+       grpc_health_v1.RegisterHealthServer(grpcServer, s.hc)
        callback(grpcServer)
 
        go grpcServer.Serve(listener)
+       s.hc.SetStatus(grpc_health_v1.HealthCheckResponse_SERVING)
        s.grpcServer = grpcServer
 
        err = s.publisher.Publish(ctx)
@@ -140,6 +150,7 @@ func (s *Server) PublishAndServe(ctx context.Context, callback func(*grpc.Server
 }
 
 func (s *Server) Stop() {
+       s.hc.SetStatus(grpc_health_v1.HealthCheckResponse_NOT_SERVING)
        s.publisher.Unpublish()
        if s.grpcServer != nil {
                s.grpcServer.GracefulStop()
@@ -148,6 +159,10 @@ func (s *Server) Stop() {
 }
 
 func (s *Server) handleConnection(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+       if info.FullMethod == "/grpc.health.v1.Health/Check" {
+               return handler(ctx, req)
+       }
+
        peer, ok := peer.FromContext(ctx)
        if !ok {
                return nil, status.Errorf(codes.PermissionDenied, "client did not authenticate")
@@ -167,6 +182,10 @@ func (s *Server) handleConnection(ctx context.Context, req interface{}, info *gr
 }
 
 func (s *Server) handleStreamConnection(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+       if info.FullMethod == "/grpc.health.v1.Health/Watch" {
+               return handler(srv, ss)
+       }
+
        ctx := ss.Context()
        peer, ok := peer.FromContext(ctx)
        if !ok {
index c42ed84b780b71c424f785147b38607189c62a2b..0fd985a94b4d1688c2631fa6cc31e316846c71ed 100644 (file)
@@ -166,7 +166,7 @@ func unsubscribeInternal(filePath string) error {
        return nil
 }
 
-func unsubscribe(filePath string) error {
+func Unsubscribe(filePath string) error {
        err := unsubscribeInternal(filePath)
        unsubscribeInternal(path.Dir(filePath))
        return err
@@ -278,7 +278,7 @@ func handleEvent(event fsnotify.Event) {
                        return
                }
        }
-       logger.V(2).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op)
+       logger.V(4).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op)
 }
 
 // handleEventForPath is the second stage of event handling which calls the actual handlers.
@@ -294,6 +294,8 @@ func handleEventForPath(event fsnotify.Event, ruleName string) {
                return
        }
 
+       op := event.Op
+
        logger.V(2).Debugf("inotify event on %s: %v", event.Name, event.Op)
        logger.V(2).Debugf("ruleName: %s", ruleName)
        logger.V(2).Debugf("upstreams: %+v", rule.upstreams.AsSlice())
@@ -314,10 +316,10 @@ func handleEventForPath(event fsnotify.Event, ruleName string) {
                handlerPath = ruleName
 
                for _, upstream := range rule.upstreams.AsSortedSlice() {
-                       unsubscribe(upstream)
+                       Unsubscribe(upstream)
                }
                untrackSymlinks(ruleName)
-               unsubscribe(ruleName)
+               Unsubscribe(ruleName)
                if len(watchHandlers) == 1 {
                        logger.V(2).Noticef("all watchers should now be cleaned up:")
                        Debug()
@@ -341,13 +343,14 @@ func handleEventForPath(event fsnotify.Event, ruleName string) {
        if event.Has(Close) {
                if pendingWrites.Contains(event.Name) {
                        pendingWrites.Del(event.Name)
+                       op |= Write
                } else if !fileSwapped {
                        return
                }
        }
 
        for _, handler := range rule.callbacks {
-               handler(handlerPath, event.Op)
+               handler(handlerPath, op)
        }
 }
 
@@ -366,7 +369,7 @@ func untrackSymlinks(ruleName string) {
        for link, targets := range symlinkPropagationMap {
                if targets.Contains(ruleName) {
                        logger.V(2).Debugf("cleaning up watcher on %s for fsnotify watch rule %s", link, ruleName)
-                       unsubscribe(link)
+                       Unsubscribe(link)
                        targets.Del(ruleName)
                }
                if targets.Len() == 0 {
index 5280cc49c7e2ee17145f02cd0210ca549bc9e76c..caf07cce44812aa4bf3030aecb92cdd215393ea4 100644 (file)
@@ -82,6 +82,8 @@ func realpath(filePath string) string {
        return strings.Join(out, string(filepath.Separator))
 }
 
+var RealPath = realpath
+
 // resolveSymlinkRecurse recursively resolves filePath until a real file is found.
 //
 // Both relative and absolute symlinks are handled and normalized.
index ed70ad0d0b61aee9afb430d979092aa58b86bc3e..8bc08f8cf8ab8a011ceb6cc62d87faa26889b550 100644 (file)
@@ -16,6 +16,7 @@ type PrincipalClass int
 
 const (
        InvalidPrincipal PrincipalClass = iota
+       AnonymousPrincipal
        ServicePrincipal
        UserPrincipal
        SSLCertificatePrincipal
@@ -316,11 +317,15 @@ func DefaultIdentity() Identity {
                panic("cannot get default identity before flags are parsed")
        }
 
+       if defaultMtlsIdentity == "anonymous" {
+               return Anonymous()
+       }
+
        if defaultMtlsIdentity == "" {
                userId, err := NewDefaultUserIdentity()
                if err == nil && userId.IsValid() {
                        leafCert, _ := userId.LeafCertificate()
-                       log.Default().Infof("found valid user certificate, using identity: %s", leafCert.Subject)
+                       log.Default().V(1).Infof("found valid user certificate, using identity: %s", leafCert.Subject)
                        return userId
                } else {
                        log.Default().V(2).Debugf("couldn't load a user identity: err: %+v", err)
@@ -332,6 +337,11 @@ func DefaultIdentity() Identity {
        return NewServiceIdentity(defaultMtlsIdentity)
 }
 
+// Anonymous returns an identity that supplies no client certificate
+func Anonymous() Identity {
+       return &anonymousIdentity{}
+}
+
 func IdentityFromTLSConnectionState(state *tls.ConnectionState) (Identity, error) {
        if state == nil {
                return nil, fmt.Errorf("connectionState is nil")
index 672c62c72e727b6eed05a7f0e575f3dd7e27b122..57796ced6daffbc68e76c77f17400d9f5509c3a4 100644 (file)
@@ -107,11 +107,16 @@ func (cv *mtlsPeerVerifier) ConfigureServer(tlsConfig *tls.Config) error {
        if err != nil {
                return err
        }
-       tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo)
+       // If the tls config isn't configured to request a client certificate at all (the default), reconfigure it
+       // to require one, which is usually what you want.
+       if tlsConfig.ClientAuth == tls.NoClientCert {
+               tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+       }
+       tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tlsConfig.ClientAuth)
        tlsConfig.InsecureSkipVerify = true
-       tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo)
+       tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo, tlsConfig.ClientAuth)
        tlsConfig.ClientCAs = vo.Roots
-       tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+       tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
 
        return nil
 }
@@ -121,17 +126,21 @@ func (cv *mtlsPeerVerifier) ConfigureClient(tlsConfig *tls.Config) error {
        if err != nil {
                return err
        }
+       clientAuthType := tls.RequireAndVerifyClientCert
        tlsConfig.InsecureSkipVerify = true
-       tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo)
-       tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo)
+       tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo, clientAuthType)
+       tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo, clientAuthType)
 
        return nil
 }
 
-func (cv *mtlsPeerVerifier) verifyConnectionFunc(verifyOpts x509.VerifyOptions) func(tls.ConnectionState) error {
+func (cv *mtlsPeerVerifier) verifyConnectionFunc(verifyOpts x509.VerifyOptions, clientAuth tls.ClientAuthType) func(tls.ConnectionState) error {
        return func(cs tls.ConnectionState) error {
                if len(cs.PeerCertificates) < 1 {
-                       return fmt.Errorf("no peer certificate provided")
+                       if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven {
+                               return nil
+                       }
+                       return ErrNoCertificatePresented
                }
 
                peerCert := cs.PeerCertificates[0]
index 1a95101d98e17e33d625b1ba6cd988c5ac9b4323..37f4e2837b29740fe91789aa23daf8cade6cddcb 100644 (file)
@@ -1,6 +1,7 @@
 package mtls
 
 import (
+       "crypto/tls"
        "crypto/x509"
        "errors"
        "sync"
@@ -158,7 +159,7 @@ func verifyMTLSCertificateChain(leafCert *x509.Certificate, intermediates []*x50
 
        for _, chain := range chains {
                lastCert := chain[0]
-               logger.Debugf("checking constraints on leaf certificate: %+v", lastCert.Subject.String())
+               logger.V(2).Debugf("checking constraints on leaf certificate: %+v", lastCert.Subject.String())
                if err := checkLeafCertificateConstraints(lastCert); err != nil {
                        return err
                }
@@ -173,10 +174,10 @@ func NewVerifyMTLSPeerCertificateFunc() (tlsVerifyPeerCertificatesFunc, error) {
                return nil, err
        }
 
-       return NewVerifyMTLSPeerCertificateFuncWithOpts(vo), nil
+       return NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tls.RequireAndVerifyClientCert), nil
 }
 
-func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions) tlsVerifyPeerCertificatesFunc {
+func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions, clientAuth tls.ClientAuthType) tlsVerifyPeerCertificatesFunc {
        return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
                var leafCert *x509.Certificate
 
@@ -199,6 +200,9 @@ func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions) tlsVerifyPe
                }
 
                if leafCert == nil {
+                       if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven {
+                               return nil
+                       }
                        return ErrNoCertificatePresented
                }
 
index 5212aceaa41416aa783d1b191bf0c604fb5bce89..d9a9524d1bac0735369c9f5edc3947bc84a4ce6b 100644 (file)
@@ -7,6 +7,7 @@ import (
        "sync"
        "time"
 
+       "go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
        "go.etcd.io/etcd/client/pkg/v3/srv"
        etcd_client "go.etcd.io/etcd/client/v3"
        "go.fuhry.dev/runtime/mtls"
@@ -96,6 +97,40 @@ func NewEtcdClientWithDeadline(id mtls.Identity, domain string, deadline time.Ti
        return client, nil
 }
 
+func ErrorIsRecoverable(err error) bool {
+       rpcErr, ok := err.(rpctypes.EtcdError)
+       if !ok {
+               // not an EtcdError
+               return false
+       }
+
+       var unrecoverableCodes = []error{
+               rpctypes.ErrAuthFailed,
+               rpctypes.ErrRoleNotGranted,
+               rpctypes.ErrAuthFailed,
+               rpctypes.ErrAuthNotEnabled,
+               rpctypes.ErrInvalidAuthToken,
+               rpctypes.ErrAuthOldRevision,
+               rpctypes.ErrInvalidAuthMgmt,
+               rpctypes.ErrClusterIdMismatch,
+               rpctypes.ErrNoLeader,
+               rpctypes.ErrNotCapable,
+               rpctypes.ErrStopped,
+               rpctypes.ErrTimeoutDueToConnectionLost,
+               rpctypes.ErrTimeoutDueToLeaderFail,
+               rpctypes.ErrUnhealthy,
+               rpctypes.ErrCorrupt,
+       }
+
+       for _, c := range unrecoverableCodes {
+               if rpcErr.Code() == c.(rpctypes.EtcdError).Code() {
+                       return false
+               }
+       }
+
+       return true
+}
+
 func init() {
        flag.StringVar(&etcdMtlsId, "etcd.mtls.id", "etcd-client", "mTLS identity to use for connecting to etcd")
        flag.UintVar(&etcdStartupTimeoutMs, "etcd.startup-timeout", etcdStartupTimeoutMs, "max timeout (in ms) for etcd startup attempts before failing")
index 5d95d1771c6f8322c2265a76bad27ac62b805210..cdda3b3fb83f12fb4416edc9424831e3f5b84266 100644 (file)
@@ -101,7 +101,7 @@ func (w *SDWatcher) GetAddrs(ctx context.Context) ([]ServiceAddress, error) {
 func (w *SDWatcher) watch(ctx context.Context) {
        kvs := make(map[string][]byte, 0)
 
-       w.logger.Infof("Watching for service publications under path %s", w.prefix())
+       w.logger.V(1).Infof("Watching for service publications under path %s", w.prefix())
 
        items, err := w.EtcdClient.Get(ctx, w.prefix(), etcd_client.WithPrefix())
        if err == nil {