]> go.fuhry.dev Git - runtime.git/commitdiff
Fine let's just make it a full HTTP proxy
authorDan Fuhry <dan@fuhry.com>
Sat, 29 Mar 2025 03:38:07 +0000 (23:38 -0400)
committerDan Fuhry <dan@fuhry.com>
Sat, 29 Mar 2025 03:38:07 +0000 (23:38 -0400)
Refactor samlproxy into a general purpose proxy with pluggable actions. Add S3 bucket serving backend. Route actions can fulfill the request or modify it and call next(), basically the same idea as coredns but for http.

Backwards incompatible with existing configs.

13 files changed:
.gitignore
go.mod
go.sum
http/proxy/Makefile [moved from http/samlproxy/Makefile with 100% similarity]
http/proxy/main.go [new file with mode: 0644]
http/proxy/systemd/http-proxy@.service [moved from http/samlproxy/systemd/saml-proxy@.service with 100% similarity]
http/route_action_proxy.go [new file with mode: 0644]
http/route_action_redirect.go [new file with mode: 0644]
http/route_action_s3.go [new file with mode: 0644]
http/route_action_saml.go [new file with mode: 0644]
http/samlproxy.go [deleted file]
http/samlproxy/main.go [deleted file]
http/server.go [new file with mode: 0644]

index 50a6a0ba82e9eacd5068d995779ea578d780803c..43c001f0c18239212b9ddc743394ed6c7d01965d 100644 (file)
@@ -49,7 +49,7 @@ mtls/verify_tool/verify_tool
 ldap/health_exporter/health_exporter
 envoy/xds/envoy_xds/envoy_xds
 mtls/mtls_exporter/mtls_exporter
-http/samlproxy/samlproxy
+http/proxy/proxy
 automation/bryston_ctl/cli/cli
 automation/bryston_ctl/client/client
 automation/bryston_ctl/server/server
diff --git a/go.mod b/go.mod
index 8abdf6b599dbc5a9aba89fe5ec42c2d1de3d7bc8..8b799950bbd515b81b3a50a801b711d46310cee0 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -1,14 +1,16 @@
 module go.fuhry.dev/runtime
 
