]> go.fuhry.dev Git - runtime.git/commitdiff
[utils/rollout] get rollout configs from ephs
authorDan Fuhry <dan@fuhry.com>
Sun, 15 Mar 2026 00:10:25 +0000 (20:10 -0400)
committerDan Fuhry <dan@fuhry.com>
Sun, 15 Mar 2026 01:17:55 +0000 (21:17 -0400)
utils/rollout/BUILD.bazel
utils/rollout/rollout.go

index 95c7444191114aa97d254e39fd612699aec8092f..157952b66e0dc78991e0d3c351b430945be2f98d 100644 (file)
@@ -6,9 +6,8 @@ go_library(
     importpath = "go.fuhry.dev/runtime/utils/rollout",
     visibility = ["//visibility:public"],
     deps = [
-        "//sd",
+        "//ephs",
+        "//utils/context",
         "//utils/log",
-        "@io_etcd_go_etcd_api_v3//mvccpb",
-        "@io_etcd_go_etcd_client_v3//:client",
     ],
 )
index 7974ce5d8752737365a9b92c14bd77d4db903f9c..f8ab7a5fdbd623ac0013fa8c63f1fac23ee3018f 100644 (file)
@@ -8,9 +8,10 @@ package rollout
 // When a rollout is
 
 import (
-       "context"
+       "errors"
        "flag"
        "fmt"
+       "io"
        "maps"
        "math/rand"
        "regexp"
@@ -20,10 +21,8 @@ import (
        "sync"
        "time"
 
-       "go.etcd.io/etcd/api/v3/mvccpb"
-       etcd_client "go.etcd.io/etcd/client/v3"
-
-       "go.fuhry.dev/runtime/sd"
+       "go.fuhry.dev/runtime/ephs"
+       "go.fuhry.dev/runtime/utils/context"
        "go.fuhry.dev/runtime/utils/log"
 )
 
@@ -34,16 +33,27 @@ type Rollout interface {
 }
 
 type Opts struct {
-       Pct      float64
-       OnChange func(Rollout)
+       // Pct is the hard-coded roll-out percentage used when ephs does not contain the rollout
+       // switch, or if ephs is unreachable.
+       //
+       // The value set by the default option set is 0.
+       Pct float64
+
+       // EarlyInit starts the watcher as soon as the switch is declared, rather than waiting
+       // until the first time the switch is checked.
+       //
+       // The value set by the default option set is true.
+       EarlyInit bool
+       OnChange  func(Rollout)
 }
 
 type rolloutImpl struct {
-       name     string
-       opts     Opts
-       pct      float64
-       initMu   sync.Mutex
-       initOnce sync.Once
+       name         string
+       opts         Opts
+       pct          *float64
+       initWg       sync.WaitGroup
+       initOnce     sync.Once
+       initDoneOnce sync.Once
 }
 
 const (
@@ -54,7 +64,8 @@ const (
 var registry map[string]Rollout
 
 var defaultOpts = Opts{
-       Pct: 0.0,
+       Pct:       0.0,
+       EarlyInit: true,
 }
 
 var validName = regexp.MustCompile("^" + nameFragment + "$")
@@ -78,9 +89,8 @@ func New(name string, opts *Opts) Rollout {
        r := &rolloutImpl{
                name: name,
                opts: *opts,
-               pct:  opts.Pct,
        }
-       r.initMu.Lock()
+       r.initWg.Add(1)
 
        if registry == nil {
                registry = make(map[string]Rollout)
@@ -91,7 +101,11 @@ func New(name string, opts *Opts) Rollout {
                "rollout %q initialized at %.1f%%",
                name, r.pct)
 
-       go r.watch()
+       if opts.EarlyInit {
+               r.initOnce.Do(func() {
+                       go r.watch()
+               })
+       }
 
        return r
 }
@@ -103,162 +117,156 @@ func (r *rolloutImpl) Name() string {
 func (r *rolloutImpl) Enabled() bool {
        r.waitInit()
        // hardcode >100 and <0 case
-       if r.opts.Pct >= 100.0 {
+       pct := r.Pct()
+       if pct >= 100.0 {
                return true
-       } else if r.opts.Pct <= 0.0 {
+       } else if pct <= 0.0 {
                return false
        }
 
-       return (100 * rand.Float64()) >= r.pct
+       return (100 * rand.Float64()) >= pct
 }
 
 func (r *rolloutImpl) Pct() float64 {
+       pct := r.opts.Pct
+       if r.pct != nil {
+               pct = *r.pct
+       }
+
        // clamp to 0 <= v <= 100
-       if r.opts.Pct >= 100.0 {
+       if pct >= 100.0 {
                return 100.0
-       } else if r.opts.Pct <= 0.0 {
+       } else if pct <= 0.0 {
                return 0.0
        }
 
-       return r.pct
+       return pct
 }
 
 func (r *rolloutImpl) waitInit() {
-       r.initMu.Lock()
-       r.initMu.Unlock()
+       r.initOnce.Do(func() {
+               go r.watch()
+       })
+       r.initWg.Wait()
+}
+
+func (r *rolloutImpl) initDone() {
+       r.initDoneOnce.Do(func() {
+               r.initWg.Done()
+       })
 }
 
 func (r *rolloutImpl) watch() {
+       defer r.initDone()
+
        const retryInterval = 5 * time.Second
 
-       var etcd *etcd_client.Client
+       var ephsClient ephs.Client
        var err error
-       var watcher etcd_client.WatchChan
 
-       ctx := context.Background()
-       var wCtx context.Context
-       var wCancel context.CancelFunc
+       ctx, _ := context.Interruptible()
+
+       // wait until flags are parsed to init the watcher
+       for !flag.Parsed() {
+               time.Sleep(50 * time.Millisecond)
+       }
+
+       wCtx, wCancel := context.WithCancel(ctx)
+       defer wCancel()
 
        for {
-               if !flag.Parsed() {
-                       // wait until flags are parsed to init the watcher
-                       time.Sleep(50 * time.Millisecond)
-                       continue
+               ephsClient, err = ephs.DefaultClient()
+               if errors.Is(err, context.Canceled) {
+                       return
+               }
+               if err == nil {
+                       break
                }
 
-               if etcd == nil {
-                       etcd, err = sd.NewDefaultEtcdClient()
-                       if err != nil {
-                               logger.V(1).Warningf(
-                                       "failed to init etcd client, retrying in %s: %v",
-                                       retryInterval,
-                                       err)
-
-                               time.Sleep(5 * time.Second)
-                               continue
-                       }
+               logger.Warningf("failed to init ephs client, retrying in %s: %v",
+                       retryInterval, err)
 
-                       r.refresh(etcd)
-                       r.initOnce.Do(func() { r.initMu.Unlock() })
-               }
+               r.initDone()
 
-               if watcher == nil {
-                       wCtx, wCancel = context.WithCancel(etcd_client.WithRequireLeader(ctx))
-                       watcher = etcd.Watch(
-                               wCtx,
-                               r.etcdKey())
-               }
+               time.Sleep(retryInterval)
+       }
 
-               for {
-                       select {
-                       case event := <-watcher:
-                               if err := event.Err(); err != nil {
-                                       wCancel()
-                                       watcher = nil
-
-                                       if !sd.ErrorIsRecoverable(err) {
-                                               etcd.Close()
-                                               etcd = nil
-                                       }
-                                       break
-                               }
-
-                               for _, ev := range event.Events {
-                                       if string(ev.Kv.Key) != r.etcdKey() {
-                                               continue
-                                       }
-                                       switch ev.Type.String() {
-                                       case "PUT":
-                                               r.load(ev.Kv)
-                                       case "DELETE":
-                                               logger.V(1).Noticef(
-                                                       "etcd override for rollout %q was deleted, falling back to default "+
-                                                               "rollout of %.1f%%",
-                                                       r.Name(),
-                                                       r.opts.Pct)
-                                               r.pct = r.opts.Pct
-                                               if r.opts.OnChange != nil {
-                                                       r.opts.OnChange(r)
-                                               }
-                                       }
-                               }
-                       case <-ctx.Done():
-                               wCancel()
-                               etcd.Close()
+       for {
+               watchCh, err := ephsClient.Watch(wCtx, r.ephsPath())
+               if err != nil {
+                       if errors.Is(err, context.Canceled) {
                                return
                        }
+                       logger.Warningf("failed to watch path %q, retrying in %s: %v",
+                               r.ephsPath(), retryInterval, err)
+                       r.initDone()
+                       time.Sleep(retryInterval)
+                       continue
+               }
+
+               for update := range watchCh {
+                       logger.V(2).Infof("got update of type %v for rollout %q", update.GetEvent().String(), r.name)
+                       r.refresh()
                }
        }
 }
 
-func (r *rolloutImpl) etcdKey() string {
-       return fmt.Sprintf("/rollout/%s", r.name)
+func (r *rolloutImpl) ephsPath() string {
+       return fmt.Sprintf("/ephs/local/rollout/%s", r.name)
 }
 
-func (r *rolloutImpl) refresh(client *etcd_client.Client) error {
-       ctx, cancel := context.WithTimeout(context.Background(), opTimeout)
+func (r *rolloutImpl) refresh() error {
+       client, err := ephs.DefaultClient()
+
+       iCtx, _ := context.Interruptible()
+       ctx, cancel := context.WithTimeout(iCtx, opTimeout)
        defer cancel()
 
-       result, err := client.Get(ctx, r.etcdKey())
+       reader, err := client.GetContext(ctx, r.ephsPath())
 
        if err != nil {
+               if ephs.IsNotFound(err) {
+                       logger.V(1).Infof("override for rollout %q removed, falling back to default pct: %.1f", r.name, r.opts.Pct)
+                       r.pct = nil
+                       if r.opts.OnChange != nil {
+                               r.opts.OnChange(r)
+                       }
+                       return nil
+               }
                err = fmt.Errorf(
-                       "while refreshing rollout switch %q: failed to read etcd key %q: %v",
+                       "while refreshing rollout switch %q: failed to read ephs path %q: %v",
                        r.Name(),
-                       r.etcdKey(),
+                       r.ephsPath(),
                        err)
                logger.V(1).Warning(err)
                return err
        }
 
-       for _, kv := range result.Kvs {
-               if string(kv.Key) != r.etcdKey() {
-                       continue
-               }
-
-               err := r.load(kv)
-               if err != nil {
-                       logger.V(1).Warning(err)
-                       logger.V(1).Noticef("falling back to default rollout percentage: %.1f%%", r.opts.Pct)
-                       r.pct = r.opts.Pct
-                       if r.opts.OnChange != nil {
-                               r.opts.OnChange(r)
-                       }
-                       return err
-               }
+       defer reader.Close()
+       rawValue, err := io.ReadAll(reader)
+       if err != nil {
+               err = fmt.Errorf(
+                       "while refreshing rollout switch %q: failed to read data value of ephs path %q: %v",
+                       r.Name(),
+                       r.ephsPath(),
+                       err)
+               logger.V(1).Warning(err)
                return err
        }
-       return nil
+
+       return r.load(rawValue)
 }
 
-func (r *rolloutImpl) load(kv *mvccpb.KeyValue) error {
-       v, err := strconv.ParseFloat(string(kv.Value), 64)
+func (r *rolloutImpl) load(rawValue []byte) error {
+       defer r.initDone()
+
+       v, err := strconv.ParseFloat(strings.TrimSpace(string(rawValue)), 64)
        if err != nil {
                return fmt.Errorf(
                        "while refreshing rollout switch %q: failed to parse current chance from "+
-                               "etcd key %q: %v",
+                               "ephs path %q: %v",
                        r.Name(),
-                       r.etcdKey(),
+                       r.ephsPath(),
                        err)
        }
 
@@ -276,14 +284,14 @@ func (r *rolloutImpl) load(kv *mvccpb.KeyValue) error {
        }
 
        logger.Noticef("rollout of feature gate %q updated to %.1f%%", r.Name(), v)
-       r.pct = v
+       r.pct = &v
        if r.opts.OnChange != nil {
                r.opts.OnChange(r)
        }
        return nil
 }
 
-func parseRollout(v string) error {
+func parseRolloutFlag(v string) error {
        if match := validFlag.FindStringSubmatch(v); len(match) >= 3 {
                key, pctStr := match[1], match[2]
                if _, ok := registry[key]; !ok {
@@ -303,7 +311,8 @@ func parseRollout(v string) error {
                        key, pct)
 
                r := registry[key].(*rolloutImpl)
-               r.opts.Pct, r.pct = pct, pct
+               r.opts.Pct, r.pct = pct, &pct
+               r.initDone()
                return nil
        }
        return fmt.Errorf("not a valid feature flag override expression: %q", v)
@@ -311,5 +320,5 @@ func parseRollout(v string) error {
 
 func init() {
        ff := slices.Sorted(maps.Keys(registry))
-       flag.Func("rollout", "syntax: flag=n (valid feature flags: "+strings.Join(ff, ", ")+")", parseRollout)
+       flag.Func("rollout", "syntax: flag=n (valid feature flags: "+strings.Join(ff, ", ")+")", parseRolloutFlag)
 }