From 33aae9e0369c2bd1d24abb0c5829e83dc18dd19e Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Tue, 5 Aug 2025 23:01:18 -0400 Subject: [PATCH] [http] support ranges for route_action_s3 --- http/range.go | 120 ++++++++++++++++++++++++++++++++++++++++ http/route_action_s3.go | 76 ++++++++++++++++++++++++- http/server.go | 13 +++-- 3 files changed, 202 insertions(+), 7 deletions(-) create mode 100644 http/range.go diff --git a/http/range.go b/http/range.go new file mode 100644 index 0000000..580ab0f --- /dev/null +++ b/http/range.go @@ -0,0 +1,120 @@ +package http + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" +) + +type RangeUnit uint + +const ( + RangeUnitInvalid RangeUnit = iota + RangeUnitBytes +) + +func RangeUnitFromString(inp string) (RangeUnit, error) { + switch inp { + case "bytes": + return RangeUnitBytes, nil + } + + return RangeUnitInvalid, errors.New("invalid unit in Range header") +} + +func (ru RangeUnit) String() string { + switch ru { + case RangeUnitBytes: + return "bytes" + default: + panic("invalid range unit") + } +} + +type Range struct { + Start uint64 + End uint64 +} + +type RangeHeader struct { + Unit RangeUnit + Ranges []Range + SuffixLength uint64 +} + +var validRange = regexp.MustCompile("^[0-9]+-[0-9]*$") + +func ParseRange(val string) (*RangeHeader, error) { + unitStr, ranges, ok := strings.Cut(val, "=") + if !ok { + return nil, errors.New("invalid syntax of range header") + } + unit, err := RangeUnitFromString(unitStr) + if err != nil { + return nil, err + } + + h := &RangeHeader{ + Unit: unit, + } + + if len(ranges) < 1 { + return nil, errors.New("empty range header after unit") + } + + if ranges[0] == '-' { + i, err := strconv.ParseUint(ranges[1:], 10, 64) + if err != nil { + return nil, err + } + + h.SuffixLength = i + } else { + for _, rangeExpr := range strings.Split(ranges, ",") { + rangeExpr = strings.TrimSpace(rangeExpr) + if !validRange.MatchString(rangeExpr) { + return nil, errors.New("invalid syntax of range expression in range header") + } + + start, end, _ := strings.Cut(rangeExpr, "-") + r := Range{} + if i, err := strconv.ParseUint(start, 10, 64); err == nil { + r.Start = i + } else { + return nil, err + } + + if end != "" { + if i, err := strconv.ParseUint(end, 10, 64); err == nil { + r.End = i + } else { + return nil, err + } + + if r.Start >= r.End { + return nil, errors.New("invalid range: start is greater than end") + } + } + + h.Ranges = append(h.Ranges, r) + } + } + + return h, nil +} + +func (rh *RangeHeader) String() string { + s := rh.Unit.String() + "=" + if rh.SuffixLength > 0 { + s += fmt.Sprintf("-%d", rh.SuffixLength) + } else { + var ranges []string + for _, r := range rh.Ranges { + ranges = append(ranges, fmt.Sprintf("%d-%d", r.Start, r.End)) + } + s += strings.Join(ranges, ", ") + } + return s +} diff --git a/http/route_action_s3.go b/http/route_action_s3.go index 7cabce6..3ab3647 100644 --- a/http/route_action_s3.go +++ b/http/route_action_s3.go @@ -41,6 +41,20 @@ func (a *S3Action) Handle(w http.ResponseWriter, r *http.Request, next http.Hand }, "/"), "/") + rc := http.StatusOK + var rh *RangeHeader + + if ar := r.Header.Get("range"); ar != "" { + rh, err = ParseRange(ar) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("failed parsing range: " + err.Error())) + return + } + + LoggerFromContext(r.Context()).V(3).Debugf("parsed range header as: %s", rh) + } + object, err := mc.GetObject( r.Context(), a.BucketName, @@ -61,11 +75,67 @@ func (a *S3Action) Handle(w http.ResponseWriter, r *http.Request, next http.Hand return } + var seek int64 + var size = stat.Size + if rh != nil { + if rh.SuffixLength > 0 { + if rh.SuffixLength > uint64(stat.Size) { + http.Error(w, + fmt.Sprintf( + "requested suffix of %d bytes exceeds file size of %d bytes", + rh.SuffixLength, stat.Size), + http.StatusRequestedRangeNotSatisfiable) + + return + } + + w.Header().Set("content-range", + fmt.Sprintf("%s %d-%d/%d", rh.Unit.String(), uint64(stat.Size)-rh.SuffixLength, stat.Size-1, stat.Size)) + + rc = http.StatusPartialContent + size = int64(rh.SuffixLength) + seek = stat.Size - int64(rh.SuffixLength) + } else if len(rh.Ranges) == 1 { + rng := rh.Ranges[0] + if rng.End == 0 { + rng.End = uint64(stat.Size - 1) + } + if rng.Start >= uint64(stat.Size) || rng.End > uint64(stat.Size-1) { + http.Error(w, + fmt.Sprintf( + "requested range %d-%d exceeds file size of %d bytes", + rng.Start, rng.End, stat.Size), + http.StatusRequestedRangeNotSatisfiable) + return + } + + respRange := fmt.Sprintf("%s %d-%d/%d", rh.Unit.String(), rng.Start, rng.End, stat.Size) + w.Header().Set("content-range", respRange) + + seek = int64(rng.Start) + size = int64(rng.End-rng.Start) + 1 + rc = http.StatusPartialContent + + LoggerFromContext(r.Context()).V(3).Debugf( + "satisfiable range is %d bytes long, seeking to byte %d: %s", + size, seek, respRange) + } else { + http.Error(w, + "multiple ranges are not supported", + http.StatusRequestedRangeNotSatisfiable) + } + } + + if seek > 0 { + object.Seek(seek, io.SeekStart) + } else { + w.Header().Set("accept-ranges", "bytes") + } + w.Header().Set("content-length", fmt.Sprintf("%d", size)) w.Header().Set("content-type", stat.ContentType) - w.Header().Set("content-length", fmt.Sprintf("%d", stat.Size)) - w.WriteHeader(http.StatusOK) + w.WriteHeader(rc) - io.CopyN(w, object, stat.Size) + io.CopyN(w, object, size) } func (a *S3Action) minioClient() (mc *minio.Client, err error) { diff --git a/http/server.go b/http/server.go index eb52e72..e056d28 100644 --- a/http/server.go +++ b/http/server.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "os" + "regexp" "sync" "go.fuhry.dev/runtime/mtls" @@ -57,6 +58,7 @@ const ( kSamlDefaults ) +var portSpec = regexp.MustCompile(":[0-9]{1,5}$") var initHooks []initHook var routeParseFuncs []routeParseFunc @@ -298,10 +300,13 @@ func (l *Listener) handle(w http.ResponseWriter, r *http.Request) { // make sure this host is known vhost, ok := l.VirtualHosts[r.Host] if !ok { - http.Error(w, "Misdirected request: unknown virtual host", - http.StatusMisdirectedRequest) - - return + h := portSpec.ReplaceAllString(r.Host, "") + vhost, ok = l.VirtualHosts[h] + if !ok { + http.Error(w, "Misdirected request: unknown virtual host", + http.StatusMisdirectedRequest) + return + } } l.fulfill(w, r, vhost.Routes) -- 2.50.1