From 8901da1d363cef7187f9d8f25e35670310f6db45 Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Sat, 14 Mar 2026 20:10:25 -0400 Subject: [PATCH] [utils/rollout] get rollout configs from ephs --- utils/rollout/BUILD.bazel | 5 +- utils/rollout/rollout.go | 249 ++++++++++++++++++++------------------ 2 files changed, 131 insertions(+), 123 deletions(-) diff --git a/utils/rollout/BUILD.bazel b/utils/rollout/BUILD.bazel index 95c7444..157952b 100644 --- a/utils/rollout/BUILD.bazel +++ b/utils/rollout/BUILD.bazel @@ -6,9 +6,8 @@ go_library( importpath = "go.fuhry.dev/runtime/utils/rollout", visibility = ["//visibility:public"], deps = [ - "//sd", + "//ephs", + "//utils/context", "//utils/log", - "@io_etcd_go_etcd_api_v3//mvccpb", - "@io_etcd_go_etcd_client_v3//:client", ], ) diff --git a/utils/rollout/rollout.go b/utils/rollout/rollout.go index 7974ce5..f8ab7a5 100644 --- a/utils/rollout/rollout.go +++ b/utils/rollout/rollout.go @@ -8,9 +8,10 @@ package rollout // When a rollout is import ( - "context" + "errors" "flag" "fmt" + "io" "maps" "math/rand" "regexp" @@ -20,10 +21,8 @@ import ( "sync" "time" - "go.etcd.io/etcd/api/v3/mvccpb" - etcd_client "go.etcd.io/etcd/client/v3" - - "go.fuhry.dev/runtime/sd" + "go.fuhry.dev/runtime/ephs" + "go.fuhry.dev/runtime/utils/context" "go.fuhry.dev/runtime/utils/log" ) @@ -34,16 +33,27 @@ type Rollout interface { } type Opts struct { - Pct float64 - OnChange func(Rollout) + // Pct is the hard-coded roll-out percentage used when ephs does not contain the rollout + // switch, or if ephs is unreachable. + // + // The value set by the default option set is 0. + Pct float64 + + // EarlyInit starts the watcher as soon as the switch is declared, rather than waiting + // until the first time the switch is checked. + // + // The value set by the default option set is true. + EarlyInit bool + OnChange func(Rollout) } type rolloutImpl struct { - name string - opts Opts - pct float64 - initMu sync.Mutex - initOnce sync.Once + name string + opts Opts + pct *float64 + initWg sync.WaitGroup + initOnce sync.Once + initDoneOnce sync.Once } const ( @@ -54,7 +64,8 @@ const ( var registry map[string]Rollout var defaultOpts = Opts{ - Pct: 0.0, + Pct: 0.0, + EarlyInit: true, } var validName = regexp.MustCompile("^" + nameFragment + "$") @@ -78,9 +89,8 @@ func New(name string, opts *Opts) Rollout { r := &rolloutImpl{ name: name, opts: *opts, - pct: opts.Pct, } - r.initMu.Lock() + r.initWg.Add(1) if registry == nil { registry = make(map[string]Rollout) @@ -91,7 +101,11 @@ func New(name string, opts *Opts) Rollout { "rollout %q initialized at %.1f%%", name, r.pct) - go r.watch() + if opts.EarlyInit { + r.initOnce.Do(func() { + go r.watch() + }) + } return r } @@ -103,162 +117,156 @@ func (r *rolloutImpl) Name() string { func (r *rolloutImpl) Enabled() bool { r.waitInit() // hardcode >100 and <0 case - if r.opts.Pct >= 100.0 { + pct := r.Pct() + if pct >= 100.0 { return true - } else if r.opts.Pct <= 0.0 { + } else if pct <= 0.0 { return false } - return (100 * rand.Float64()) >= r.pct + return (100 * rand.Float64()) >= pct } func (r *rolloutImpl) Pct() float64 { + pct := r.opts.Pct + if r.pct != nil { + pct = *r.pct + } + // clamp to 0 <= v <= 100 - if r.opts.Pct >= 100.0 { + if pct >= 100.0 { return 100.0 - } else if r.opts.Pct <= 0.0 { + } else if pct <= 0.0 { return 0.0 } - return r.pct + return pct } func (r *rolloutImpl) waitInit() { - r.initMu.Lock() - r.initMu.Unlock() + r.initOnce.Do(func() { + go r.watch() + }) + r.initWg.Wait() +} + +func (r *rolloutImpl) initDone() { + r.initDoneOnce.Do(func() { + r.initWg.Done() + }) } func (r *rolloutImpl) watch() { + defer r.initDone() + const retryInterval = 5 * time.Second - var etcd *etcd_client.Client + var ephsClient ephs.Client var err error - var watcher etcd_client.WatchChan - ctx := context.Background() - var wCtx context.Context - var wCancel context.CancelFunc + ctx, _ := context.Interruptible() + + // wait until flags are parsed to init the watcher + for !flag.Parsed() { + time.Sleep(50 * time.Millisecond) + } + + wCtx, wCancel := context.WithCancel(ctx) + defer wCancel() for { - if !flag.Parsed() { - // wait until flags are parsed to init the watcher - time.Sleep(50 * time.Millisecond) - continue + ephsClient, err = ephs.DefaultClient() + if errors.Is(err, context.Canceled) { + return + } + if err == nil { + break } - if etcd == nil { - etcd, err = sd.NewDefaultEtcdClient() - if err != nil { - logger.V(1).Warningf( - "failed to init etcd client, retrying in %s: %v", - retryInterval, - err) - - time.Sleep(5 * time.Second) - continue - } + logger.Warningf("failed to init ephs client, retrying in %s: %v", + retryInterval, err) - r.refresh(etcd) - r.initOnce.Do(func() { r.initMu.Unlock() }) - } + r.initDone() - if watcher == nil { - wCtx, wCancel = context.WithCancel(etcd_client.WithRequireLeader(ctx)) - watcher = etcd.Watch( - wCtx, - r.etcdKey()) - } + time.Sleep(retryInterval) + } - for { - select { - case event := <-watcher: - if err := event.Err(); err != nil { - wCancel() - watcher = nil - - if !sd.ErrorIsRecoverable(err) { - etcd.Close() - etcd = nil - } - break - } - - for _, ev := range event.Events { - if string(ev.Kv.Key) != r.etcdKey() { - continue - } - switch ev.Type.String() { - case "PUT": - r.load(ev.Kv) - case "DELETE": - logger.V(1).Noticef( - "etcd override for rollout %q was deleted, falling back to default "+ - "rollout of %.1f%%", - r.Name(), - r.opts.Pct) - r.pct = r.opts.Pct - if r.opts.OnChange != nil { - r.opts.OnChange(r) - } - } - } - case <-ctx.Done(): - wCancel() - etcd.Close() + for { + watchCh, err := ephsClient.Watch(wCtx, r.ephsPath()) + if err != nil { + if errors.Is(err, context.Canceled) { return } + logger.Warningf("failed to watch path %q, retrying in %s: %v", + r.ephsPath(), retryInterval, err) + r.initDone() + time.Sleep(retryInterval) + continue + } + + for update := range watchCh { + logger.V(2).Infof("got update of type %v for rollout %q", update.GetEvent().String(), r.name) + r.refresh() } } } -func (r *rolloutImpl) etcdKey() string { - return fmt.Sprintf("/rollout/%s", r.name) +func (r *rolloutImpl) ephsPath() string { + return fmt.Sprintf("/ephs/local/rollout/%s", r.name) } -func (r *rolloutImpl) refresh(client *etcd_client.Client) error { - ctx, cancel := context.WithTimeout(context.Background(), opTimeout) +func (r *rolloutImpl) refresh() error { + client, err := ephs.DefaultClient() + + iCtx, _ := context.Interruptible() + ctx, cancel := context.WithTimeout(iCtx, opTimeout) defer cancel() - result, err := client.Get(ctx, r.etcdKey()) + reader, err := client.GetContext(ctx, r.ephsPath()) if err != nil { + if ephs.IsNotFound(err) { + logger.V(1).Infof("override for rollout %q removed, falling back to default pct: %.1f", r.name, r.opts.Pct) + r.pct = nil + if r.opts.OnChange != nil { + r.opts.OnChange(r) + } + return nil + } err = fmt.Errorf( - "while refreshing rollout switch %q: failed to read etcd key %q: %v", + "while refreshing rollout switch %q: failed to read ephs path %q: %v", r.Name(), - r.etcdKey(), + r.ephsPath(), err) logger.V(1).Warning(err) return err } - for _, kv := range result.Kvs { - if string(kv.Key) != r.etcdKey() { - continue - } - - err := r.load(kv) - if err != nil { - logger.V(1).Warning(err) - logger.V(1).Noticef("falling back to default rollout percentage: %.1f%%", r.opts.Pct) - r.pct = r.opts.Pct - if r.opts.OnChange != nil { - r.opts.OnChange(r) - } - return err - } + defer reader.Close() + rawValue, err := io.ReadAll(reader) + if err != nil { + err = fmt.Errorf( + "while refreshing rollout switch %q: failed to read data value of ephs path %q: %v", + r.Name(), + r.ephsPath(), + err) + logger.V(1).Warning(err) return err } - return nil + + return r.load(rawValue) } -func (r *rolloutImpl) load(kv *mvccpb.KeyValue) error { - v, err := strconv.ParseFloat(string(kv.Value), 64) +func (r *rolloutImpl) load(rawValue []byte) error { + defer r.initDone() + + v, err := strconv.ParseFloat(strings.TrimSpace(string(rawValue)), 64) if err != nil { return fmt.Errorf( "while refreshing rollout switch %q: failed to parse current chance from "+ - "etcd key %q: %v", + "ephs path %q: %v", r.Name(), - r.etcdKey(), + r.ephsPath(), err) } @@ -276,14 +284,14 @@ func (r *rolloutImpl) load(kv *mvccpb.KeyValue) error { } logger.Noticef("rollout of feature gate %q updated to %.1f%%", r.Name(), v) - r.pct = v + r.pct = &v if r.opts.OnChange != nil { r.opts.OnChange(r) } return nil } -func parseRollout(v string) error { +func parseRolloutFlag(v string) error { if match := validFlag.FindStringSubmatch(v); len(match) >= 3 { key, pctStr := match[1], match[2] if _, ok := registry[key]; !ok { @@ -303,7 +311,8 @@ func parseRollout(v string) error { key, pct) r := registry[key].(*rolloutImpl) - r.opts.Pct, r.pct = pct, pct + r.opts.Pct, r.pct = pct, &pct + r.initDone() return nil } return fmt.Errorf("not a valid feature flag override expression: %q", v) @@ -311,5 +320,5 @@ func parseRollout(v string) error { func init() { ff := slices.Sorted(maps.Keys(registry)) - flag.Func("rollout", "syntax: flag=n (valid feature flags: "+strings.Join(ff, ", ")+")", parseRollout) + flag.Func("rollout", "syntax: flag=n (valid feature flags: "+strings.Join(ff, ", ")+")", parseRolloutFlag) } -- 2.52.0