-go 1.23
+go 1.23.0
+
+toolchain go1.24.1
 
 require (
        github.com/google/certificate-transparency-go v1.1.4
        github.com/google/go-attestation v0.4.3
        github.com/google/go-tpm v0.3.3 // indirect
        github.com/google/go-tspi v0.2.1-0.20190423175329-115dea689aad // indirect
-       golang.org/x/crypto v0.25.0 // indirect
-       golang.org/x/sys v0.22.0
+       golang.org/x/crypto v0.36.0 // indirect
+       golang.org/x/sys v0.31.0
 )
 
 require (
@@ -39,8 +41,8 @@ require (
        go.fuhry.dev/fsnotify v1.7.2
        go.fuhry.dev/grpc-quic v0.1.2
        golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d
-       golang.org/x/sync v0.7.0
-       golang.org/x/term v0.22.0
+       golang.org/x/sync v0.12.0
+       golang.org/x/term v0.30.0
        gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
        gopkg.in/ini.v1 v1.67.0
        gopkg.in/yaml.v3 v3.0.1
@@ -66,11 +68,14 @@ require (
        github.com/felixge/httpsnoop v1.0.1 // indirect
        github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
        github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
+       github.com/go-ini/ini v1.67.0 // indirect
        github.com/go-logfmt/logfmt v0.5.1 // indirect
        github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
+       github.com/goccy/go-json v0.10.5 // indirect
        github.com/golang-jwt/jwt/v4 v4.5.0 // indirect
        github.com/gomodule/redigo v1.8.2 // indirect
        github.com/google/pprof v0.0.0-20230509042627-b1315fad0c5a // indirect
+       github.com/google/uuid v1.6.0 // indirect
        github.com/goph/emperror v0.17.2 // indirect
        github.com/gorilla/handlers v1.5.1 // indirect
        github.com/gorilla/mux v1.8.0 // indirect
@@ -78,11 +83,17 @@ require (
        github.com/jonboulle/clockwork v0.3.0 // indirect
        github.com/jpillora/backoff v1.0.0 // indirect
        github.com/json-iterator/go v1.1.12 // indirect
-       github.com/klauspost/compress v1.16.5 // indirect
+       github.com/klauspost/compress v1.18.0 // indirect
+       github.com/klauspost/cpuid/v2 v2.2.10 // indirect
        github.com/kr/pretty v0.3.1 // indirect
        github.com/kr/text v0.2.0 // indirect
        github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
        github.com/miekg/pkcs11 v1.1.1 // indirect
+       github.com/minio/crc64nvme v1.0.1 // indirect
+       github.com/minio/md5-simd v1.1.2 // indirect
+       github.com/minio/minio-go v6.0.14+incompatible // indirect
+       github.com/minio/minio-go/v7 v7.0.89 // indirect
+       github.com/mitchellh/go-homedir v1.1.0 // indirect
        github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
        github.com/modern-go/reflect2 v1.0.2 // indirect
        github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect
@@ -98,6 +109,7 @@ require (
        github.com/prometheus/procfs v0.12.0 // indirect
        github.com/quic-go/qtls-go1-20 v0.3.4 // indirect
        github.com/rogpeppe/go-internal v1.10.0 // indirect
+       github.com/rs/xid v1.6.0 // indirect
        github.com/russellhaering/goxmldsig v1.3.0 // indirect
        github.com/russross/blackfriday/v2 v2.1.0 // indirect
        github.com/thales-e-security/pool v0.0.2 // indirect
@@ -128,8 +140,8 @@ require (
        go.uber.org/atomic v1.11.0
        go.uber.org/multierr v1.8.0 // indirect
        go.uber.org/zap v1.21.0 // indirect
-       golang.org/x/net v0.27.0 // indirect
-       golang.org/x/text v0.16.0
+       golang.org/x/net v0.37.0 // indirect
+       golang.org/x/text v0.23.0
        google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d // indirect
        google.golang.org/grpc v1.59.0
        google.golang.org/protobuf v1.34.1
diff --git a/go.sum b/go.sum
index bedfe3b7f9d36b467226fb5e725803a6139b164d..b7cfe352a07a2f02f31f54e15c7413fbaa0c1f2f 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -171,6 +171,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK
 github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A=
+github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU=
@@ -188,6 +190,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
 github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
 github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
 github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
+github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
+github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
 github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
 github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
 github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
@@ -279,6 +283,8 @@ github.com/google/trillian v1.3.11/go.mod h1:0tPraVHrSDkA3BO6vKX67zgLXs6SsOAbHEi
 github.com/google/uuid v0.0.0-20161128191214-064e2069ce9c/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
 github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
 github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18=
@@ -362,6 +368,11 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
 github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI=
 github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
+github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
+github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
+github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
+github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE=
+github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -408,9 +419,18 @@ github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX7
 github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/miekg/pkcs11 v1.1.1 h1:Ugu9pdy6vAYku5DEpVWVFPYnzV+bxB+iRdbuFSu7TvU=
 github.com/miekg/pkcs11 v1.1.1/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
+github.com/minio/crc64nvme v1.0.1 h1:DHQPrYPdqK7jQG/Ls5CTBZWeex/2FMS3G5XGkycuFrY=
+github.com/minio/crc64nvme v1.0.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg=
+github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
+github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
+github.com/minio/minio-go v6.0.14+incompatible h1:fnV+GD28LeqdN6vT2XdGKW8Qe/IfjJDswNVuni6km9o=
+github.com/minio/minio-go v6.0.14+incompatible/go.mod h1:7guKYtitv8dktvNUGrhzmNlA5wrAABTQXCoesZdFQO8=
+github.com/minio/minio-go/v7 v7.0.89 h1:hx4xV5wwTUfyv8LarhJAwNecnXpoTsj9v3f3q/ZkiJU=
+github.com/minio/minio-go/v7 v7.0.89/go.mod h1:2rFnGAp02p7Dddo1Fq4S2wYOfpF0MUTSeLTRC90I204=
 github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
 github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
 github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
+github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
 github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
 github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc=
 github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg=
@@ -518,6 +538,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
 github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
 github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ=
 github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
+github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
+github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
 github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM=
 github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw=
 github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
@@ -651,6 +673,8 @@ golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm
 golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
 golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
 golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
+golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
+golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
 golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -727,6 +751,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
 golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
 golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
+golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
+golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
 golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -747,6 +773,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
 golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
+golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
 golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -799,9 +827,13 @@ golang.org/x/sys v0.0.0-20210629170331-7dc0b73dc9fb/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
 golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
+golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
 golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
+golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
+golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
 golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -810,6 +842,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
 golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
+golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
+golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
 golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
similarity index 100%
rename from http/samlproxy/Makefile
rename to http/proxy/Makefile
diff --git a/http/proxy/main.go b/http/proxy/main.go
new file mode 100644 (file)
index 0000000..c9c727b
--- /dev/null
@@ -0,0 +1,64 @@
+package main
+
+import (
+       "context"
+       "flag"
+       "os"
+       "os/signal"
+       "syscall"
+       "time"
+
+       "github.com/coreos/go-systemd/daemon"
+       "gopkg.in/yaml.v3"
+
+       "go.fuhry.dev/runtime/http"
+       "go.fuhry.dev/runtime/mtls"
+       "go.fuhry.dev/runtime/utils/log"
+)
+
+func main() {
+       mtls.SetDefaultIdentity("authproxy")
+
+       ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+       defer cancel()
+
+       server := http.NewServerWithContext(ctx)
+
+       loadConfig := func(arg string) error {
+               contents, err := os.ReadFile(arg)
+               if err != nil {
+                       return err
+               }
+
+               err = yaml.Unmarshal(contents, server)
+               return err
+       }
+
+       flag.Func("config", "YAML file to load configuration from", loadConfig)
+       flag.StringVar(&server.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private")
+       flag.StringVar(&server.Listener.Addr, "listen", "[::]:8443", "address for auth proxy to listen on")
+       flag.StringVar(&server.Listener.InsecureAddr, "listen.http", "[::]:8080", "address for http-to-https redirector")
+
+       flag.Parse()
+
+       httpServer, err := server.Create()
+       if err != nil {
+               log.Panic(err)
+       }
+       go httpServer.ListenAndServeTLS("", "")
+
+       log.Default().Infof("listening on HTTPS at %s", server.Listener.Addr)
+
+       unsecureServer := server.CreateInsecure()
+       go unsecureServer.ListenAndServe()
+
+       log.Default().Infof("listening on HTTP at %s (redirects to HTTPS only)", server.Listener.InsecureAddr)
+
+       daemon.SdNotify(false, daemon.SdNotifyReady)
+
+       <-ctx.Done()
+       shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
+       defer shutdownCancel()
+       httpServer.Shutdown(shutdownCtx)
+       unsecureServer.Shutdown(shutdownCtx)
+}
diff --git a/http/route_action_proxy.go b/http/route_action_proxy.go
new file mode 100644 (file)
index 0000000..83dccae
--- /dev/null
@@ -0,0 +1,153 @@
+package http
+
+import (
+       "context"
+       "crypto/tls"
+       "fmt"
+       "io"
+       "net"
+       "net/http"
+       "strconv"
+       "strings"
+       "sync"
+
+       "gopkg.in/yaml.v3"
+
+       "go.fuhry.dev/runtime/mtls"
+)
+
+type staticUpstreamAction struct {
+       Host     string `yaml:"host"`
+       Port     int    `yaml:"port"`
+       Identity string `yaml:"mtls_id"`
+
+       client     *http.Client
+       clientOnce sync.Once
+}
+
+// httpClient returns an HTTP client for making requests to the backend.
+func (su *staticUpstreamAction) httpClient() (*http.Client, error) {
+       var err error
+       su.clientOnce.Do(func() {
+               transport := &http.Transport{}
+               var tlsConfig *tls.Config
+
+               if su.Identity != "" {
+                       myIdentity := mtls.DefaultIdentity()
+                       tlsConfig, err = myIdentity.TlsConfig(context.Background())
+                       if err != nil {
+                               return
+                       }
+
+                       verifier := mtls.NewPeerNameVerifier()
+                       verifier.AllowFrom(mtls.Service, su.Identity)
+                       err = verifier.ConfigureClient(tlsConfig)
+                       if err != nil {
+                               return
+                       }
+
+                       transport.TLSClientConfig = tlsConfig
+               }
+
+               client := &http.Client{
+                       Transport: transport,
+               }
+
+               su.client = client
+       })
+       if err != nil {
+               return nil, err
+       }
+       return su.client, nil
+}
+
+// Handle implements RouteAction.
+func (su *staticUpstreamAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
+       logger := LoggerFromContext(r.Context())
+
+       upstreamReq := r.Clone(r.Context())
+       upstreamReq.URL.Scheme = "http"
+       if su.Identity != "" {
+               upstreamReq.URL.Scheme = "https"
+       }
+       upstreamReq.URL.Host = net.JoinHostPort(su.Host, strconv.Itoa(su.Port))
+       upstreamReq.RequestURI = ""
+
+       // set proxy headers
+       if remoteHost, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
+               logger.V(3).Debugf("x-forwarded-for: %s", remoteHost)
+               upstreamReq.Header.Set("x-forwarded-for", remoteHost)
+       }
+
+       // proxy the request to the backend
+       client, err := su.httpClient()
+       if err != nil {
+               http.Error(w,
+                       fmt.Sprintf("error setting up connection to backend: %v", err),
+                       http.StatusInternalServerError)
+       }
+       response, err := client.Do(upstreamReq)
+       if err != nil {
+               http.Error(w, err.Error(), http.StatusBadGateway)
+               return
+       }
+
+       for name, value := range response.Header {
+               w.Header().Set(name, strings.Join(value, ", "))
+       }
+
+       if response.StatusCode == http.StatusSwitchingProtocols {
+               hijacker, ok := w.(http.Hijacker)
+               if !ok {
+                       http.Error(w, "websocket passthrough not supported", http.StatusMethodNotAllowed)
+                       return
+               }
+
+               upstreamWriter, ok := response.Body.(io.Writer)
+               if !ok {
+                       http.Error(w, "body doesn't support io.Writer", http.StatusMethodNotAllowed)
+                       return
+               }
+
+               w.WriteHeader(response.StatusCode)
+
+               conn, rw, err := hijacker.Hijack()
+               if err != nil {
+                       http.Error(w, err.Error(), http.StatusInternalServerError)
+                       return
+               }
+
+               wg := sync.WaitGroup{}
+               wg.Add(2)
+               pipe := func(w io.Writer, r io.Reader) {
+                       defer wg.Done()
+                       io.Copy(w, r)
+               }
+               go pipe(rw, response.Body)
+               go pipe(upstreamWriter, rw)
+
+               wg.Wait()
+               conn.Close()
+               return
+       }
+
+       w.WriteHeader(response.StatusCode)
+       io.Copy(w, response.Body)
+}
+
+func staticUpstreamActionFromYaml(node *yaml.Node) (RouteAction, error) {
+       var rawNode struct {
+               Proxy *staticUpstreamAction `yaml:"proxy"`
+       }
+
+       err := node.Decode(&rawNode)
+       if err != nil || rawNode.Proxy == nil {
+               return nil, nil
+       }
+
+       return rawNode.Proxy, nil
+}
+
+func init() {
+       AddRouteParseFunc(staticUpstreamActionFromYaml)
+}
diff --git a/http/route_action_redirect.go b/http/route_action_redirect.go
new file mode 100644 (file)
index 0000000..2197231
--- /dev/null
@@ -0,0 +1,66 @@
+package http
+
+import (
+       "net/http"
+       "net/url"
+
+       "gopkg.in/yaml.v3"
+)
+
+type RedirectAction struct {
+       StatusCode  int
+       Destination *url.URL
+}
+
+// Handle implements RouteAction
+func (a *RedirectAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
+       newUrl := *r.URL
+
+       if a.Destination.Host != "" {
+               newUrl.Host = a.Destination.Host
+       }
+       newUrl.Path = a.Destination.Path
+       if a.Destination.RawQuery != "" {
+               newUrl.RawQuery = a.Destination.RawQuery
+       }
+       if a.Destination.Scheme != "" {
+               newUrl.Scheme = a.Destination.Scheme
+       }
+       if a.Destination.Fragment != "" {
+               newUrl.Fragment = a.Destination.Fragment
+       }
+
+       status := a.StatusCode
+       if status == 0 {
+               status = http.StatusFound
+       }
+
+       w.Header().Set("location", newUrl.String())
+       w.WriteHeader(status)
+}
+
+func redirectFromRouteYaml(node *yaml.Node) (RouteAction, error) {
+       var rawNode struct {
+               Redirect *struct {
+                       Destination string `yaml:"dest"`
+                       Status      int    `yaml:"status"`
+               } `yaml:"redirect,omitempty"`
+       }
+
+       if err := node.Decode(&rawNode); err == nil && rawNode.Redirect != nil {
+               u, err := url.Parse(rawNode.Redirect.Destination)
+               if err != nil {
+                       return nil, err
+               }
+               return &RedirectAction{
+                       Destination: u,
+                       StatusCode:  rawNode.Redirect.Status,
+               }, nil
+       }
+
+       return nil, nil
+}
+
+func init() {
+       AddRouteParseFunc(redirectFromRouteYaml)
+}
diff --git a/http/route_action_s3.go b/http/route_action_s3.go
new file mode 100644 (file)
index 0000000..49bf85b
--- /dev/null
@@ -0,0 +1,132 @@
+package http
+
+import (
+       "fmt"
+       "io"
+       "net/http"
+       "strconv"
+       "strings"
+       "sync"
+
+       "github.com/minio/minio-go/v7"
+       "github.com/minio/minio-go/v7/pkg/credentials"
+       "gopkg.in/yaml.v3"
+)
+
+type S3Action struct {
+       S3Endpoint   string `yaml:"endpoint"`
+       S3AccessKey  string `yaml:"access_key"`
+       S3SecretKey  string `yaml:"secret_key"`
+       BucketName   string `yaml:"bucket"`
+       ObjectPrefix string `yaml:"prefix"`
+       StripPrefix  string `yaml:"strip_prefix"`
+
+       mc     *minio.Client
+       mcOnce sync.Once
+}
+
+// Handle implements RouteAction
+func (a *S3Action) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
+       mc, err := a.minioClient()
+       if err != nil {
+               w.WriteHeader(http.StatusInternalServerError)
+               w.Write([]byte(fmt.Sprintf("failed to init minio client: %+v", err)))
+               return
+       }
+
+       reqPath := strings.TrimPrefix(r.URL.Path, a.StripPrefix)
+       objPath := "/" + strings.Trim(strings.Join(
+               []string{
+                       strings.Trim(a.ObjectPrefix, "/"),
+                       strings.Trim(reqPath, "/"),
+               },
+               "/"), "/")
+
+       object, err := mc.GetObject(
+               r.Context(),
+               a.BucketName,
+               objPath,
+               minio.GetObjectOptions{})
+
+       if err != nil {
+               w.WriteHeader(http.StatusInternalServerError)
+               w.Write([]byte(fmt.Sprintf("failed to GetObject %q: %+v", reqPath, err)))
+               return
+       }
+
+       stat, err := object.Stat()
+       if err != nil {
+               w.WriteHeader(http.StatusInternalServerError)
+               w.Write([]byte(fmt.Sprintf("failed to stat object %q: %+v", objPath, err)))
+               return
+       }
+
+       w.Header().Set("content-type", stat.ContentType)
+       w.WriteHeader(http.StatusOK)
+
+       if strings.HasPrefix(stat.ContentType, "text/") || strings.HasPrefix(stat.ContentType, "application/") {
+               buf := make([]byte, 32*1024)
+               var out []byte
+               for {
+                       nr, err := object.Read(buf)
+                       if nr == 0 && err != nil {
+                               break
+                       }
+                       end := nr
+                       for i := nr; i > 0; i-- {
+                               if buf[i-1] != '\000' {
+                                       end = i
+                                       break
+                               }
+                       }
+                       out = append(out, buf[:end]...)
+
+                       if err == io.EOF {
+                               break
+                       }
+               }
+               w.Header().Set("content-length", strconv.Itoa(len(out)))
+               for nw := 0; nw < len(out); {
+                       c, err := w.Write(out[nw:])
+                       if err != nil {
+                               break
+                       }
+                       nw += c
+               }
+       } else {
+               io.Copy(w, object)
+       }
+}
+
+func (a *S3Action) minioClient() (mc *minio.Client, err error) {
+       a.mcOnce.Do(func() {
+               mc, err = minio.New(
+                       a.S3Endpoint,
+                       &minio.Options{
+                               Creds:        credentials.NewStaticV4(a.S3AccessKey, a.S3SecretKey, ""),
+                               Secure:       true,
+                               BucketLookup: minio.BucketLookupDNS,
+                       },
+               )
+
+               a.mc = mc
+       })
+
+       return a.mc, err
+}
+
+func s3ActionFromRouteYaml(node *yaml.Node) (RouteAction, error) {
+       var rawNode struct {
+               S3 *S3Action `yaml:"s3,omitempty"`
+       }
+
+       if err := node.Decode(&rawNode); err == nil && rawNode.S3 != nil {
+               return rawNode.S3, nil
+       }
+
+       return nil, nil
+}
+
+func init() {
+       AddRouteParseFunc(s3ActionFromRouteYaml)
+}
diff --git a/http/route_action_saml.go b/http/route_action_saml.go
new file mode 100644 (file)
index 0000000..eb0d829
--- /dev/null
@@ -0,0 +1,322 @@
+package http
+
+import (
+       "context"
+       "crypto"
+       "crypto/rsa"
+       "crypto/x509"
+       "crypto/x509/pkix"
+       "fmt"
+       "math/big"
+       "net/http"
+       "net/url"
+       "regexp"
+       "strconv"
+       "strings"
+       "sync"
+       "time"
+
+       "github.com/crewjam/saml"
+       "github.com/crewjam/saml/samlsp"
+
+       "go.fuhry.dev/runtime/mtls/certutil"
+       "go.fuhry.dev/runtime/utils/hashset"
+       "gopkg.in/yaml.v3"
+)
+
+type SAMLServiceProvider struct {
+       EntityID          string `yaml:"entity_id"`
+       EntityCertificate string `yaml:"entity_certificate"`
+       EntityPrivateKey  string `yaml:"entity_key"`
+       IDP               string `yaml:"idp"`
+
+       metadata     *saml.EntityDescriptor
+       metadataOnce sync.Once
+
+       entityCert  *x509.Certificate
+       entityKey   *rsa.PrivateKey
+       certKeyOnce sync.Once
+       spMu        sync.Mutex
+       mw          map[string]*samlsp.Middleware
+}
+
+type samlAction struct {
+       sp             *SAMLServiceProvider
+       requireAuth    bool
+       usernameHeader string
+}
+
+var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`)
+var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"})
+
+func (sa *samlAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
+       logger := LoggerFromContext(r.Context())
+
+       // ensure client isn't trying to inject saml-related headers
+       if err := sa.checkRequest(r); err != nil {
+               http.Error(w, err.Error(), http.StatusBadRequest)
+               return
+       }
+
+       sp := sa.sp
+       if sp == nil {
+               sp = serverDefaultSamlConfig(r.Context())
+       }
+       if sp == nil {
+               http.Error(w, "SAML auth requested but no SP config present",
+                       http.StatusInternalServerError)
+       }
+       provider, err := sp.getMiddleware(r.Host)
+       if err != nil {
+               http.Error(w, err.Error(), http.StatusInternalServerError)
+               return
+       }
+
+       if r.URL.Path == "/saml/acs" {
+               provider.ServeACS(w, r)
+               return
+       }
+
+       session, sessionErr := provider.Session.GetSession(r)
+
+       if sessionErr != nil && sessionErr != samlsp.ErrNoSession {
+               http.Error(w, sessionErr.Error(), http.StatusBadRequest)
+               return
+       }
+
+       if sessionErr == samlsp.ErrNoSession && sa.requireAuth {
+               logger.V(3).Debugf("route requires a valid session, redirecting")
+
+               provider.HandleStartAuthFlow(w, r)
+               return
+       }
+
+       if swa, ok := session.(samlsp.SessionWithAttributes); ok {
+               attrs := swa.GetAttributes()
+               oboHeader := sa.usernameHeader
+               if oboHeader == "" {
+                       oboHeader = "on-behalf-of"
+               }
+               logger.V(3).Debugf("setting origin request header: %s: %q", oboHeader, attrs.Get("uid"))
+               r.Header.Set(oboHeader, attrs.Get("uid"))
+       }
+
+       if jwts, ok := session.(samlsp.JWTSessionClaims); ok {
+               r.Header.Set("x-saml-audience", jwts.StandardClaims.Audience)
+               logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-audience", jwts.StandardClaims.Audience)
+
+               iat := strconv.FormatInt(jwts.StandardClaims.IssuedAt, 10)
+               r.Header.Set("x-saml-issued-at", iat)
+               logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-issued-at", iat)
+
+               eat := strconv.FormatInt(jwts.StandardClaims.ExpiresAt, 10)
+               r.Header.Set("x-saml-expires-at", eat)
+               logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-expires-at", eat)
+
+               r.Header.Set("x-saml-subject", jwts.StandardClaims.Subject)
+               logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject)
+
+               for attr, values := range jwts.Attributes {
+                       headerName := fmt.Sprintf("x-saml-%s",
+                               samlAttributeReplaceRegexp.ReplaceAllString(strings.ToLower(attr), "-"))
+                       headerValue := strings.Join(values, ", ")
+                       logger.V(3).Debugf("setting origin request header: %s: %s",
+                               headerName, headerValue)
+                       r.Header.Set(headerName, headerValue)
+               }
+       } else {
+               r.Header.Set("x-saml-anonymous-auth", "1")
+       }
+
+       next(w, r)
+}
+
+func (sa *samlAction) checkRequest(r *http.Request) error {
+       for k, _ := range r.Header {
+               k = strings.ToLower(k)
+               if strings.HasPrefix(k, "x-saml-") || restrictedHeaders.Contains(k) {
+                       return fmt.Errorf("downstream attempted to overwrite restricted header: %q", k)
+               }
+       }
+       return nil
+}
+
+func (sp *SAMLServiceProvider) Metadata() (*saml.EntityDescriptor, error) {
+       var err error
+       sp.metadataOnce.Do(func() {
+               var idpMetadataUrl *url.URL
+               idpMetadataUrl, err = url.Parse(sp.IDP)
+               if err != nil {
+                       return
+               }
+
+               sp.metadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataUrl)
+       })
+
+       if err != nil {
+               sp.metadataOnce = sync.Once{}
+               return nil, err
+       }
+
+       return sp.metadata, nil
+}
+
+func (sp *SAMLServiceProvider) certAndKey() (cert *x509.Certificate, pvk *rsa.PrivateKey, err error) {
+       sp.certKeyOnce.Do(func() {
+               if sp.EntityPrivateKey != "" {
+                       var loadedKey crypto.PrivateKey
+                       loadedKey, err = certutil.LoadPrivateKeyFromPEM(sp.EntityPrivateKey)
+                       if err != nil {
+                               return
+                       }
+                       var ok bool
+                       pvk, ok = loadedKey.(*rsa.PrivateKey)
+                       if !ok {
+                               err = fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
+                               return
+                       }
+               } else {
+                       // generate new RSA private key
+                       pvk, err = rsa.GenerateKey(saml.RandReader, 2048)
+                       if err != nil {
+                               return
+                       }
+               }
+               if sp.EntityCertificate != "" {
+                       certs, err := certutil.LoadCertificatesFromPEM(sp.EntityCertificate)
+                       if err != nil {
+                               return
+                       }
+                       cert = certs[0]
+               } else {
+                       // generate new self-signed X509 certificate
+                       serialBytes := make([]byte, 16)
+                       saml.RandReader.Read(serialBytes)
+                       serial := big.NewInt(0)
+                       serial.SetBytes(serialBytes)
+
+                       template := &x509.Certificate{
+                               Subject: pkix.Name{
+                                       CommonName: sp.EntityID,
+                               },
+                               SerialNumber:          serial,
+                               BasicConstraintsValid: true,
+                               ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
+                               KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
+                               IsCA:                  false,
+                               NotBefore:             time.Now(),
+                               NotAfter:              time.Now().Add(90 * 86400 * time.Second),
+                       }
+                       certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &pvk.PublicKey, pvk)
+                       if err != nil {
+                               return
+                       }
+                       cert, err = x509.ParseCertificate(certBytes)
+                       if err != nil {
+                               return
+                       }
+               }
+
+               sp.entityCert = cert
+               sp.entityKey = pvk
+       })
+
+       if err != nil {
+               sp.metadataOnce = sync.Once{}
+               return nil, nil, err
+       }
+
+       return sp.entityCert, sp.entityKey, nil
+}
+
+func (sp *SAMLServiceProvider) getMiddleware(host string) (*samlsp.Middleware, error) {
+       sp.spMu.Lock()
+       defer sp.spMu.Unlock()
+
+       if sp.mw == nil {
+               sp.mw = make(map[string]*samlsp.Middleware)
+       }
+
+       if _, ok := sp.mw[host]; !ok {
+               mw, err := sp.newMiddleware(host)
+               if err != nil {
+                       return nil, err
+               }
+               sp.mw[host] = mw
+       }
+
+       return sp.mw[host], nil
+}
+
+func (sp *SAMLServiceProvider) newMiddleware(host string) (*samlsp.Middleware, error) {
+       idpMetadata, err := sp.Metadata()
+       if err != nil {
+               return nil, err
+       }
+       cert, key, err := sp.certAndKey()
+       if err != nil {
+               return nil, err
+       }
+       return samlsp.New(samlsp.Options{
+               EntityID: sp.EntityID,
+               URL: url.URL{
+                       Scheme: "https",
+                       Host:   host,
+               },
+               Key:         key,
+               Certificate: cert,
+               IDPMetadata: idpMetadata,
+       })
+}
+
+func samlInitHook(ctx context.Context, node *yaml.Node) (context.Context, error) {
+       var samlConfig struct {
+               SP *SAMLServiceProvider `yaml:"saml"`
+       }
+
+       if err := node.Decode(&samlConfig); err == nil && samlConfig.SP != nil {
+               ctx = context.WithValue(ctx, kSamlDefaults, samlConfig.SP)
+       }
+
+       return ctx, nil
+}
+
+func serverDefaultSamlConfig(ctx context.Context) *SAMLServiceProvider {
+       v := ctx.Value(kSamlDefaults)
+       if c, ok := v.(*SAMLServiceProvider); ok {
+               return c
+       }
+       return nil
+}
+
+func samlActionFromRouteYaml(node *yaml.Node) (RouteAction, error) {
+       var rawNode struct {
+               SP *struct {
+                       *SAMLServiceProvider
+                       Require string `yaml:"require"`
+               } `yaml:"saml,omitempty"`
+               Auth string `yaml:"auth"`
+       }
+
+       err := node.Decode(&rawNode)
+       if err != nil || rawNode.Auth != "saml" {
+               return nil, nil
+       }
+
+       require, err := strconv.ParseBool(rawNode.SP.Require)
+       if err != nil {
+               return nil, err
+       }
+
+       sa := &samlAction{
+               sp:          rawNode.SP.SAMLServiceProvider,
+               requireAuth: require,
+       }
+
+       return sa, nil
+}
+
+func init() {
+       AddServerInitHook(samlInitHook)
+       AddRouteParseFunc(samlActionFromRouteYaml)
+}
diff --git a/http/samlproxy.go b/http/samlproxy.go
deleted file mode 100644 (file)
index 489f5d2..0000000
+++ /dev/null
@@ -1,694 +0,0 @@
-package http
-
-import (
-       "context"
-       "crypto"
-       "crypto/rsa"
-       "crypto/tls"
-       "crypto/x509"
-       "crypto/x509/pkix"
-       "errors"
-       "fmt"
-       "io"
-       "math/big"
-       "net"
-       "net/http"
-       "net/url"
-       "os"
-       "regexp"
-       "strconv"
-       "strings"
-       "sync"
-       "time"
-
-       "github.com/crewjam/saml"
-       "github.com/crewjam/saml/samlsp"
-
-       "go.fuhry.dev/runtime/mtls"
-       "go.fuhry.dev/runtime/mtls/certutil"
-       "go.fuhry.dev/runtime/utils/hashset"
-       "go.fuhry.dev/runtime/utils/log"
-       "go.fuhry.dev/runtime/utils/stringmatch"
-       "gopkg.in/yaml.v3"
-)
-
-type authEnforcement uint
-
-type RouteAction interface {
-       Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc)
-}
-
-type RedirectAction struct {
-       StatusCode  int
-       Destination *url.URL
-}
-
-type Route struct {
-       Auth   authEnforcement
-       Path   stringmatch.StringMatcher
-       Action RouteAction
-}
-
-type SAMLBackend struct {
-       Host           string `yaml:"host"`
-       Port           int    `yaml:"port"`
-       Identity       string `yaml:"mtls_id"`
-       UsernameHeader string `yaml:"username_header"`
-
-       client     *http.Client
-       clientOnce sync.Once
-}
-
-type SAMLVirtualHost struct {
-       *SAMLServiceProvider `yaml:"saml"`
-
-       Backend *SAMLBackend `yaml:"backend"`
-       Routes  []*Route     `yaml:"routes"`
-}
-
-type SAMLServiceProvider struct {
-       EntityID          string `yaml:"entity_id"`
-       EntityCertificate string `yaml:"entity_certificate"`
-       EntityPrivateKey  string `yaml:"entity_key"`
-       IDP               string `yaml:"idp"`
-
-       metadata     *saml.EntityDescriptor
-       metadataOnce sync.Once
-
-       entityCert  *x509.Certificate
-       entityKey   *rsa.PrivateKey
-       certKeyOnce sync.Once
-}
-
-type SAMLListener struct {
-       *SAMLServiceProvider `yaml:"saml"`
-
-       Addr         string                      `yaml:"listen"`
-       InsecureAddr string                      `yaml:"listen_insecure"`
-       Certificate  string                      `yaml:"cert"`
-       VirtualHosts map[string]*SAMLVirtualHost `yaml:"virtual_hosts"`
-}
-
-type SAMLProxy struct {
-       Listener SAMLListener `yaml:"listener"`
-
-       logger log.Logger
-}
-
-const (
-       AuthRequired authEnforcement = iota
-       AuthOptional
-)
-
-var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`)
-var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"})
-
-// Handle implements RouteAction
-func (a *RedirectAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
-       newUrl := *r.URL
-
-       if a.Destination.Host != "" {
-               newUrl.Host = a.Destination.Host
-       }
-       newUrl.Path = a.Destination.Path
-       if a.Destination.RawQuery != "" {
-               newUrl.RawQuery = a.Destination.RawQuery
-       }
-       if a.Destination.Scheme != "" {
-               newUrl.Scheme = a.Destination.Scheme
-       }
-       if a.Destination.Fragment != "" {
-               newUrl.Fragment = a.Destination.Fragment
-       }
-
-       status := a.StatusCode
-       if status == 0 {
-               status = http.StatusFound
-       }
-
-       w.Header().Set("location", newUrl.String())
-       w.WriteHeader(status)
-}
-
-// UnmarshalYAML implements yaml.Unmarshaler
-func (r *Route) UnmarshalYAML(node *yaml.Node) error {
-       var rawNode struct {
-               Auth     string                 `yaml:"auth"`
-               Path     *stringmatch.MatchRule `yaml:"path"`
-               Redirect *struct {
-                       Destination string `yaml:"dest"`
-                       Status      int    `yaml:"status"`
-               } `yaml:"redirect"`
-       }
-
-       if err := node.Decode(&rawNode); err != nil {
-               return err
-       }
-
-       switch rawNode.Auth {
-       case "required":
-               r.Auth = AuthRequired
-       case "optional":
-               r.Auth = AuthOptional
-       default:
-               return fmt.Errorf("error unmarshaling route: invalid auth enforcement string value: %s", node.Value)
-       }
-
-       if rawNode.Path != nil {
-               m, err := rawNode.Path.Matcher()
-               if err != nil {
-                       return fmt.Errorf("error unmarshaling route: invalid path matcher: %v", err)
-               }
-               r.Path = m
-       } else {
-               return errors.New("error unmarshaling route: exactly one of (`path`) must be specified")
-       }
-
-       if rawNode.Redirect != nil {
-               u, err := url.Parse(rawNode.Redirect.Destination)
-               if err != nil {
-                       return err
-               }
-               r.Action = &RedirectAction{
-                       Destination: u,
-                       StatusCode:  rawNode.Redirect.Status,
-               }
-       }
-
-       return nil
-}
-
-// RouteFromArg implements the 3rd argument to flag.Func.
-//
-// It parses a string in the format of auth:field:match_mode:value, returning a Route if
-// it parses successfully.
-func RouteFromArg(arg string) (*Route, error) {
-       parts := strings.SplitN(arg, ":", 4)
-       if len(parts) != 4 {
-               return nil, fmt.Errorf("invalid route spec: %q", arg)
-       }
-       a, f, t, v := parts[0], parts[1], parts[2], parts[3]
-       var auth authEnforcement
-       switch strings.ToLower(a) {
-       case "r", "req", "required":
-               auth = AuthRequired
-       case "o", "opt", "optional":
-               auth = AuthOptional
-       default:
-               return nil, fmt.Errorf("invalid auth setting: %q", a)
-       }
-
-       route := &Route{
-               Auth: auth,
-       }
-
-       match := stringmatch.MatchRule{
-               Mode:  t,
-               Value: v,
-       }
-       m, err := match.Matcher()
-       if err != nil {
-               return nil, err
-       }
-
-       switch strings.ToLower(f) {
-       case "p", "path":
-               route.Path = m
-       default:
-               return nil, fmt.Errorf("invalid match field: %q", f)
-       }
-
-       return route, nil
-}
-
-// Client returns an HTTP client for making requests to the backend.
-func (b *SAMLBackend) Client() (*http.Client, error) {
-       var err error
-       b.clientOnce.Do(func() {
-               transport := &http.Transport{}
-               var tlsConfig *tls.Config
-
-               if b.Identity != "" {
-                       myIdentity := mtls.DefaultIdentity()
-                       tlsConfig, err = myIdentity.TlsConfig(context.Background())
-                       if err != nil {
-                               return
-                       }
-
-                       verifier := mtls.NewPeerNameVerifier()
-                       verifier.AllowFrom(mtls.Service, b.Identity)
-                       err = verifier.ConfigureClient(tlsConfig)
-                       if err != nil {
-                               return
-                       }
-
-                       transport.TLSClientConfig = tlsConfig
-               }
-
-               client := &http.Client{
-                       Transport: transport,
-               }
-
-               b.client = client
-       })
-       if err != nil {
-               return nil, err
-       }
-       return b.client, nil
-}
-
-// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host
-// and other settings.
-func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
-       var _ yaml.Unmarshaler = &Route{}
-
-       if sp.logger == nil {
-               sp.logger = log.Default().WithPrefix("SAMLProxy")
-       }
-
-       handler, err := sp.newHandler()
-       if err != nil {
-               return nil, err
-       }
-
-       addr := sp.Listener.Addr
-       if addr == "" {
-               addr = "[::]:8443"
-       }
-
-       lm := log.NewLoggingMiddlewareWithLogger(handler, sp.logger)
-       server := &http.Server{
-               Addr:    addr,
-               Handler: lm.HandlerFunc(),
-       }
-
-       if sp.Listener.Certificate != "" {
-               cert := mtls.NewSSLCertificate(sp.Listener.Certificate)
-               tlsConfig, err := cert.TlsConfig(ctx)
-               if err != nil {
-                       return nil, err
-               }
-               server.TLSConfig = tlsConfig
-       }
-
-       return server, nil
-}
-
-func (sp *SAMLProxy) NewHTTPSRedirectorWithContext(ctx context.Context) *http.Server {
-       addr := sp.Listener.InsecureAddr
-       if addr == "" {
-               addr = "[::]:8080"
-       }
-
-       server := &http.Server{
-               Addr: addr,
-               Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-                       host := r.Host
-                       if host == "" {
-                               w.WriteHeader(http.StatusBadRequest)
-                               return
-                       }
-
-                       if _, ok := sp.Listener.VirtualHosts[host]; !ok {
-                               w.WriteHeader(http.StatusMisdirectedRequest)
-                               return
-                       }
-
-                       newUrl := *r.URL
-                       newUrl.Scheme = "https"
-                       newUrl.Host = host
-                       w.Header().Set("location", newUrl.String())
-                       w.WriteHeader(http.StatusFound)
-               }),
-       }
-
-       return server
-}
-
-func (sp *SAMLServiceProvider) Metadata() (*saml.EntityDescriptor, error) {
-       var err error
-       sp.metadataOnce.Do(func() {
-               var idpMetadataUrl *url.URL
-               idpMetadataUrl, err = url.Parse(sp.IDP)
-               if err != nil {
-                       return
-               }
-
-               sp.metadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataUrl)
-       })
-
-       if err != nil {
-               sp.metadataOnce = sync.Once{}
-               return nil, err
-       }
-
-       return sp.metadata, nil
-}
-
-func (sp *SAMLServiceProvider) CertAndKey() (cert *x509.Certificate, pvk *rsa.PrivateKey, err error) {
-       sp.certKeyOnce.Do(func() {
-               if sp.EntityPrivateKey != "" {
-                       var loadedKey crypto.PrivateKey
-                       loadedKey, err = certutil.LoadPrivateKeyFromPEM(sp.EntityPrivateKey)
-                       if err != nil {
-                               return
-                       }
-                       var ok bool
-                       pvk, ok = loadedKey.(*rsa.PrivateKey)
-                       if !ok {
-                               err = fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk)
-                               return
-                       }
-               } else {
-                       // generate new RSA private key
-                       pvk, err = rsa.GenerateKey(saml.RandReader, 2048)
-                       if err != nil {
-                               return
-                       }
-               }
-               if sp.EntityCertificate != "" {
-                       certs, err := certutil.LoadCertificatesFromPEM(sp.EntityCertificate)
-                       if err != nil {
-                               return
-                       }
-                       cert = certs[0]
-               } else {
-                       // generate new self-signed X509 certificate
-                       serialBytes := make([]byte, 16)
-                       saml.RandReader.Read(serialBytes)
-                       serial := big.NewInt(0)
-                       serial.SetBytes(serialBytes)
-
-                       template := &x509.Certificate{
-                               Subject: pkix.Name{
-                                       CommonName: sp.EntityID,
-                               },
-                               SerialNumber:          serial,
-                               BasicConstraintsValid: true,
-                               ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
-                               KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
-                               IsCA:                  false,
-                               NotBefore:             time.Now(),
-                               NotAfter:              time.Now().Add(90 * 86400 * time.Second),
-                       }
-                       certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &pvk.PublicKey, pvk)
-                       if err != nil {
-                               return
-                       }
-                       cert, err = x509.ParseCertificate(certBytes)
-                       if err != nil {
-                               return
-                       }
-               }
-
-               sp.entityCert = cert
-               sp.entityKey = pvk
-       })
-
-       if err != nil {
-               sp.metadataOnce = sync.Once{}
-               return nil, nil, err
-       }
-
-       return sp.entityCert, sp.entityKey, nil
-}
-
-func (sp *SAMLServiceProvider) NewServiceProvider(host string) (*samlsp.Middleware, error) {
-       idpMetadata, err := sp.Metadata()
-       if err != nil {
-               return nil, err
-       }
-       cert, key, err := sp.CertAndKey()
-       if err != nil {
-               return nil, err
-       }
-       return samlsp.New(samlsp.Options{
-               EntityID: sp.EntityID,
-               URL: url.URL{
-                       Scheme: "https",
-                       Host:   host,
-               },
-               Key:         key,
-               Certificate: cert,
-               IDPMetadata: idpMetadata,
-       })
-}
-
-func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) {
-       samlSp := make(map[string]*samlsp.Middleware, 0)
-       spMu := &sync.Mutex{}
-
-       handle := func(w http.ResponseWriter, r *http.Request) {
-               // ensure host header present
-               host := r.Header.Get("Host")
-               if host == "" {
-                       host = r.Header.Get(":authority")
-               }
-               if host == "" {
-                       host = r.Host
-               }
-               if host == "" {
-                       r.Header.Write(os.Stderr)
-                       fmt.Fprintf(os.Stderr, "%v\n", r.URL.String())
-                       sp.writeError(w, http.StatusBadRequest, errors.New("missing Host header"))
-                       return
-               }
-
-               // ensure client isn't trying to inject saml-related headers
-               if err := sp.checkRequest(r); err != nil {
-                       sp.writeError(w, http.StatusBadRequest, err)
-                       return
-               }
-
-               // make sure this host is known
-               vhost, ok := sp.Listener.VirtualHosts[host]
-               if !ok {
-                       sp.writeError(w, http.StatusMisdirectedRequest,
-                               errors.New("Misdirected request: unknown virtual host"))
-
-                       return
-               }
-
-               // ensure we have SP instance
-               spMu.Lock()
-               if _, ok := samlSp[host]; !ok {
-                       samlSettings := vhost.SAMLServiceProvider
-                       if samlSettings == nil {
-                               samlSettings = sp.Listener.SAMLServiceProvider
-                       }
-                       idpMetadata, err := samlSettings.Metadata()
-                       if err != nil {
-                               sp.writeError(w, http.StatusInternalServerError, err)
-                               return
-                       }
-                       // make sure the browser isn't trying to access the IdP - this can happen if the TLS session
-                       // was reused because our certificate is also valid for the SSO URL.
-                       for _, ssoDesc := range idpMetadata.IDPSSODescriptors {
-                               for _, ssoSvc := range ssoDesc.SingleSignOnServices {
-                                       if loginUrl, err := url.Parse(ssoSvc.Location); err == nil {
-                                               if loginUrl.Host == host {
-                                                       sp.writeError(w, http.StatusMisdirectedRequest,
-                                                               errors.New("Misdirected request: this is not the IDP you're looking for"))
-
-                                                       return
-                                               }
-                                       }
-                               }
-                       }
-                       middleware, err := samlSettings.NewServiceProvider(host)
-                       if err != nil {
-                               sp.writeError(w, http.StatusInternalServerError, err)
-                               return
-                       }
-                       samlSp[host] = middleware
-               }
-               spMu.Unlock()
-
-               provider := samlSp[host]
-               if r.URL.Path == "/saml/acs" {
-                       provider.ServeACS(w, r)
-                       return
-               }
-
-               session, sessionErr := provider.Session.GetSession(r)
-
-               if sessionErr != nil && sessionErr != samlsp.ErrNoSession {
-                       sp.logger.V(2).Warningf("non-NoSession err from sp: %v", sessionErr)
-                       sp.writeError(w, http.StatusBadRequest, sessionErr)
-                       return
-               }
-
-               defaultRoute := true
-               next := sp.fulfill(vhost, session)
-
-               sp.logger.V(3).Debugf("checking for routes matching %s", r.URL)
-               for _, route := range vhost.Routes {
-                       match := false
-                       if route.Path != nil {
-                               match = route.Path.Match(r.URL.Path)
-                               sp.logger.V(3).Debugf("path %s matches %s: %t",
-                                       r.URL.Path, route.Path.String(), match)
-                       } else {
-                               sp.writeError(w, http.StatusInternalServerError,
-                                       errors.New("nothing to match on in route"))
-                       }
-
-                       if match {
-                               defaultRoute = false
-                               if sessionErr == samlsp.ErrNoSession && route.Auth == AuthRequired {
-                                       sp.logger.V(3).Debugf("route requires a valid session, redirecting")
-
-                                       provider.HandleStartAuthFlow(w, r)
-                                       return
-                               }
-
-                               if route.Action != nil {
-                                       sp.logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action)
-                                       route.Action.Handle(w, r, next)
-                               }
-                       }
-               }
-
-               if defaultRoute {
-                       sp.logger.V(3).Debugf("using default route")
-                       if sessionErr == samlsp.ErrNoSession {
-                               sp.logger.V(3).Debugf("default route requires a valid session, redirecting")
-                               provider.HandleStartAuthFlow(w, r)
-                               return
-                       }
-               }
-
-               next(w, r)
-       }
-
-       return handle, nil
-}
-
-func (sp *SAMLProxy) fulfill(vhost *SAMLVirtualHost, session samlsp.Session) http.HandlerFunc {
-       return func(w http.ResponseWriter, r *http.Request) {
-               if session != nil {
-                       sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session)
-               } else {
-                       sp.logger.V(3).Debugf("serving path %s without session", r.URL.Path)
-               }
-
-               newReq := r.Clone(r.Context())
-               newReq.URL.Scheme = "http"
-               if vhost.Backend.Identity != "" {
-                       newReq.URL.Scheme = "https"
-               }
-               newReq.URL.Host = net.JoinHostPort(vhost.Backend.Host, strconv.Itoa(vhost.Backend.Port))
-               newReq.RequestURI = ""
-
-               if swa, ok := session.(samlsp.SessionWithAttributes); ok {
-                       attrs := swa.GetAttributes()
-                       oboHeader := vhost.Backend.UsernameHeader
-                       if oboHeader == "" {
-                               oboHeader = "on-behalf-of"
-                       }
-                       sp.logger.V(3).Debugf("setting origin request header: %s: %q", oboHeader, attrs.Get("uid"))
-                       newReq.Header.Set(oboHeader, attrs.Get("uid"))
-               }
-
-               if jwts, ok := session.(samlsp.JWTSessionClaims); ok {
-                       newReq.Header.Set("x-saml-audience", jwts.StandardClaims.Audience)
-                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-audience", jwts.StandardClaims.Audience)
-
-                       iat := strconv.FormatInt(jwts.StandardClaims.IssuedAt, 10)
-                       newReq.Header.Set("x-saml-issued-at", iat)
-                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-issued-at", iat)
-
-                       eat := strconv.FormatInt(jwts.StandardClaims.ExpiresAt, 10)
-                       newReq.Header.Set("x-saml-expires-at", eat)
-                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-expires-at", eat)
-
-                       newReq.Header.Set("x-saml-subject", jwts.StandardClaims.Subject)
-                       sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject)
-
-                       for attr, values := range jwts.Attributes {
-                               headerName := fmt.Sprintf("x-saml-%s",
-                                       samlAttributeReplaceRegexp.ReplaceAllString(strings.ToLower(attr), "-"))
-                               headerValue := strings.Join(values, ", ")
-                               sp.logger.V(3).Debugf("setting origin request header: %s: %s",
-                                       headerName, headerValue)
-                               newReq.Header.Set(headerName, headerValue)
-                       }
-               }
-
-               // set proxy headers
-               if remoteHost, _, err := net.SplitHostPort(r.RemoteAddr); err != nil {
-                       sp.logger.V(3).Debugf("x-forwarded-for: %s", remoteHost)
-                       newReq.Header.Set("x-forwarded-for", remoteHost)
-               }
-
-               // proxy the request to the backend
-               client, err := vhost.Backend.Client()
-               if err != nil {
-                       sp.writeError(w, http.StatusInternalServerError, fmt.Errorf("error setting up connection to backend: %v", err))
-               }
-               response, err := client.Do(newReq)
-               if err != nil {
-                       sp.writeError(w, http.StatusBadGateway, err)
-                       return
-               }
-
-               for name, value := range response.Header {
-                       w.Header().Set(name, strings.Join(value, ", "))
-               }
-
-               if response.StatusCode == http.StatusSwitchingProtocols {
-                       hijacker, ok := w.(http.Hijacker)
-                       if !ok {
-                               sp.writeError(w, http.StatusMethodNotAllowed, errors.New("websocket passthrough not supported"))
-                               return
-                       }
-
-                       upstreamWriter, ok := response.Body.(io.Writer)
-                       if !ok {
-                               sp.writeError(w, http.StatusMethodNotAllowed, errors.New("body doesn't support io.Writer"))
-                               return
-                       }
-
-                       w.WriteHeader(response.StatusCode)
-
-                       conn, rw, err := hijacker.Hijack()
-                       if err != nil {
-                               sp.writeError(w, http.StatusInternalServerError, err)
-                               return
-                       }
-
-                       wg := sync.WaitGroup{}
-                       wg.Add(2)
-                       pipe := func(w io.Writer, r io.Reader) {
-                               defer wg.Done()
-                               io.Copy(w, r)
-                       }
-                       go pipe(rw, response.Body)
-                       go pipe(upstreamWriter, rw)
-
-                       wg.Wait()
-                       conn.Close()
-                       return
-               }
-
-               w.WriteHeader(response.StatusCode)
-               io.Copy(w, response.Body)
-       }
-}
-
-func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) {
-       sp.logger.V(1).Warningf("returning status: %d %s", status, err.Error())
-
-       w.WriteHeader(status)
-       w.Write([]byte(fmt.Sprintf("<h1>%d %s</h1>", status, err.Error())))
-}
-
-func (sp *SAMLProxy) checkRequest(r *http.Request) error {
-       for k, _ := range r.Header {
-               k = strings.ToLower(k)
-               if strings.HasPrefix(k, "x-saml-") || restrictedHeaders.Contains(k) {
-                       return fmt.Errorf("downstream attempted to overwrite restricted header: %q", k)
-               }
-       }
-       return nil
-}
diff --git a/http/samlproxy/main.go b/http/samlproxy/main.go
deleted file mode 100644 (file)
index 50ff0ab..0000000
+++ /dev/null
@@ -1,86 +0,0 @@
-package main
-
-import (
-       "context"
-       "flag"
-       "os"
-       "os/signal"
-       "syscall"
-       "time"
-
-       "github.com/coreos/go-systemd/daemon"
-       "gopkg.in/yaml.v3"
-
-       "go.fuhry.dev/runtime/http"
-       "go.fuhry.dev/runtime/mtls"
-       "go.fuhry.dev/runtime/utils/log"
-)
-
-func main() {
-       mtls.SetDefaultIdentity("authproxy")
-
-       sp := &http.SAMLProxy{
-               Listener: http.SAMLListener{
-                       SAMLServiceProvider: &http.SAMLServiceProvider{},
-               },
-       }
-       vhost := &http.SAMLVirtualHost{
-               Backend: &http.SAMLBackend{},
-       }
-
-       loadConfig := func(arg string) error {
-               contents, err := os.ReadFile(arg)
-               if err != nil {
-                       return err
-               }
-
-               err = yaml.Unmarshal(contents, sp)
-               return err
-       }
-       addRoute := func(arg string) error {
-               route, err := http.RouteFromArg(arg)
-               if err != nil {
-                       return err
-               }
-               vhost.Routes = append(vhost.Routes, route)
-               return nil
-       }
-
-       vhostName := flag.String("vhost", "", "HTTP(S) hostname to serve")
-       flag.Func("config", "YAML file to load configuration from", loadConfig)
-       flag.Func("route", "Route rule in the format of auth:field:matcher:value\n"+
-               "  auth: required, optional\n"+
-               "  field: path\n"+
-               "  matcher: prefix, suffix, exact, contains, regexp\n"+
-               "  value: any string", addRoute)
-       flag.StringVar(&sp.Listener.EntityID, "saml.entity-id", "", "entity ID of SAML service provider")
-       flag.StringVar(&sp.Listener.IDP, "saml.idp.url", "", "URL to IdP metadata")
-       flag.StringVar(&sp.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private")
-       flag.StringVar(&vhost.Backend.Host, "backend.host", "127.0.0.1", "backend host")
-       flag.IntVar(&vhost.Backend.Port, "backend.port", 0, "backend port")
-       flag.StringVar(&vhost.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend")
-       flag.StringVar(&sp.Listener.Addr, "listen", "[::]:8443", "address for auth proxy to listen on")
-       flag.StringVar(&sp.Listener.InsecureAddr, "listen.http", "[::]:8080", "address for http-to-https redirector")
-
-       flag.Parse()
-
-       sp.Listener.VirtualHosts[*vhostName] = vhost
-
-       ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
-       defer cancel()
-       server, err := sp.NewHTTPServerWithContext(ctx)
-       if err != nil {
-               log.Panic(err)
-       }
-       go server.ListenAndServeTLS("", "")
-
-       unsecureServer := sp.NewHTTPSRedirectorWithContext(ctx)
-       go unsecureServer.ListenAndServe()
-
-       daemon.SdNotify(false, daemon.SdNotifyReady)
-
-       <-ctx.Done()
-       shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
-       defer shutdownCancel()
-       server.Shutdown(shutdownCtx)
-}
diff --git a/http/server.go b/http/server.go
new file mode 100644 (file)
index 0000000..9f88be1
--- /dev/null
@@ -0,0 +1,282 @@
+package http
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "net"
+       "net/http"
+       "os"
+
+       "go.fuhry.dev/runtime/mtls"
+       "go.fuhry.dev/runtime/utils/log"
+       "go.fuhry.dev/runtime/utils/stringmatch"
+       "gopkg.in/yaml.v3"
+)
+
+type serverCtxVar int
+
+type RouteAction interface {
+       Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc)
+}
+
+type Route struct {
+       Path   stringmatch.StringMatcher
+       Action RouteAction
+}
+
+type VirtualHost struct {
+       Routes []*Route `yaml:"routes"`
+}
+
+type Listener struct {
+       Addr         string                  `yaml:"listen"`
+       InsecureAddr string                  `yaml:"listen_insecure"`
+       Certificate  string                  `yaml:"cert"`
+       VirtualHosts map[string]*VirtualHost `yaml:"virtual_hosts"`
+}
+
+type Server struct {
+       Listener *Listener       `yaml:"listener"`
+       Context  context.Context `yaml:"-"`
+}
+
+type initHook func(context.Context, *yaml.Node) (context.Context, error)
+type routeParseFunc func(*yaml.Node) (RouteAction, error)
+
+const (
+       kLogger serverCtxVar = iota
+       kListener
+       kListenAddr
+       kSamlDefaults
+)
+
+var initHooks []initHook
+var routeParseFuncs []routeParseFunc
+
+func AddServerInitHook(hook initHook) {
+       initHooks = append(initHooks, hook)
+}
+
+func AddRouteParseFunc(rpf routeParseFunc) {
+       routeParseFuncs = append(routeParseFuncs, rpf)
+}
+
+func NewServer() *Server {
+       return NewServerWithContext(context.Background())
+}
+
+func NewServerWithContext(ctx context.Context) *Server {
+       logger := log.WithPrefix(fmt.Sprintf("%T", &Server{}))
+
+       return &Server{
+               Listener: &Listener{
+                       VirtualHosts: make(map[string]*VirtualHost, 0),
+               },
+               Context: context.WithValue(ctx, kLogger, logger),
+       }
+}
+
+// UnmarshalYAML implements yaml.Unmarshaler
+func (r *Route) UnmarshalYAML(node *yaml.Node) error {
+       var rawNode struct {
+               Path *stringmatch.MatchRule `yaml:"path"`
+       }
+
+       if err := node.Decode(&rawNode); err != nil {
+               return err
+       }
+
+       if rawNode.Path != nil {
+               m, err := rawNode.Path.Matcher()
+               if err != nil {
+                       return fmt.Errorf("error unmarshaling route: invalid path matcher: %v", err)
+               }
+               r.Path = m
+       } else {
+               return errors.New("error unmarshaling route: path must be specified")
+       }
+
+       for _, rpf := range routeParseFuncs {
+               action, err := rpf(node)
+               if err != nil {
+                       return err
+               }
+               if action != nil {
+                       r.Action = action
+                       break
+               }
+       }
+
+       return nil
+}
+
+// UnmarshalYAML implements yaml.Unmarshaler
+func (s *Server) UnmarshalYAML(node *yaml.Node) error {
+       lc := &struct {
+               Listener *Listener `yaml:"listener"`
+       }{}
+
+       if s.Context == nil {
+               s.Context = context.Background()
+       }
+
+       if err := node.Decode(&lc); err != nil {
+               return err
+       }
+
+       s.Listener = lc.Listener
+
+       for _, initHook := range initHooks {
+               newCtx, err := initHook(s.Context, node)
+               if err != nil {
+                       return err
+               }
+               s.Context = newCtx
+       }
+
+       return nil
+}
+
+func (s *Server) Create() (*http.Server, error) {
+       listenerCtx := context.WithValue(s.Context, kListener, s.Listener)
+       return s.Listener.NewHTTPServerWithContext(listenerCtx)
+}
+
+func (s *Server) CreateInsecure() *http.Server {
+       listenerCtx := context.WithValue(s.Context, kListener, s.Listener)
+       return s.Listener.NewHTTPSRedirectorWithContext(listenerCtx)
+}
+
+// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host
+// and other settings.
+func (l *Listener) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) {
+       if l.Addr == "" {
+               l.Addr = "[::]:8443"
+       }
+
+       logger := LoggerFromContext(ctx).WithPrefix(fmt.Sprintf("%T(%s)", l, l.Addr))
+       serverCtx := context.WithValue(ctx, kLogger, logger)
+
+       lm := log.NewLoggingMiddlewareWithLogger(
+               http.HandlerFunc(l.handle),
+               logger.AppendPrefix("access"))
+
+       server := &http.Server{
+               Addr: l.Addr,
+               BaseContext: func(l net.Listener) context.Context {
+                       return context.WithValue(serverCtx, kListenAddr, l.Addr())
+               },
+               Handler: lm.HandlerFunc(),
+       }
+
+       if l.Certificate != "" {
+               cert := mtls.NewSSLCertificate(l.Certificate)
+               tlsConfig, err := cert.TlsConfig(serverCtx)
+               if err != nil {
+                       return nil, err
+               }
+               server.TLSConfig = tlsConfig
+       }
+
+       return server, nil
+}
+
+func (l *Listener) NewHTTPSRedirectorWithContext(ctx context.Context) *http.Server {
+       if l.InsecureAddr == "" {
+               l.InsecureAddr = "[::]:8080"
+       }
+
+       server := &http.Server{
+               Addr: l.InsecureAddr,
+               Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+                       host := r.Host
+                       if host == "" {
+                               w.WriteHeader(http.StatusBadRequest)
+                               return
+                       }
+
+                       if _, ok := l.VirtualHosts[host]; !ok {
+                               w.WriteHeader(http.StatusMisdirectedRequest)
+                               return
+                       }
+
+                       newUrl := *r.URL
+                       newUrl.Scheme = "https"
+                       newUrl.Host = host
+                       w.Header().Set("location", newUrl.String())
+                       w.WriteHeader(http.StatusFound)
+               }),
+       }
+
+       return server
+}
+
+func (l *Listener) handle(w http.ResponseWriter, r *http.Request) {
+       // ensure host header present
+       if r.Host == "" {
+               r.Header.Write(os.Stderr)
+               http.Error(w, "missing Host header", http.StatusBadRequest)
+               return
+       }
+
+       // make sure this host is known
+       vhost, ok := l.VirtualHosts[r.Host]
+       if !ok {
+               http.Error(w, "Misdirected request: unknown virtual host",
+                       http.StatusMisdirectedRequest)
+
+               return
+       }
+
+       l.fulfill(w, r, vhost.Routes)
+}
+
+func (l *Listener) fulfill(w http.ResponseWriter, r *http.Request, routes []*Route) {
+       logger := LoggerFromContext(r.Context())
+       if logger == nil {
+               http.Error(w, "cannot get logger", http.StatusInternalServerError)
+       }
+
+       logger.V(3).Debugf("checking for routes matching %s", r.URL)
+       for i, route := range routes {
+               match := false
+               if route.Path != nil {
+                       match = route.Path.Match(r.URL.Path)
+                       logger.V(3).Debugf("path %s matches %s: %t",
+                               r.URL.Path, route.Path.String(), match)
+               } else {
+                       http.Error(w, "nothing to match on in route", http.StatusInternalServerError)
+               }
+
+               if match {
+                       if route.Action != nil {
+                               logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action)
+                               next := http.NotFound
+                               if len(routes) > i {
+                                       next = func(w http.ResponseWriter, r *http.Request) {
+                                               logger.V(3).Debugf("%T called next(), continuing request processing", route.Action)
+                                               l.fulfill(w, r, routes[i+1:])
+                                       }
+                               }
+                               route.Action.Handle(w, r, next)
+                               return
+                       } else {
+                               http.Error(w,
+                                       fmt.Sprintf("no action configured for route %s", route.Path.String()),
+                                       http.StatusInternalServerError)
+                       }
+               }
+       }
+
+       http.NotFound(w, r)
+}
+
+func LoggerFromContext(ctx context.Context) log.Logger {
+       l := ctx.Value(kLogger)
+       if logger, ok := l.(log.Logger); ok {
+               return logger
+       }
+
+       return nil
+}