From d539b24ad77d2235fd5ddae7972cb986eddf2bd5 Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Thu, 6 Nov 2025 07:03:20 -0500 Subject: [PATCH] more ephs --- ephs/client/main.go | 275 ++++++++++++++++++++++++++++++++++++++ ephs/server/Makefile | 14 ++ ephs/server/ephs_acl.yaml | 3 + ephs/server/main.go | 64 +++++++++ ephs/server/rules.yaml | 11 ++ grpc/client.go | 109 +++++++++++++-- grpc/conn_quic.go | 30 ++++- grpc/health_probe/main.go | 152 +++++++++++++++++++++ grpc/healthcheck.go | 85 ++++++++++++ grpc/server.go | 19 +++ mtls/fsnotify/fsnotify.go | 15 ++- mtls/fsnotify/util.go | 2 + mtls/identity.go | 12 +- mtls/verify_names.go | 23 +++- mtls/verify_roots.go | 10 +- sd/etcd_factory.go | 35 +++++ sd/watcher.go | 2 +- 17 files changed, 827 insertions(+), 34 deletions(-) create mode 100644 ephs/client/main.go create mode 100644 ephs/server/Makefile create mode 100644 ephs/server/ephs_acl.yaml create mode 100644 ephs/server/main.go create mode 100644 ephs/server/rules.yaml create mode 100644 grpc/health_probe/main.go create mode 100644 grpc/healthcheck.go diff --git a/ephs/client/main.go b/ephs/client/main.go new file mode 100644 index 0000000..f15ca24 --- /dev/null +++ b/ephs/client/main.go @@ -0,0 +1,275 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "io" + "os" + "os/signal" + "path" + "strings" + "syscall" + + "github.com/urfave/cli/v3" + + "go.fuhry.dev/runtime/constants" + "go.fuhry.dev/runtime/ephs" + "go.fuhry.dev/runtime/utils/log" +) + +func cmdStat(ctx context.Context, cmd *cli.Command) error { + client, err := ephs.DefaultClient() + if err != nil { + return err + } + + resp, err := client.StatContext(ctx, strings.TrimSuffix(cmd.StringArg("path"), "/")) + if err != nil { + return err + } + + fmt.Println(ephs.FormatFsEntry(resp)) + return nil +} + +func cmdCat(ctx context.Context, cmd *cli.Command) error { + client, err := ephs.DefaultClient() + if err != nil { + return err + } + + reader, err := client.GetContext(ctx, cmd.StringArg("path")) + if err != nil { + return err + } + + io.Copy(os.Stdout, reader) + + return nil +} + +func cmdCopy(ctx context.Context, cmd *cli.Command) error { + client, err := ephs.DefaultClient() + if err != nil { + return err + } + + src := cmd.StringArg("src") + dst := cmd.StringArg("dst") + + if ephs.IsEphsPath(src) && ephs.IsEphsPath(dst) { + return errors.New("copying from one ephs path to another is not presently supported") + } else if !ephs.IsEphsPath(src) && !ephs.IsEphsPath(dst) { + return errors.New("copying from one non-ephs path to another is not supported") + } + + if strings.HasSuffix(dst, "/") { + dst += path.Base(src) + } + + if ephs.IsEphsPath(src) { + // copy from ephs to local + fp, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE, os.FileMode(0600)) + if err != nil { + return err + } + defer fp.Close() + + reader, err := client.GetContext(ctx, src) + if err != nil { + return err + } + + nw, _ := io.Copy(fp, reader) + + log.Default().Infof("wrote %d bytes to %s", nw, dst) + return nil + } else if ephs.IsEphsPath(dst) { + // copy from local to ephs (put request) + st, err := os.Stat(src) + if err != nil { + return err + } + fp, err := os.OpenFile(src, os.O_RDONLY, os.FileMode(0)) + if err != nil { + return err + } + defer fp.Close() + + resp, err := client.PutContext(ctx, dst, uint64(st.Size()), fp) + if err != nil { + return err + } + + log.Default().Noticef("successfully uploaded: %s\n%s", dst, ephs.FormatFsEntry(resp)) + } + + return nil +} + +func cmdDelete(ctx context.Context, cmd *cli.Command) error { + client, err := ephs.DefaultClient() + if err != nil { + return err + } + + err = client.DeleteContext(ctx, cmd.StringArg("path"), cmd.IsSet("recursive")) + if err != nil { + return err + } + + log.Default().Infof("Deleted object: %s", cmd.StringArg("path")) + return nil +} + +func cmdWatch(ctx context.Context, cmd *cli.Command) error { + client, err := ephs.DefaultClient() + if err != nil { + return err + } + + watchCl, err := client.Watch(ctx, cmd.StringArg("path")) + if err != nil { + return err + } + + for msg := range watchCl { + log.Default().Infof("Got event:\nEvent type: %s\n%s", msg.Event.String(), ephs.FormatFsEntry(msg.GetEntry())) + } + return nil +} + +func cmdMkdir(ctx context.Context, cmd *cli.Command) error { + client, err := ephs.DefaultClient() + if err != nil { + return err + } + + err = client.MkDirContext(ctx, cmd.StringArg("path"), cmd.Bool("recursive")) + if err != nil { + return err + } + + log.Default().Infof("Created directory: %s", cmd.StringArg("path")) + return nil +} + +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + flag.Parse() + + cmd := &cli.Command{ + Name: "ephs", + Version: constants.Version, + Description: "interact with ephs", + + Commands: []*cli.Command{ + { + Name: "stat", + Aliases: []string{"ls", "dir", "list"}, + Description: "get info about a file or list contents of a directory", + Action: cmdStat, + Arguments: []cli.Argument{ + &cli.StringArg{ + Name: "path", + UsageText: "path to read", + }, + }, + }, + { + Name: "cat", + Description: "read a file", + Action: cmdCat, + Arguments: []cli.Argument{ + &cli.StringArg{ + Name: "path", + UsageText: "path to read", + }, + }, + }, + { + Name: "cp", + Description: "copy a file", + Action: cmdCopy, + Arguments: []cli.Argument{ + &cli.StringArg{ + Name: "src", + UsageText: "source", + }, + &cli.StringArg{ + Name: "dst", + UsageText: "destination", + }, + }, + }, + { + Name: "rm", + Description: "delete a file", + Action: cmdDelete, + Arguments: []cli.Argument{ + &cli.StringArg{ + Name: "path", + UsageText: "path to read", + }, + }, + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "recursive", + Usage: "allow removing directories and all of their contents", + }, + }, + }, + { + Name: "watch", + Description: "watch a path for changes", + Action: cmdWatch, + Arguments: []cli.Argument{ + &cli.StringArg{ + Name: "path", + UsageText: "path to read", + }, + }, + }, + { + Name: "mkdir", + Description: "create a directory", + Action: cmdMkdir, + Arguments: []cli.Argument{ + &cli.StringArg{ + Name: "path", + UsageText: "directory to create", + }, + }, + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "recursive", + Usage: "also create any parent directories that don't already exist", + }, + }, + }, + }, + + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "vv", + Usage: "verbosity level", + }, + &cli.StringFlag{ + Name: "v", + Usage: "log level", + }, + &cli.StringFlag{ + Name: "grpc.transport", + Usage: "grpc transport (tcp or quic)", + }, + }, + } + + if err := cmd.Run(ctx, os.Args); err != nil { + log.Default().Panic(err) + } +} diff --git a/ephs/server/Makefile b/ephs/server/Makefile new file mode 100644 index 0000000..bfab546 --- /dev/null +++ b/ephs/server/Makefile @@ -0,0 +1,14 @@ +GOSRC = $(wildcard *.go) +GOEXE = $(shell basename `pwd`) +GOBUILDFLAGS := -buildmode=pie -trimpath + +all: $(GOEXE) + +clean: + rm -fv $(GOEXE) + +.PHONY: all clean + +$(GOEXE): %: $(GOSRC) + go build $(GOBUILDFLAGS) -o $@ $< + diff --git a/ephs/server/ephs_acl.yaml b/ephs/server/ephs_acl.yaml new file mode 100644 index 0000000..c4b1b46 --- /dev/null +++ b/ephs/server/ephs_acl.yaml @@ -0,0 +1,3 @@ +DEFAULT: + - service: '*' + - user: '*' diff --git a/ephs/server/main.go b/ephs/server/main.go new file mode 100644 index 0000000..d0e24cf --- /dev/null +++ b/ephs/server/main.go @@ -0,0 +1,64 @@ +package main + +import ( + "context" + "flag" + "os/signal" + "syscall" + + "go.fuhry.dev/runtime/ephs/servicer" + "go.fuhry.dev/runtime/grpc" + "go.fuhry.dev/runtime/mtls" + ephs_pb "go.fuhry.dev/runtime/proto/service/ephs" + "go.fuhry.dev/runtime/utils/log" + + google_grpc "google.golang.org/grpc" +) + +func main() { + var err error + + acl := flag.String("ephs.acl", "", "YAML file containing ACLs for ephs") + awsFile := flag.String("ephs.s3-creds-file", "", "file to load AWS S3 credentials from") + awsEnv := flag.Bool("ephs.s3-creds-env", false, "set to true to load AWS credentials from environment") + + flag.Parse() + + serverIdentity := mtls.DefaultIdentity() + s, err := grpc.NewGrpcServer(serverIdentity) + if err != nil { + panic(err) + } + + var opts []servicer.Option + if *acl != "" { + opts = append(opts, servicer.WithAclFile(*acl)) + } + + if awsEnv != nil && *awsEnv { + opts = append(opts, servicer.WithAWSEnvCredentials()) + } else if awsFile != nil && *awsFile != "" { + opts = append(opts, servicer.WithAWSCredentialFile(*awsFile)) + } + + serv, err := servicer.NewEphsServicer(opts...) + if err != nil { + log.Fatal(err) + } + + ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + + err = s.PublishAndServe(ctx, func(s *google_grpc.Server) { + ephs_pb.RegisterEphsServer(s, serv) + }) + if err != nil { + panic(err) + } + defer s.Stop() + + <-ctx.Done() +} + +func init() { + mtls.SetDefaultIdentity("ephs") +} diff --git a/ephs/server/rules.yaml b/ephs/server/rules.yaml new file mode 100644 index 0000000..9ba49d6 --- /dev/null +++ b/ephs/server/rules.yaml @@ -0,0 +1,11 @@ +rules: +- principal: + prefix: spiffe://roc.xx0r.info/service/ + key: + prefix: services/{{principal}}/ +- principal: + or: + - exact: spiffe://roc.xx0r.info/user/dan + - exact: spiffe://roc.xx0r.info/service/conffs + key: + any: true diff --git a/grpc/client.go b/grpc/client.go index e904b30..93072f5 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -3,44 +3,127 @@ package grpc import ( "context" "fmt" + "net" "go.fuhry.dev/runtime/mtls" "go.fuhry.dev/runtime/sd" "google.golang.org/grpc" ) -type Client struct { +type ClientConn = grpc.ClientConn + +type Client interface { + Conn() (*grpc.ClientConn, error) +} + +type ClientOption interface { + apply(*client) error +} + +type clientOption struct { + f func(*client) error +} + +type AddressProvider interface { + GetAddrs(context.Context) ([]sd.ServiceAddress, error) +} + +func (o *clientOption) apply(c *client) error { + return o.f(c) +} + +type client struct { ctx context.Context serverId mtls.Identity clientId mtls.Identity - watcher *sd.SDWatcher + watcher AddressProvider connFac ConnectionFactory } -func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity) (*Client, error) { - etcdc, err := sd.NewDefaultEtcdClient() - if err != nil { - panic(err) +type staticAddressProvider struct { + addresses []sd.ServiceAddress +} + +func (s *staticAddressProvider) GetAddrs(_ context.Context) ([]sd.ServiceAddress, error) { + return s.addresses, nil +} + +func WithConnectionFactory(fac ConnectionFactory) ClientOption { + return &clientOption{ + f: func(c *client) error { + c.connFac = fac + return nil + }, + } +} + +func WithAddressProvider(ap AddressProvider) ClientOption { + return &clientOption{ + f: func(c *client) error { + c.watcher = ap + return nil + }, + } +} + +func WithStaticAddress(addresses ...*net.TCPAddr) ClientOption { + var addrs []sd.ServiceAddress + + for _, addr := range addresses { + var ip4, ip6 string + if len(addr.IP) == 4 { + ip4 = addr.IP.String() + } else { + ip6 = addr.IP.String() + } + addrs = append(addrs, sd.ServiceAddress{ + IP4: ip4, + IP6: ip6, + Port: uint16(addr.Port), + }) } - w := &sd.SDWatcher{ - Service: serverId.Name(), - EtcdClient: etcdc, - Protocol: sd.ProtocolGRPC, + return &clientOption{ + f: func(c *client) error { + c.watcher = &staticAddressProvider{ + addresses: addrs, + } + return nil + }, } +} - cl := &Client{ +func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity, opts ...ClientOption) (Client, error) { + cl := &client{ ctx: ctx, serverId: serverId, clientId: clientId, - watcher: w, connFac: NewDefaultConnectionFactory(), } + for _, opt := range opts { + if err := opt.apply(cl); err != nil { + return nil, err + } + } + + if cl.watcher == nil { + etcdc, err := sd.NewDefaultEtcdClient() + if err != nil { + return nil, err + } + + cl.watcher = &sd.SDWatcher{ + Service: serverId.Name(), + EtcdClient: etcdc, + Protocol: sd.ProtocolGRPC, + } + } + return cl, nil } -func (c *Client) Conn() (*grpc.ClientConn, error) { +func (c *client) Conn() (*grpc.ClientConn, error) { addrs, err := c.watcher.GetAddrs(c.ctx) if err != nil { return nil, err diff --git a/grpc/conn_quic.go b/grpc/conn_quic.go index 99d16df..c949686 100644 --- a/grpc/conn_quic.go +++ b/grpc/conn_quic.go @@ -1,9 +1,11 @@ package grpc import ( + "context" "crypto/tls" "fmt" "net" + "time" "google.golang.org/grpc/credentials" @@ -11,7 +13,14 @@ import ( grpc_quic "go.fuhry.dev/grpc-quic" ) -type QUICConnectionFactory struct{} +var defaultQuicConfig = &quic.Config{ + HandshakeIdleTimeout: 15 * time.Second, + MaxIdleTimeout: 5 * time.Minute, +} + +type QUICConnectionFactory struct { + QUICConfig *quic.Config +} func (cf *QUICConnectionFactory) NewCredentials(tlsConfig *tls.Config) credentials.TransportCredentials { tlsConfig.NextProtos = []string{"grpc-quic"} @@ -24,7 +33,11 @@ func (cf *QUICConnectionFactory) NewListener(port uint16, tlsConfig *tls.Config) return nil, err } - quicListener, err := quic.Listen(udpListener, tlsConfig, nil) + if cf.QUICConfig == nil { + cf.QUICConfig = defaultQuicConfig.Clone() + } + + quicListener, err := quic.Listen(udpListener, tlsConfig, cf.QUICConfig.Clone()) if err != nil { return nil, err } @@ -35,7 +48,18 @@ func (cf *QUICConnectionFactory) NewListener(port uint16, tlsConfig *tls.Config) } func (cf *QUICConnectionFactory) NewDialer(tlsConfig *tls.Config) ContextDialer { - return grpc_quic.NewQuicDialer(tlsConfig) + if cf.QUICConfig == nil { + cf.QUICConfig = defaultQuicConfig.Clone() + } + + return func(ctx context.Context, target string) (net.Conn, error) { + sess, err := quic.DialAddr(ctx, target, tlsConfig, cf.QUICConfig.Clone()) + if err != nil { + return nil, err + } + + return grpc_quic.NewConn(sess) + } } func init() { diff --git a/grpc/health_probe/main.go b/grpc/health_probe/main.go new file mode 100644 index 0000000..a32adc9 --- /dev/null +++ b/grpc/health_probe/main.go @@ -0,0 +1,152 @@ +package main + +import ( + "context" + "flag" + "net" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + "go.fuhry.dev/runtime/grpc" + "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/utils/log" + + "google.golang.org/grpc/health/grpc_health_v1" +) + +func main() { + serverId := flag.String("server-id", "", "mtls ID of gRPC server") + serverAddr := flag.String("server-addr", "", "server address as ip:port only") + watch := flag.Bool("watch", false, "continually monitor server for health status changes") + wait := flag.Bool("wait", false, "continually retry connecting, then stream server status until RUNNING") + var waitTimeout = 60 * time.Second + flag.Func("wait.timeout", "maximum time to wait for a healthy server (default: 60s)", func(val string) (err error) { + waitTimeout, err = time.ParseDuration(val) + return + }) + + mtls.SetDefaultIdentity("anonymous") + + flag.Parse() + + var opts []grpc.ClientOption + + if *serverAddr != "" { + host, port, err := net.SplitHostPort(*serverAddr) + if err != nil { + log.Default().Panic(err) + } + + portInt, err := strconv.Atoi(port) + if err != nil { + portInt, err = net.LookupPort("tcp", port) + if err != nil { + log.Default().Panic(err) + } + } + + if ip := net.ParseIP(host); ip == nil { + log.Default().Panicf("%q: not a valid IPv4 or IPv6 address", host) + } else { + opts = append(opts, grpc.WithStaticAddress(&net.TCPAddr{ip, portInt, ""})) + } + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + var deadline time.Time + + if *wait { + ctx, cancel = context.WithTimeout(ctx, waitTimeout) + defer cancel() + deadline, _ = ctx.Deadline() + } + + if *watch && *wait { + log.Default().Fatal("watch and wait options are mutually exclusive") + os.Exit(1) + } + + var conn *grpc.ClientConn + for { + client, err := grpc.NewGrpcClient(ctx, mtls.NewServiceIdentity(*serverId), mtls.DefaultIdentity(), opts...) + if err != nil { + if *wait && time.Now().Before(deadline) { + log.Default().Warningf("error connecting (%v) retrying in 1s", err) + time.Sleep(1 * time.Second) + continue + } + log.Default().Critical(err) + os.Exit(1) + } + + conn, err = client.Conn() + if err != nil { + if *wait && time.Now().Before(deadline) { + log.Default().Warningf("error connecting (%v) retrying in 1s", err) + time.Sleep(1 * time.Second) + continue + } + log.Default().Critical(err) + os.Exit(1) + } + + defer conn.Close() + break + } + + for { + healthClient := grpc_health_v1.NewHealthClient(conn) + + if *watch { + stream, err := healthClient.Watch(ctx, &grpc_health_v1.HealthCheckRequest{}) + if err != nil { + log.Default().Alertf("healthcheck failed with error %T: %v", err, err) + os.Exit(2) + } + + for msg, err := stream.Recv(); err == nil; msg, err = stream.Recv() { + log.Default().Infof("server status: %s", msg.Status.String()) + time.Sleep(100 * time.Millisecond) + } + + os.Exit(0) + } else if *wait { + stream, err := healthClient.Watch(ctx, &grpc_health_v1.HealthCheckRequest{}) + if err != nil { + if time.Now().Before(deadline) { + log.Default().Warningf("error connecting (%v) retrying in 1s", err) + time.Sleep(1 * time.Second) + continue + } + log.Default().Critical(err) + os.Exit(1) + } + + for msg, err := stream.Recv(); err == nil; msg, err = stream.Recv() { + log.Default().Infof("server status: %s", msg.Status.String()) + if msg.Status == grpc_health_v1.HealthCheckResponse_SERVING { + return + } + } + + os.Exit(1) + } else { + resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{}) + if err != nil { + log.Default().Alertf("healthcheck failed with error %T: %v", err, err) + os.Exit(2) + } + + log.Default().Infof("server status: %s", resp.Status.String()) + if resp.Status == grpc_health_v1.HealthCheckResponse_SERVING { + os.Exit(0) + } + + os.Exit(1) + } + } +} diff --git a/grpc/healthcheck.go b/grpc/healthcheck.go new file mode 100644 index 0000000..ce8b14b --- /dev/null +++ b/grpc/healthcheck.go @@ -0,0 +1,85 @@ +package grpc + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/health/grpc_health_v1" + + "go.fuhry.dev/runtime/utils/log" +) + +type HealthCheckServicer interface { + grpc_health_v1.HealthServer + + SetStatus(grpc_health_v1.HealthCheckResponse_ServingStatus) +} + +type healthCheckServicer struct { + s grpc_health_v1.HealthCheckResponse_ServingStatus + logger log.Logger + + watchers map[chan grpc_health_v1.HealthCheckResponse_ServingStatus]struct{} + watchersMu sync.Mutex +} + +var _ grpc_health_v1.HealthServer = &healthCheckServicer{} + +func (h *healthCheckServicer) Check(ctx context.Context, req *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) { + return &grpc_health_v1.HealthCheckResponse{ + Status: h.s, + }, nil +} + +func (h *healthCheckServicer) SetStatus(status grpc_health_v1.HealthCheckResponse_ServingStatus) { + h.s = status + + h.watchersMu.Lock() + defer h.watchersMu.Unlock() + + for wa := range h.watchers { + wa <- status + } +} + +func (h *healthCheckServicer) Watch(req *grpc_health_v1.HealthCheckRequest, stream grpc.ServerStreamingServer[grpc_health_v1.HealthCheckResponse]) error { + h.watchersMu.Lock() + ch := make(chan grpc_health_v1.HealthCheckResponse_ServingStatus) + h.watchers[ch] = struct{}{} + h.watchersMu.Unlock() + + defer (func() { + h.watchersMu.Lock() + delete(h.watchers, ch) + close(ch) + h.watchersMu.Unlock() + })() + + msg := &grpc_health_v1.HealthCheckResponse{ + Status: h.s, + } + if err := stream.SendMsg(msg); err != nil { + return err + } + + for { + select { + case st := <-ch: + h.logger.Infof("got status update: %s", st) + msg.Status = st + if err := stream.SendMsg(msg); err != nil { + return err + } + case <-stream.Context().Done(): + return stream.Context().Err() + } + } +} + +func NewHealthCheckServicer() HealthCheckServicer { + return &healthCheckServicer{ + logger: log.WithPrefix("grpc_health_v1"), + watchers: make(map[chan grpc_health_v1.HealthCheckResponse_ServingStatus]struct{}), + } +} diff --git a/grpc/server.go b/grpc/server.go index b095431..9e56c73 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) @@ -33,6 +34,7 @@ type Server struct { log log.Logger sessions *lru.Cache[string, *session] connFac ConnectionFactory + hc HealthCheckServicer } var defaultPort *uint @@ -85,6 +87,7 @@ func NewGrpcServerWithPort(id mtls.Identity, port uint16) (*Server, error) { log: log.WithPrefix(fmt.Sprintf("grpcServer:%s", id.Name())), sessions: sessionsLru, connFac: NewDefaultConnectionFactory(), + hc: NewHealthCheckServicer(), } return server, nil @@ -100,6 +103,11 @@ func (s *Server) PublishAndServe(ctx context.Context, callback func(*grpc.Server tc.MinVersion = tls.VersionTLS13 tc.MaxVersion = tls.VersionTLS13 + // Verify the client certificate if presented, but still allow the handshake to complete if no certificate + // is presented at all. This allows for anonymous gRPC health checks; see handleConnection and + // handleStreamConnection below. + tc.ClientAuth = tls.VerifyClientCertIfGiven + err = s.verifier.ConfigureServer(tc) if err != nil { return err @@ -123,9 +131,11 @@ func (s *Server) PublishAndServe(ctx context.Context, callback func(*grpc.Server grpcServer := grpc.NewServer(opts...) + grpc_health_v1.RegisterHealthServer(grpcServer, s.hc) callback(grpcServer) go grpcServer.Serve(listener) + s.hc.SetStatus(grpc_health_v1.HealthCheckResponse_SERVING) s.grpcServer = grpcServer err = s.publisher.Publish(ctx) @@ -140,6 +150,7 @@ func (s *Server) PublishAndServe(ctx context.Context, callback func(*grpc.Server } func (s *Server) Stop() { + s.hc.SetStatus(grpc_health_v1.HealthCheckResponse_NOT_SERVING) s.publisher.Unpublish() if s.grpcServer != nil { s.grpcServer.GracefulStop() @@ -148,6 +159,10 @@ func (s *Server) Stop() { } func (s *Server) handleConnection(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if info.FullMethod == "/grpc.health.v1.Health/Check" { + return handler(ctx, req) + } + peer, ok := peer.FromContext(ctx) if !ok { return nil, status.Errorf(codes.PermissionDenied, "client did not authenticate") @@ -167,6 +182,10 @@ func (s *Server) handleConnection(ctx context.Context, req interface{}, info *gr } func (s *Server) handleStreamConnection(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if info.FullMethod == "/grpc.health.v1.Health/Watch" { + return handler(srv, ss) + } + ctx := ss.Context() peer, ok := peer.FromContext(ctx) if !ok { diff --git a/mtls/fsnotify/fsnotify.go b/mtls/fsnotify/fsnotify.go index c42ed84..0fd985a 100644 --- a/mtls/fsnotify/fsnotify.go +++ b/mtls/fsnotify/fsnotify.go @@ -166,7 +166,7 @@ func unsubscribeInternal(filePath string) error { return nil } -func unsubscribe(filePath string) error { +func Unsubscribe(filePath string) error { err := unsubscribeInternal(filePath) unsubscribeInternal(path.Dir(filePath)) return err @@ -278,7 +278,7 @@ func handleEvent(event fsnotify.Event) { return } } - logger.V(2).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op) + logger.V(4).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op) } // handleEventForPath is the second stage of event handling which calls the actual handlers. @@ -294,6 +294,8 @@ func handleEventForPath(event fsnotify.Event, ruleName string) { return } + op := event.Op + logger.V(2).Debugf("inotify event on %s: %v", event.Name, event.Op) logger.V(2).Debugf("ruleName: %s", ruleName) logger.V(2).Debugf("upstreams: %+v", rule.upstreams.AsSlice()) @@ -314,10 +316,10 @@ func handleEventForPath(event fsnotify.Event, ruleName string) { handlerPath = ruleName for _, upstream := range rule.upstreams.AsSortedSlice() { - unsubscribe(upstream) + Unsubscribe(upstream) } untrackSymlinks(ruleName) - unsubscribe(ruleName) + Unsubscribe(ruleName) if len(watchHandlers) == 1 { logger.V(2).Noticef("all watchers should now be cleaned up:") Debug() @@ -341,13 +343,14 @@ func handleEventForPath(event fsnotify.Event, ruleName string) { if event.Has(Close) { if pendingWrites.Contains(event.Name) { pendingWrites.Del(event.Name) + op |= Write } else if !fileSwapped { return } } for _, handler := range rule.callbacks { - handler(handlerPath, event.Op) + handler(handlerPath, op) } } @@ -366,7 +369,7 @@ func untrackSymlinks(ruleName string) { for link, targets := range symlinkPropagationMap { if targets.Contains(ruleName) { logger.V(2).Debugf("cleaning up watcher on %s for fsnotify watch rule %s", link, ruleName) - unsubscribe(link) + Unsubscribe(link) targets.Del(ruleName) } if targets.Len() == 0 { diff --git a/mtls/fsnotify/util.go b/mtls/fsnotify/util.go index 5280cc4..caf07cc 100644 --- a/mtls/fsnotify/util.go +++ b/mtls/fsnotify/util.go @@ -82,6 +82,8 @@ func realpath(filePath string) string { return strings.Join(out, string(filepath.Separator)) } +var RealPath = realpath + // resolveSymlinkRecurse recursively resolves filePath until a real file is found. // // Both relative and absolute symlinks are handled and normalized. diff --git a/mtls/identity.go b/mtls/identity.go index ed70ad0..8bc08f8 100644 --- a/mtls/identity.go +++ b/mtls/identity.go @@ -16,6 +16,7 @@ type PrincipalClass int const ( InvalidPrincipal PrincipalClass = iota + AnonymousPrincipal ServicePrincipal UserPrincipal SSLCertificatePrincipal @@ -316,11 +317,15 @@ func DefaultIdentity() Identity { panic("cannot get default identity before flags are parsed") } + if defaultMtlsIdentity == "anonymous" { + return Anonymous() + } + if defaultMtlsIdentity == "" { userId, err := NewDefaultUserIdentity() if err == nil && userId.IsValid() { leafCert, _ := userId.LeafCertificate() - log.Default().Infof("found valid user certificate, using identity: %s", leafCert.Subject) + log.Default().V(1).Infof("found valid user certificate, using identity: %s", leafCert.Subject) return userId } else { log.Default().V(2).Debugf("couldn't load a user identity: err: %+v", err) @@ -332,6 +337,11 @@ func DefaultIdentity() Identity { return NewServiceIdentity(defaultMtlsIdentity) } +// Anonymous returns an identity that supplies no client certificate +func Anonymous() Identity { + return &anonymousIdentity{} +} + func IdentityFromTLSConnectionState(state *tls.ConnectionState) (Identity, error) { if state == nil { return nil, fmt.Errorf("connectionState is nil") diff --git a/mtls/verify_names.go b/mtls/verify_names.go index 672c62c..57796ce 100644 --- a/mtls/verify_names.go +++ b/mtls/verify_names.go @@ -107,11 +107,16 @@ func (cv *mtlsPeerVerifier) ConfigureServer(tlsConfig *tls.Config) error { if err != nil { return err } - tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo) + // If the tls config isn't configured to request a client certificate at all (the default), reconfigure it + // to require one, which is usually what you want. + if tlsConfig.ClientAuth == tls.NoClientCert { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tlsConfig.ClientAuth) tlsConfig.InsecureSkipVerify = true - tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo) + tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo, tlsConfig.ClientAuth) tlsConfig.ClientCAs = vo.Roots - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven return nil } @@ -121,17 +126,21 @@ func (cv *mtlsPeerVerifier) ConfigureClient(tlsConfig *tls.Config) error { if err != nil { return err } + clientAuthType := tls.RequireAndVerifyClientCert tlsConfig.InsecureSkipVerify = true - tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo) - tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo) + tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo, clientAuthType) + tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo, clientAuthType) return nil } -func (cv *mtlsPeerVerifier) verifyConnectionFunc(verifyOpts x509.VerifyOptions) func(tls.ConnectionState) error { +func (cv *mtlsPeerVerifier) verifyConnectionFunc(verifyOpts x509.VerifyOptions, clientAuth tls.ClientAuthType) func(tls.ConnectionState) error { return func(cs tls.ConnectionState) error { if len(cs.PeerCertificates) < 1 { - return fmt.Errorf("no peer certificate provided") + if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven { + return nil + } + return ErrNoCertificatePresented } peerCert := cs.PeerCertificates[0] diff --git a/mtls/verify_roots.go b/mtls/verify_roots.go index 1a95101..37f4e28 100644 --- a/mtls/verify_roots.go +++ b/mtls/verify_roots.go @@ -1,6 +1,7 @@ package mtls import ( + "crypto/tls" "crypto/x509" "errors" "sync" @@ -158,7 +159,7 @@ func verifyMTLSCertificateChain(leafCert *x509.Certificate, intermediates []*x50 for _, chain := range chains { lastCert := chain[0] - logger.Debugf("checking constraints on leaf certificate: %+v", lastCert.Subject.String()) + logger.V(2).Debugf("checking constraints on leaf certificate: %+v", lastCert.Subject.String()) if err := checkLeafCertificateConstraints(lastCert); err != nil { return err } @@ -173,10 +174,10 @@ func NewVerifyMTLSPeerCertificateFunc() (tlsVerifyPeerCertificatesFunc, error) { return nil, err } - return NewVerifyMTLSPeerCertificateFuncWithOpts(vo), nil + return NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tls.RequireAndVerifyClientCert), nil } -func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions) tlsVerifyPeerCertificatesFunc { +func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions, clientAuth tls.ClientAuthType) tlsVerifyPeerCertificatesFunc { return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { var leafCert *x509.Certificate @@ -199,6 +200,9 @@ func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions) tlsVerifyPe } if leafCert == nil { + if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven { + return nil + } return ErrNoCertificatePresented } diff --git a/sd/etcd_factory.go b/sd/etcd_factory.go index 5212ace..d9a9524 100644 --- a/sd/etcd_factory.go +++ b/sd/etcd_factory.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" "go.etcd.io/etcd/client/pkg/v3/srv" etcd_client "go.etcd.io/etcd/client/v3" "go.fuhry.dev/runtime/mtls" @@ -96,6 +97,40 @@ func NewEtcdClientWithDeadline(id mtls.Identity, domain string, deadline time.Ti return client, nil } +func ErrorIsRecoverable(err error) bool { + rpcErr, ok := err.(rpctypes.EtcdError) + if !ok { + // not an EtcdError + return false + } + + var unrecoverableCodes = []error{ + rpctypes.ErrAuthFailed, + rpctypes.ErrRoleNotGranted, + rpctypes.ErrAuthFailed, + rpctypes.ErrAuthNotEnabled, + rpctypes.ErrInvalidAuthToken, + rpctypes.ErrAuthOldRevision, + rpctypes.ErrInvalidAuthMgmt, + rpctypes.ErrClusterIdMismatch, + rpctypes.ErrNoLeader, + rpctypes.ErrNotCapable, + rpctypes.ErrStopped, + rpctypes.ErrTimeoutDueToConnectionLost, + rpctypes.ErrTimeoutDueToLeaderFail, + rpctypes.ErrUnhealthy, + rpctypes.ErrCorrupt, + } + + for _, c := range unrecoverableCodes { + if rpcErr.Code() == c.(rpctypes.EtcdError).Code() { + return false + } + } + + return true +} + func init() { flag.StringVar(&etcdMtlsId, "etcd.mtls.id", "etcd-client", "mTLS identity to use for connecting to etcd") flag.UintVar(&etcdStartupTimeoutMs, "etcd.startup-timeout", etcdStartupTimeoutMs, "max timeout (in ms) for etcd startup attempts before failing") diff --git a/sd/watcher.go b/sd/watcher.go index 5d95d17..cdda3b3 100644 --- a/sd/watcher.go +++ b/sd/watcher.go @@ -101,7 +101,7 @@ func (w *SDWatcher) GetAddrs(ctx context.Context) ([]ServiceAddress, error) { func (w *SDWatcher) watch(ctx context.Context) { kvs := make(map[string][]byte, 0) - w.logger.Infof("Watching for service publications under path %s", w.prefix()) + w.logger.V(1).Infof("Watching for service publications under path %s", w.prefix()) items, err := w.EtcdClient.Get(ctx, w.prefix(), etcd_client.WithPrefix()) if err == nil { -- 2.50.1