--- /dev/null
+package main
+
+import (
+ "context"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "os"
+ "os/signal"
+ "path"
+ "strings"
+ "syscall"
+
+ "github.com/urfave/cli/v3"
+
+ "go.fuhry.dev/runtime/constants"
+ "go.fuhry.dev/runtime/ephs"
+ "go.fuhry.dev/runtime/utils/log"
+)
+
+func cmdStat(ctx context.Context, cmd *cli.Command) error {
+ client, err := ephs.DefaultClient()
+ if err != nil {
+ return err
+ }
+
+ resp, err := client.StatContext(ctx, strings.TrimSuffix(cmd.StringArg("path"), "/"))
+ if err != nil {
+ return err
+ }
+
+ fmt.Println(ephs.FormatFsEntry(resp))
+ return nil
+}
+
+func cmdCat(ctx context.Context, cmd *cli.Command) error {
+ client, err := ephs.DefaultClient()
+ if err != nil {
+ return err
+ }
+
+ reader, err := client.GetContext(ctx, cmd.StringArg("path"))
+ if err != nil {
+ return err
+ }
+
+ io.Copy(os.Stdout, reader)
+
+ return nil
+}
+
+func cmdCopy(ctx context.Context, cmd *cli.Command) error {
+ client, err := ephs.DefaultClient()
+ if err != nil {
+ return err
+ }
+
+ src := cmd.StringArg("src")
+ dst := cmd.StringArg("dst")
+
+ if ephs.IsEphsPath(src) && ephs.IsEphsPath(dst) {
+ return errors.New("copying from one ephs path to another is not presently supported")
+ } else if !ephs.IsEphsPath(src) && !ephs.IsEphsPath(dst) {
+ return errors.New("copying from one non-ephs path to another is not supported")
+ }
+
+ if strings.HasSuffix(dst, "/") {
+ dst += path.Base(src)
+ }
+
+ if ephs.IsEphsPath(src) {
+ // copy from ephs to local
+ fp, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE, os.FileMode(0600))
+ if err != nil {
+ return err
+ }
+ defer fp.Close()
+
+ reader, err := client.GetContext(ctx, src)
+ if err != nil {
+ return err
+ }
+
+ nw, _ := io.Copy(fp, reader)
+
+ log.Default().Infof("wrote %d bytes to %s", nw, dst)
+ return nil
+ } else if ephs.IsEphsPath(dst) {
+ // copy from local to ephs (put request)
+ st, err := os.Stat(src)
+ if err != nil {
+ return err
+ }
+ fp, err := os.OpenFile(src, os.O_RDONLY, os.FileMode(0))
+ if err != nil {
+ return err
+ }
+ defer fp.Close()
+
+ resp, err := client.PutContext(ctx, dst, uint64(st.Size()), fp)
+ if err != nil {
+ return err
+ }
+
+ log.Default().Noticef("successfully uploaded: %s\n%s", dst, ephs.FormatFsEntry(resp))
+ }
+
+ return nil
+}
+
+func cmdDelete(ctx context.Context, cmd *cli.Command) error {
+ client, err := ephs.DefaultClient()
+ if err != nil {
+ return err
+ }
+
+ err = client.DeleteContext(ctx, cmd.StringArg("path"), cmd.IsSet("recursive"))
+ if err != nil {
+ return err
+ }
+
+ log.Default().Infof("Deleted object: %s", cmd.StringArg("path"))
+ return nil
+}
+
+func cmdWatch(ctx context.Context, cmd *cli.Command) error {
+ client, err := ephs.DefaultClient()
+ if err != nil {
+ return err
+ }
+
+ watchCl, err := client.Watch(ctx, cmd.StringArg("path"))
+ if err != nil {
+ return err
+ }
+
+ for msg := range watchCl {
+ log.Default().Infof("Got event:\nEvent type: %s\n%s", msg.Event.String(), ephs.FormatFsEntry(msg.GetEntry()))
+ }
+ return nil
+}
+
+func cmdMkdir(ctx context.Context, cmd *cli.Command) error {
+ client, err := ephs.DefaultClient()
+ if err != nil {
+ return err
+ }
+
+ err = client.MkDirContext(ctx, cmd.StringArg("path"), cmd.Bool("recursive"))
+ if err != nil {
+ return err
+ }
+
+ log.Default().Infof("Created directory: %s", cmd.StringArg("path"))
+ return nil
+}
+
+func main() {
+ ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+ defer cancel()
+
+ flag.Parse()
+
+ cmd := &cli.Command{
+ Name: "ephs",
+ Version: constants.Version,
+ Description: "interact with ephs",
+
+ Commands: []*cli.Command{
+ {
+ Name: "stat",
+ Aliases: []string{"ls", "dir", "list"},
+ Description: "get info about a file or list contents of a directory",
+ Action: cmdStat,
+ Arguments: []cli.Argument{
+ &cli.StringArg{
+ Name: "path",
+ UsageText: "path to read",
+ },
+ },
+ },
+ {
+ Name: "cat",
+ Description: "read a file",
+ Action: cmdCat,
+ Arguments: []cli.Argument{
+ &cli.StringArg{
+ Name: "path",
+ UsageText: "path to read",
+ },
+ },
+ },
+ {
+ Name: "cp",
+ Description: "copy a file",
+ Action: cmdCopy,
+ Arguments: []cli.Argument{
+ &cli.StringArg{
+ Name: "src",
+ UsageText: "source",
+ },
+ &cli.StringArg{
+ Name: "dst",
+ UsageText: "destination",
+ },
+ },
+ },
+ {
+ Name: "rm",
+ Description: "delete a file",
+ Action: cmdDelete,
+ Arguments: []cli.Argument{
+ &cli.StringArg{
+ Name: "path",
+ UsageText: "path to read",
+ },
+ },
+ Flags: []cli.Flag{
+ &cli.BoolFlag{
+ Name: "recursive",
+ Usage: "allow removing directories and all of their contents",
+ },
+ },
+ },
+ {
+ Name: "watch",
+ Description: "watch a path for changes",
+ Action: cmdWatch,
+ Arguments: []cli.Argument{
+ &cli.StringArg{
+ Name: "path",
+ UsageText: "path to read",
+ },
+ },
+ },
+ {
+ Name: "mkdir",
+ Description: "create a directory",
+ Action: cmdMkdir,
+ Arguments: []cli.Argument{
+ &cli.StringArg{
+ Name: "path",
+ UsageText: "directory to create",
+ },
+ },
+ Flags: []cli.Flag{
+ &cli.BoolFlag{
+ Name: "recursive",
+ Usage: "also create any parent directories that don't already exist",
+ },
+ },
+ },
+ },
+
+ Flags: []cli.Flag{
+ &cli.IntFlag{
+ Name: "vv",
+ Usage: "verbosity level",
+ },
+ &cli.StringFlag{
+ Name: "v",
+ Usage: "log level",
+ },
+ &cli.StringFlag{
+ Name: "grpc.transport",
+ Usage: "grpc transport (tcp or quic)",
+ },
+ },
+ }
+
+ if err := cmd.Run(ctx, os.Args); err != nil {
+ log.Default().Panic(err)
+ }
+}
--- /dev/null
+GOSRC = $(wildcard *.go)
+GOEXE = $(shell basename `pwd`)
+GOBUILDFLAGS := -buildmode=pie -trimpath
+
+all: $(GOEXE)
+
+clean:
+ rm -fv $(GOEXE)
+
+.PHONY: all clean
+
+$(GOEXE): %: $(GOSRC)
+ go build $(GOBUILDFLAGS) -o $@ $<
+
--- /dev/null
+DEFAULT:
+ - service: '*'
+ - user: '*'
--- /dev/null
+package main
+
+import (
+ "context"
+ "flag"
+ "os/signal"
+ "syscall"
+
+ "go.fuhry.dev/runtime/ephs/servicer"
+ "go.fuhry.dev/runtime/grpc"
+ "go.fuhry.dev/runtime/mtls"
+ ephs_pb "go.fuhry.dev/runtime/proto/service/ephs"
+ "go.fuhry.dev/runtime/utils/log"
+
+ google_grpc "google.golang.org/grpc"
+)
+
+func main() {
+ var err error
+
+ acl := flag.String("ephs.acl", "", "YAML file containing ACLs for ephs")
+ awsFile := flag.String("ephs.s3-creds-file", "", "file to load AWS S3 credentials from")
+ awsEnv := flag.Bool("ephs.s3-creds-env", false, "set to true to load AWS credentials from environment")
+
+ flag.Parse()
+
+ serverIdentity := mtls.DefaultIdentity()
+ s, err := grpc.NewGrpcServer(serverIdentity)
+ if err != nil {
+ panic(err)
+ }
+
+ var opts []servicer.Option
+ if *acl != "" {
+ opts = append(opts, servicer.WithAclFile(*acl))
+ }
+
+ if awsEnv != nil && *awsEnv {
+ opts = append(opts, servicer.WithAWSEnvCredentials())
+ } else if awsFile != nil && *awsFile != "" {
+ opts = append(opts, servicer.WithAWSCredentialFile(*awsFile))
+ }
+
+ serv, err := servicer.NewEphsServicer(opts...)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+
+ err = s.PublishAndServe(ctx, func(s *google_grpc.Server) {
+ ephs_pb.RegisterEphsServer(s, serv)
+ })
+ if err != nil {
+ panic(err)
+ }
+ defer s.Stop()
+
+ <-ctx.Done()
+}
+
+func init() {
+ mtls.SetDefaultIdentity("ephs")
+}
--- /dev/null
+rules:
+- principal:
+ prefix: spiffe://roc.xx0r.info/service/
+ key:
+ prefix: services/{{principal}}/
+- principal:
+ or:
+ - exact: spiffe://roc.xx0r.info/user/dan
+ - exact: spiffe://roc.xx0r.info/service/conffs
+ key:
+ any: true
import (
"context"
"fmt"
+ "net"
"go.fuhry.dev/runtime/mtls"
"go.fuhry.dev/runtime/sd"
"google.golang.org/grpc"
)
-type Client struct {
+type ClientConn = grpc.ClientConn
+
+type Client interface {
+ Conn() (*grpc.ClientConn, error)
+}
+
+type ClientOption interface {
+ apply(*client) error
+}
+
+type clientOption struct {
+ f func(*client) error
+}
+
+type AddressProvider interface {
+ GetAddrs(context.Context) ([]sd.ServiceAddress, error)
+}
+
+func (o *clientOption) apply(c *client) error {
+ return o.f(c)
+}
+
+type client struct {
ctx context.Context
serverId mtls.Identity
clientId mtls.Identity
- watcher *sd.SDWatcher
+ watcher AddressProvider
connFac ConnectionFactory
}
-func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity) (*Client, error) {
- etcdc, err := sd.NewDefaultEtcdClient()
- if err != nil {
- panic(err)
+type staticAddressProvider struct {
+ addresses []sd.ServiceAddress
+}
+
+func (s *staticAddressProvider) GetAddrs(_ context.Context) ([]sd.ServiceAddress, error) {
+ return s.addresses, nil
+}
+
+func WithConnectionFactory(fac ConnectionFactory) ClientOption {
+ return &clientOption{
+ f: func(c *client) error {
+ c.connFac = fac
+ return nil
+ },
+ }
+}
+
+func WithAddressProvider(ap AddressProvider) ClientOption {
+ return &clientOption{
+ f: func(c *client) error {
+ c.watcher = ap
+ return nil
+ },
+ }
+}
+
+func WithStaticAddress(addresses ...*net.TCPAddr) ClientOption {
+ var addrs []sd.ServiceAddress
+
+ for _, addr := range addresses {
+ var ip4, ip6 string
+ if len(addr.IP) == 4 {
+ ip4 = addr.IP.String()
+ } else {
+ ip6 = addr.IP.String()
+ }
+ addrs = append(addrs, sd.ServiceAddress{
+ IP4: ip4,
+ IP6: ip6,
+ Port: uint16(addr.Port),
+ })
}
- w := &sd.SDWatcher{
- Service: serverId.Name(),
- EtcdClient: etcdc,
- Protocol: sd.ProtocolGRPC,
+ return &clientOption{
+ f: func(c *client) error {
+ c.watcher = &staticAddressProvider{
+ addresses: addrs,
+ }
+ return nil
+ },
}
+}
- cl := &Client{
+func NewGrpcClient(ctx context.Context, serverId, clientId mtls.Identity, opts ...ClientOption) (Client, error) {
+ cl := &client{
ctx: ctx,
serverId: serverId,
clientId: clientId,
- watcher: w,
connFac: NewDefaultConnectionFactory(),
}
+ for _, opt := range opts {
+ if err := opt.apply(cl); err != nil {
+ return nil, err
+ }
+ }
+
+ if cl.watcher == nil {
+ etcdc, err := sd.NewDefaultEtcdClient()
+ if err != nil {
+ return nil, err
+ }
+
+ cl.watcher = &sd.SDWatcher{
+ Service: serverId.Name(),
+ EtcdClient: etcdc,
+ Protocol: sd.ProtocolGRPC,
+ }
+ }
+
return cl, nil
}
-func (c *Client) Conn() (*grpc.ClientConn, error) {
+func (c *client) Conn() (*grpc.ClientConn, error) {
addrs, err := c.watcher.GetAddrs(c.ctx)
if err != nil {
return nil, err
package grpc
import (
+ "context"
"crypto/tls"
"fmt"
"net"
+ "time"
"google.golang.org/grpc/credentials"
grpc_quic "go.fuhry.dev/grpc-quic"
)
-type QUICConnectionFactory struct{}
+var defaultQuicConfig = &quic.Config{
+ HandshakeIdleTimeout: 15 * time.Second,
+ MaxIdleTimeout: 5 * time.Minute,
+}
+
+type QUICConnectionFactory struct {
+ QUICConfig *quic.Config
+}
func (cf *QUICConnectionFactory) NewCredentials(tlsConfig *tls.Config) credentials.TransportCredentials {
tlsConfig.NextProtos = []string{"grpc-quic"}
return nil, err
}
- quicListener, err := quic.Listen(udpListener, tlsConfig, nil)
+ if cf.QUICConfig == nil {
+ cf.QUICConfig = defaultQuicConfig.Clone()
+ }
+
+ quicListener, err := quic.Listen(udpListener, tlsConfig, cf.QUICConfig.Clone())
if err != nil {
return nil, err
}
}
func (cf *QUICConnectionFactory) NewDialer(tlsConfig *tls.Config) ContextDialer {
- return grpc_quic.NewQuicDialer(tlsConfig)
+ if cf.QUICConfig == nil {
+ cf.QUICConfig = defaultQuicConfig.Clone()
+ }
+
+ return func(ctx context.Context, target string) (net.Conn, error) {
+ sess, err := quic.DialAddr(ctx, target, tlsConfig, cf.QUICConfig.Clone())
+ if err != nil {
+ return nil, err
+ }
+
+ return grpc_quic.NewConn(sess)
+ }
}
func init() {
--- /dev/null
+package main
+
+import (
+ "context"
+ "flag"
+ "net"
+ "os"
+ "os/signal"
+ "strconv"
+ "syscall"
+ "time"
+
+ "go.fuhry.dev/runtime/grpc"
+ "go.fuhry.dev/runtime/mtls"
+ "go.fuhry.dev/runtime/utils/log"
+
+ "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+func main() {
+ serverId := flag.String("server-id", "", "mtls ID of gRPC server")
+ serverAddr := flag.String("server-addr", "", "server address as ip:port only")
+ watch := flag.Bool("watch", false, "continually monitor server for health status changes")
+ wait := flag.Bool("wait", false, "continually retry connecting, then stream server status until RUNNING")
+ var waitTimeout = 60 * time.Second
+ flag.Func("wait.timeout", "maximum time to wait for a healthy server (default: 60s)", func(val string) (err error) {
+ waitTimeout, err = time.ParseDuration(val)
+ return
+ })
+
+ mtls.SetDefaultIdentity("anonymous")
+
+ flag.Parse()
+
+ var opts []grpc.ClientOption
+
+ if *serverAddr != "" {
+ host, port, err := net.SplitHostPort(*serverAddr)
+ if err != nil {
+ log.Default().Panic(err)
+ }
+
+ portInt, err := strconv.Atoi(port)
+ if err != nil {
+ portInt, err = net.LookupPort("tcp", port)
+ if err != nil {
+ log.Default().Panic(err)
+ }
+ }
+
+ if ip := net.ParseIP(host); ip == nil {
+ log.Default().Panicf("%q: not a valid IPv4 or IPv6 address", host)
+ } else {
+ opts = append(opts, grpc.WithStaticAddress(&net.TCPAddr{ip, portInt, ""}))
+ }
+ }
+
+ ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+ defer cancel()
+ var deadline time.Time
+
+ if *wait {
+ ctx, cancel = context.WithTimeout(ctx, waitTimeout)
+ defer cancel()
+ deadline, _ = ctx.Deadline()
+ }
+
+ if *watch && *wait {
+ log.Default().Fatal("watch and wait options are mutually exclusive")
+ os.Exit(1)
+ }
+
+ var conn *grpc.ClientConn
+ for {
+ client, err := grpc.NewGrpcClient(ctx, mtls.NewServiceIdentity(*serverId), mtls.DefaultIdentity(), opts...)
+ if err != nil {
+ if *wait && time.Now().Before(deadline) {
+ log.Default().Warningf("error connecting (%v) retrying in 1s", err)
+ time.Sleep(1 * time.Second)
+ continue
+ }
+ log.Default().Critical(err)
+ os.Exit(1)
+ }
+
+ conn, err = client.Conn()
+ if err != nil {
+ if *wait && time.Now().Before(deadline) {
+ log.Default().Warningf("error connecting (%v) retrying in 1s", err)
+ time.Sleep(1 * time.Second)
+ continue
+ }
+ log.Default().Critical(err)
+ os.Exit(1)
+ }
+
+ defer conn.Close()
+ break
+ }
+
+ for {
+ healthClient := grpc_health_v1.NewHealthClient(conn)
+
+ if *watch {
+ stream, err := healthClient.Watch(ctx, &grpc_health_v1.HealthCheckRequest{})
+ if err != nil {
+ log.Default().Alertf("healthcheck failed with error %T: %v", err, err)
+ os.Exit(2)
+ }
+
+ for msg, err := stream.Recv(); err == nil; msg, err = stream.Recv() {
+ log.Default().Infof("server status: %s", msg.Status.String())
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ os.Exit(0)
+ } else if *wait {
+ stream, err := healthClient.Watch(ctx, &grpc_health_v1.HealthCheckRequest{})
+ if err != nil {
+ if time.Now().Before(deadline) {
+ log.Default().Warningf("error connecting (%v) retrying in 1s", err)
+ time.Sleep(1 * time.Second)
+ continue
+ }
+ log.Default().Critical(err)
+ os.Exit(1)
+ }
+
+ for msg, err := stream.Recv(); err == nil; msg, err = stream.Recv() {
+ log.Default().Infof("server status: %s", msg.Status.String())
+ if msg.Status == grpc_health_v1.HealthCheckResponse_SERVING {
+ return
+ }
+ }
+
+ os.Exit(1)
+ } else {
+ resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
+ if err != nil {
+ log.Default().Alertf("healthcheck failed with error %T: %v", err, err)
+ os.Exit(2)
+ }
+
+ log.Default().Infof("server status: %s", resp.Status.String())
+ if resp.Status == grpc_health_v1.HealthCheckResponse_SERVING {
+ os.Exit(0)
+ }
+
+ os.Exit(1)
+ }
+ }
+}
--- /dev/null
+package grpc
+
+import (
+ "context"
+ "sync"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/health/grpc_health_v1"
+
+ "go.fuhry.dev/runtime/utils/log"
+)
+
+type HealthCheckServicer interface {
+ grpc_health_v1.HealthServer
+
+ SetStatus(grpc_health_v1.HealthCheckResponse_ServingStatus)
+}
+
+type healthCheckServicer struct {
+ s grpc_health_v1.HealthCheckResponse_ServingStatus
+ logger log.Logger
+
+ watchers map[chan grpc_health_v1.HealthCheckResponse_ServingStatus]struct{}
+ watchersMu sync.Mutex
+}
+
+var _ grpc_health_v1.HealthServer = &healthCheckServicer{}
+
+func (h *healthCheckServicer) Check(ctx context.Context, req *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) {
+ return &grpc_health_v1.HealthCheckResponse{
+ Status: h.s,
+ }, nil
+}
+
+func (h *healthCheckServicer) SetStatus(status grpc_health_v1.HealthCheckResponse_ServingStatus) {
+ h.s = status
+
+ h.watchersMu.Lock()
+ defer h.watchersMu.Unlock()
+
+ for wa := range h.watchers {
+ wa <- status
+ }
+}
+
+func (h *healthCheckServicer) Watch(req *grpc_health_v1.HealthCheckRequest, stream grpc.ServerStreamingServer[grpc_health_v1.HealthCheckResponse]) error {
+ h.watchersMu.Lock()
+ ch := make(chan grpc_health_v1.HealthCheckResponse_ServingStatus)
+ h.watchers[ch] = struct{}{}
+ h.watchersMu.Unlock()
+
+ defer (func() {
+ h.watchersMu.Lock()
+ delete(h.watchers, ch)
+ close(ch)
+ h.watchersMu.Unlock()
+ })()
+
+ msg := &grpc_health_v1.HealthCheckResponse{
+ Status: h.s,
+ }
+ if err := stream.SendMsg(msg); err != nil {
+ return err
+ }
+
+ for {
+ select {
+ case st := <-ch:
+ h.logger.Infof("got status update: %s", st)
+ msg.Status = st
+ if err := stream.SendMsg(msg); err != nil {
+ return err
+ }
+ case <-stream.Context().Done():
+ return stream.Context().Err()
+ }
+ }
+}
+
+func NewHealthCheckServicer() HealthCheckServicer {
+ return &healthCheckServicer{
+ logger: log.WithPrefix("grpc_health_v1"),
+ watchers: make(map[chan grpc_health_v1.HealthCheckResponse_ServingStatus]struct{}),
+ }
+}
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
log log.Logger
sessions *lru.Cache[string, *session]
connFac ConnectionFactory
+ hc HealthCheckServicer
}
var defaultPort *uint
log: log.WithPrefix(fmt.Sprintf("grpcServer:%s", id.Name())),
sessions: sessionsLru,
connFac: NewDefaultConnectionFactory(),
+ hc: NewHealthCheckServicer(),
}
return server, nil
tc.MinVersion = tls.VersionTLS13
tc.MaxVersion = tls.VersionTLS13
+ // Verify the client certificate if presented, but still allow the handshake to complete if no certificate
+ // is presented at all. This allows for anonymous gRPC health checks; see handleConnection and
+ // handleStreamConnection below.
+ tc.ClientAuth = tls.VerifyClientCertIfGiven
+
err = s.verifier.ConfigureServer(tc)
if err != nil {
return err
grpcServer := grpc.NewServer(opts...)
+ grpc_health_v1.RegisterHealthServer(grpcServer, s.hc)
callback(grpcServer)
go grpcServer.Serve(listener)
+ s.hc.SetStatus(grpc_health_v1.HealthCheckResponse_SERVING)
s.grpcServer = grpcServer
err = s.publisher.Publish(ctx)
}
func (s *Server) Stop() {
+ s.hc.SetStatus(grpc_health_v1.HealthCheckResponse_NOT_SERVING)
s.publisher.Unpublish()
if s.grpcServer != nil {
s.grpcServer.GracefulStop()
}
func (s *Server) handleConnection(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ if info.FullMethod == "/grpc.health.v1.Health/Check" {
+ return handler(ctx, req)
+ }
+
peer, ok := peer.FromContext(ctx)
if !ok {
return nil, status.Errorf(codes.PermissionDenied, "client did not authenticate")
}
func (s *Server) handleStreamConnection(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ if info.FullMethod == "/grpc.health.v1.Health/Watch" {
+ return handler(srv, ss)
+ }
+
ctx := ss.Context()
peer, ok := peer.FromContext(ctx)
if !ok {
return nil
}
-func unsubscribe(filePath string) error {
+func Unsubscribe(filePath string) error {
err := unsubscribeInternal(filePath)
unsubscribeInternal(path.Dir(filePath))
return err
return
}
}
- logger.V(2).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op)
+ logger.V(4).Warningf("dangling watcher on path %s: not known within normal rules or symlink map: %v", event.Name, event.Op)
}
// handleEventForPath is the second stage of event handling which calls the actual handlers.
return
}
+ op := event.Op
+
logger.V(2).Debugf("inotify event on %s: %v", event.Name, event.Op)
logger.V(2).Debugf("ruleName: %s", ruleName)
logger.V(2).Debugf("upstreams: %+v", rule.upstreams.AsSlice())
handlerPath = ruleName
for _, upstream := range rule.upstreams.AsSortedSlice() {
- unsubscribe(upstream)
+ Unsubscribe(upstream)
}
untrackSymlinks(ruleName)
- unsubscribe(ruleName)
+ Unsubscribe(ruleName)
if len(watchHandlers) == 1 {
logger.V(2).Noticef("all watchers should now be cleaned up:")
Debug()
if event.Has(Close) {
if pendingWrites.Contains(event.Name) {
pendingWrites.Del(event.Name)
+ op |= Write
} else if !fileSwapped {
return
}
}
for _, handler := range rule.callbacks {
- handler(handlerPath, event.Op)
+ handler(handlerPath, op)
}
}
for link, targets := range symlinkPropagationMap {
if targets.Contains(ruleName) {
logger.V(2).Debugf("cleaning up watcher on %s for fsnotify watch rule %s", link, ruleName)
- unsubscribe(link)
+ Unsubscribe(link)
targets.Del(ruleName)
}
if targets.Len() == 0 {
return strings.Join(out, string(filepath.Separator))
}
+var RealPath = realpath
+
// resolveSymlinkRecurse recursively resolves filePath until a real file is found.
//
// Both relative and absolute symlinks are handled and normalized.
const (
InvalidPrincipal PrincipalClass = iota
+ AnonymousPrincipal
ServicePrincipal
UserPrincipal
SSLCertificatePrincipal
panic("cannot get default identity before flags are parsed")
}
+ if defaultMtlsIdentity == "anonymous" {
+ return Anonymous()
+ }
+
if defaultMtlsIdentity == "" {
userId, err := NewDefaultUserIdentity()
if err == nil && userId.IsValid() {
leafCert, _ := userId.LeafCertificate()
- log.Default().Infof("found valid user certificate, using identity: %s", leafCert.Subject)
+ log.Default().V(1).Infof("found valid user certificate, using identity: %s", leafCert.Subject)
return userId
} else {
log.Default().V(2).Debugf("couldn't load a user identity: err: %+v", err)
return NewServiceIdentity(defaultMtlsIdentity)
}
+// Anonymous returns an identity that supplies no client certificate
+func Anonymous() Identity {
+ return &anonymousIdentity{}
+}
+
func IdentityFromTLSConnectionState(state *tls.ConnectionState) (Identity, error) {
if state == nil {
return nil, fmt.Errorf("connectionState is nil")
if err != nil {
return err
}
- tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo)
+ // If the tls config isn't configured to request a client certificate at all (the default), reconfigure it
+ // to require one, which is usually what you want.
+ if tlsConfig.ClientAuth == tls.NoClientCert {
+ tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+ }
+ tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tlsConfig.ClientAuth)
tlsConfig.InsecureSkipVerify = true
- tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo)
+ tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo, tlsConfig.ClientAuth)
tlsConfig.ClientCAs = vo.Roots
- tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
+ tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
return nil
}
if err != nil {
return err
}
+ clientAuthType := tls.RequireAndVerifyClientCert
tlsConfig.InsecureSkipVerify = true
- tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo)
- tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo)
+ tlsConfig.VerifyPeerCertificate = NewVerifyMTLSPeerCertificateFuncWithOpts(vo, clientAuthType)
+ tlsConfig.VerifyConnection = cv.verifyConnectionFunc(vo, clientAuthType)
return nil
}
-func (cv *mtlsPeerVerifier) verifyConnectionFunc(verifyOpts x509.VerifyOptions) func(tls.ConnectionState) error {
+func (cv *mtlsPeerVerifier) verifyConnectionFunc(verifyOpts x509.VerifyOptions, clientAuth tls.ClientAuthType) func(tls.ConnectionState) error {
return func(cs tls.ConnectionState) error {
if len(cs.PeerCertificates) < 1 {
- return fmt.Errorf("no peer certificate provided")
+ if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven {
+ return nil
+ }
+ return ErrNoCertificatePresented
}
peerCert := cs.PeerCertificates[0]
package mtls
import (
+ "crypto/tls"
"crypto/x509"
"errors"
"sync"
for _, chain := range chains {
lastCert := chain[0]
- logger.Debugf("checking constraints on leaf certificate: %+v", lastCert.Subject.String())
+ logger.V(2).Debugf("checking constraints on leaf certificate: %+v", lastCert.Subject.String())
if err := checkLeafCertificateConstraints(lastCert); err != nil {
return err
}
return nil, err
}
- return NewVerifyMTLSPeerCertificateFuncWithOpts(vo), nil
+ return NewVerifyMTLSPeerCertificateFuncWithOpts(vo, tls.RequireAndVerifyClientCert), nil
}
-func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions) tlsVerifyPeerCertificatesFunc {
+func NewVerifyMTLSPeerCertificateFuncWithOpts(vo x509.VerifyOptions, clientAuth tls.ClientAuthType) tlsVerifyPeerCertificatesFunc {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
var leafCert *x509.Certificate
}
if leafCert == nil {
+ if clientAuth == tls.NoClientCert || clientAuth == tls.RequestClientCert || clientAuth == tls.VerifyClientCertIfGiven {
+ return nil
+ }
return ErrNoCertificatePresented
}
"sync"
"time"
+ "go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
"go.etcd.io/etcd/client/pkg/v3/srv"
etcd_client "go.etcd.io/etcd/client/v3"
"go.fuhry.dev/runtime/mtls"
return client, nil
}
+func ErrorIsRecoverable(err error) bool {
+ rpcErr, ok := err.(rpctypes.EtcdError)
+ if !ok {
+ // not an EtcdError
+ return false
+ }
+
+ var unrecoverableCodes = []error{
+ rpctypes.ErrAuthFailed,
+ rpctypes.ErrRoleNotGranted,
+ rpctypes.ErrAuthFailed,
+ rpctypes.ErrAuthNotEnabled,
+ rpctypes.ErrInvalidAuthToken,
+ rpctypes.ErrAuthOldRevision,
+ rpctypes.ErrInvalidAuthMgmt,
+ rpctypes.ErrClusterIdMismatch,
+ rpctypes.ErrNoLeader,
+ rpctypes.ErrNotCapable,
+ rpctypes.ErrStopped,
+ rpctypes.ErrTimeoutDueToConnectionLost,
+ rpctypes.ErrTimeoutDueToLeaderFail,
+ rpctypes.ErrUnhealthy,
+ rpctypes.ErrCorrupt,
+ }
+
+ for _, c := range unrecoverableCodes {
+ if rpcErr.Code() == c.(rpctypes.EtcdError).Code() {
+ return false
+ }
+ }
+
+ return true
+}
+
func init() {
flag.StringVar(&etcdMtlsId, "etcd.mtls.id", "etcd-client", "mTLS identity to use for connecting to etcd")
flag.UintVar(&etcdStartupTimeoutMs, "etcd.startup-timeout", etcdStartupTimeoutMs, "max timeout (in ms) for etcd startup attempts before failing")
func (w *SDWatcher) watch(ctx context.Context) {
kvs := make(map[string][]byte, 0)
- w.logger.Infof("Watching for service publications under path %s", w.prefix())
+ w.logger.V(1).Infof("Watching for service publications under path %s", w.prefix())
items, err := w.EtcdClient.Get(ctx, w.prefix(), etcd_client.WithPrefix())
if err == nil {