From 8e994b9208488178e2d9bb0d53d815452c3064b8 Mon Sep 17 00:00:00 2001 From: Dan Fuhry Date: Fri, 28 Mar 2025 23:38:07 -0400 Subject: [PATCH] Fine let's just make it a full HTTP proxy Refactor samlproxy into a general purpose proxy with pluggable actions. Add S3 bucket serving backend. Route actions can fulfill the request or modify it and call next(), basically the same idea as coredns but for http. Backwards incompatible with existing configs. --- .gitignore | 2 +- go.mod | 28 +- go.sum | 34 + http/{samlproxy => proxy}/Makefile | 0 http/proxy/main.go | 64 ++ .../systemd/http-proxy@.service} | 0 http/route_action_proxy.go | 153 ++++ http/route_action_redirect.go | 66 ++ http/route_action_s3.go | 132 ++++ http/route_action_saml.go | 322 ++++++++ http/samlproxy.go | 694 ------------------ http/samlproxy/main.go | 86 --- http/server.go | 282 +++++++ 13 files changed, 1074 insertions(+), 789 deletions(-) rename http/{samlproxy => proxy}/Makefile (100%) create mode 100644 http/proxy/main.go rename http/{samlproxy/systemd/saml-proxy@.service => proxy/systemd/http-proxy@.service} (100%) create mode 100644 http/route_action_proxy.go create mode 100644 http/route_action_redirect.go create mode 100644 http/route_action_s3.go create mode 100644 http/route_action_saml.go delete mode 100644 http/samlproxy.go delete mode 100644 http/samlproxy/main.go create mode 100644 http/server.go diff --git a/.gitignore b/.gitignore index 50a6a0b..43c001f 100644 --- a/.gitignore +++ b/.gitignore @@ -49,7 +49,7 @@ mtls/verify_tool/verify_tool ldap/health_exporter/health_exporter envoy/xds/envoy_xds/envoy_xds mtls/mtls_exporter/mtls_exporter -http/samlproxy/samlproxy +http/proxy/proxy automation/bryston_ctl/cli/cli automation/bryston_ctl/client/client automation/bryston_ctl/server/server diff --git a/go.mod b/go.mod index 8abdf6b..8b79995 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,16 @@ module go.fuhry.dev/runtime -go 1.23 +go 1.23.0 + +toolchain go1.24.1 require ( github.com/google/certificate-transparency-go v1.1.4 github.com/google/go-attestation v0.4.3 github.com/google/go-tpm v0.3.3 // indirect github.com/google/go-tspi v0.2.1-0.20190423175329-115dea689aad // indirect - golang.org/x/crypto v0.25.0 // indirect - golang.org/x/sys v0.22.0 + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/sys v0.31.0 ) require ( @@ -39,8 +41,8 @@ require ( go.fuhry.dev/fsnotify v1.7.2 go.fuhry.dev/grpc-quic v0.1.2 golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d - golang.org/x/sync v0.7.0 - golang.org/x/term v0.22.0 + golang.org/x/sync v0.12.0 + golang.org/x/term v0.30.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c gopkg.in/ini.v1 v1.67.0 gopkg.in/yaml.v3 v3.0.1 @@ -66,11 +68,14 @@ require ( github.com/felixge/httpsnoop v1.0.1 // indirect github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect + github.com/go-ini/ini v1.67.0 // indirect github.com/go-logfmt/logfmt v0.5.1 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/goccy/go-json v0.10.5 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/gomodule/redigo v1.8.2 // indirect github.com/google/pprof v0.0.0-20230509042627-b1315fad0c5a // indirect + github.com/google/uuid v1.6.0 // indirect github.com/goph/emperror v0.17.2 // indirect github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/mux v1.8.0 // indirect @@ -78,11 +83,17 @@ require ( github.com/jonboulle/clockwork v0.3.0 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.16.5 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect github.com/miekg/pkcs11 v1.1.1 // indirect + github.com/minio/crc64nvme v1.0.1 // indirect + github.com/minio/md5-simd v1.1.2 // indirect + github.com/minio/minio-go v6.0.14+incompatible // indirect + github.com/minio/minio-go/v7 v7.0.89 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect @@ -98,6 +109,7 @@ require ( github.com/prometheus/procfs v0.12.0 // indirect github.com/quic-go/qtls-go1-20 v0.3.4 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/rs/xid v1.6.0 // indirect github.com/russellhaering/goxmldsig v1.3.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/thales-e-security/pool v0.0.2 // indirect @@ -128,8 +140,8 @@ require ( go.uber.org/atomic v1.11.0 go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.21.0 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/text v0.16.0 + golang.org/x/net v0.37.0 // indirect + golang.org/x/text v0.23.0 google.golang.org/genproto v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.34.1 diff --git a/go.sum b/go.sum index bedfe3b..b7cfe35 100644 --- a/go.sum +++ b/go.sum @@ -171,6 +171,8 @@ github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclK github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= +github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= @@ -188,6 +190,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -279,6 +283,8 @@ github.com/google/trillian v1.3.11/go.mod h1:0tPraVHrSDkA3BO6vKX67zgLXs6SsOAbHEi github.com/google/uuid v0.0.0-20161128191214-064e2069ce9c/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= @@ -362,6 +368,11 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.16.5 h1:IFV2oUNUzZaz+XyusxpLzpzS8Pt5rh0Z16For/djlyI= github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.10 h1:tBs3QSyvjDyFTq3uoc/9xFpCuOsJQFNPiAhYdw2skhE= +github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -408,9 +419,18 @@ github.com/miekg/pkcs11 v1.0.3-0.20190429190417-a667d056470f/go.mod h1:XsNlhZGX7 github.com/miekg/pkcs11 v1.0.3/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/miekg/pkcs11 v1.1.1 h1:Ugu9pdy6vAYku5DEpVWVFPYnzV+bxB+iRdbuFSu7TvU= github.com/miekg/pkcs11 v1.1.1/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= +github.com/minio/crc64nvme v1.0.1 h1:DHQPrYPdqK7jQG/Ls5CTBZWeex/2FMS3G5XGkycuFrY= +github.com/minio/crc64nvme v1.0.1/go.mod h1:eVfm2fAzLlxMdUGc0EEBGSMmPwmXD5XiNRpnu9J3bvg= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go v6.0.14+incompatible h1:fnV+GD28LeqdN6vT2XdGKW8Qe/IfjJDswNVuni6km9o= +github.com/minio/minio-go v6.0.14+incompatible/go.mod h1:7guKYtitv8dktvNUGrhzmNlA5wrAABTQXCoesZdFQO8= +github.com/minio/minio-go/v7 v7.0.89 h1:hx4xV5wwTUfyv8LarhJAwNecnXpoTsj9v3f3q/ZkiJU= +github.com/minio/minio-go/v7 v7.0.89/go.mod h1:2rFnGAp02p7Dddo1Fq4S2wYOfpF0MUTSeLTRC90I204= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc= github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= @@ -518,6 +538,8 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= @@ -651,6 +673,8 @@ golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -727,6 +751,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -747,6 +773,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -799,9 +827,13 @@ golang.org/x/sys v0.0.0-20210629170331-7dc0b73dc9fb/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -810,6 +842,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/http/samlproxy/Makefile b/http/proxy/Makefile similarity index 100% rename from http/samlproxy/Makefile rename to http/proxy/Makefile diff --git a/http/proxy/main.go b/http/proxy/main.go new file mode 100644 index 0000000..c9c727b --- /dev/null +++ b/http/proxy/main.go @@ -0,0 +1,64 @@ +package main + +import ( + "context" + "flag" + "os" + "os/signal" + "syscall" + "time" + + "github.com/coreos/go-systemd/daemon" + "gopkg.in/yaml.v3" + + "go.fuhry.dev/runtime/http" + "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/utils/log" +) + +func main() { + mtls.SetDefaultIdentity("authproxy") + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + server := http.NewServerWithContext(ctx) + + loadConfig := func(arg string) error { + contents, err := os.ReadFile(arg) + if err != nil { + return err + } + + err = yaml.Unmarshal(contents, server) + return err + } + + flag.Func("config", "YAML file to load configuration from", loadConfig) + flag.StringVar(&server.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private") + flag.StringVar(&server.Listener.Addr, "listen", "[::]:8443", "address for auth proxy to listen on") + flag.StringVar(&server.Listener.InsecureAddr, "listen.http", "[::]:8080", "address for http-to-https redirector") + + flag.Parse() + + httpServer, err := server.Create() + if err != nil { + log.Panic(err) + } + go httpServer.ListenAndServeTLS("", "") + + log.Default().Infof("listening on HTTPS at %s", server.Listener.Addr) + + unsecureServer := server.CreateInsecure() + go unsecureServer.ListenAndServe() + + log.Default().Infof("listening on HTTP at %s (redirects to HTTPS only)", server.Listener.InsecureAddr) + + daemon.SdNotify(false, daemon.SdNotifyReady) + + <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + httpServer.Shutdown(shutdownCtx) + unsecureServer.Shutdown(shutdownCtx) +} diff --git a/http/samlproxy/systemd/saml-proxy@.service b/http/proxy/systemd/http-proxy@.service similarity index 100% rename from http/samlproxy/systemd/saml-proxy@.service rename to http/proxy/systemd/http-proxy@.service diff --git a/http/route_action_proxy.go b/http/route_action_proxy.go new file mode 100644 index 0000000..83dccae --- /dev/null +++ b/http/route_action_proxy.go @@ -0,0 +1,153 @@ +package http + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + + "gopkg.in/yaml.v3" + + "go.fuhry.dev/runtime/mtls" +) + +type staticUpstreamAction struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Identity string `yaml:"mtls_id"` + + client *http.Client + clientOnce sync.Once +} + +// httpClient returns an HTTP client for making requests to the backend. +func (su *staticUpstreamAction) httpClient() (*http.Client, error) { + var err error + su.clientOnce.Do(func() { + transport := &http.Transport{} + var tlsConfig *tls.Config + + if su.Identity != "" { + myIdentity := mtls.DefaultIdentity() + tlsConfig, err = myIdentity.TlsConfig(context.Background()) + if err != nil { + return + } + + verifier := mtls.NewPeerNameVerifier() + verifier.AllowFrom(mtls.Service, su.Identity) + err = verifier.ConfigureClient(tlsConfig) + if err != nil { + return + } + + transport.TLSClientConfig = tlsConfig + } + + client := &http.Client{ + Transport: transport, + } + + su.client = client + }) + if err != nil { + return nil, err + } + return su.client, nil +} + +// Handle implements RouteAction. +func (su *staticUpstreamAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + logger := LoggerFromContext(r.Context()) + + upstreamReq := r.Clone(r.Context()) + upstreamReq.URL.Scheme = "http" + if su.Identity != "" { + upstreamReq.URL.Scheme = "https" + } + upstreamReq.URL.Host = net.JoinHostPort(su.Host, strconv.Itoa(su.Port)) + upstreamReq.RequestURI = "" + + // set proxy headers + if remoteHost, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + logger.V(3).Debugf("x-forwarded-for: %s", remoteHost) + upstreamReq.Header.Set("x-forwarded-for", remoteHost) + } + + // proxy the request to the backend + client, err := su.httpClient() + if err != nil { + http.Error(w, + fmt.Sprintf("error setting up connection to backend: %v", err), + http.StatusInternalServerError) + } + response, err := client.Do(upstreamReq) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + + for name, value := range response.Header { + w.Header().Set(name, strings.Join(value, ", ")) + } + + if response.StatusCode == http.StatusSwitchingProtocols { + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "websocket passthrough not supported", http.StatusMethodNotAllowed) + return + } + + upstreamWriter, ok := response.Body.(io.Writer) + if !ok { + http.Error(w, "body doesn't support io.Writer", http.StatusMethodNotAllowed) + return + } + + w.WriteHeader(response.StatusCode) + + conn, rw, err := hijacker.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + wg := sync.WaitGroup{} + wg.Add(2) + pipe := func(w io.Writer, r io.Reader) { + defer wg.Done() + io.Copy(w, r) + } + go pipe(rw, response.Body) + go pipe(upstreamWriter, rw) + + wg.Wait() + conn.Close() + return + } + + w.WriteHeader(response.StatusCode) + io.Copy(w, response.Body) +} + +func staticUpstreamActionFromYaml(node *yaml.Node) (RouteAction, error) { + var rawNode struct { + Proxy *staticUpstreamAction `yaml:"proxy"` + } + + err := node.Decode(&rawNode) + if err != nil || rawNode.Proxy == nil { + return nil, nil + } + + return rawNode.Proxy, nil +} + +func init() { + AddRouteParseFunc(staticUpstreamActionFromYaml) +} diff --git a/http/route_action_redirect.go b/http/route_action_redirect.go new file mode 100644 index 0000000..2197231 --- /dev/null +++ b/http/route_action_redirect.go @@ -0,0 +1,66 @@ +package http + +import ( + "net/http" + "net/url" + + "gopkg.in/yaml.v3" +) + +type RedirectAction struct { + StatusCode int + Destination *url.URL +} + +// Handle implements RouteAction +func (a *RedirectAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + newUrl := *r.URL + + if a.Destination.Host != "" { + newUrl.Host = a.Destination.Host + } + newUrl.Path = a.Destination.Path + if a.Destination.RawQuery != "" { + newUrl.RawQuery = a.Destination.RawQuery + } + if a.Destination.Scheme != "" { + newUrl.Scheme = a.Destination.Scheme + } + if a.Destination.Fragment != "" { + newUrl.Fragment = a.Destination.Fragment + } + + status := a.StatusCode + if status == 0 { + status = http.StatusFound + } + + w.Header().Set("location", newUrl.String()) + w.WriteHeader(status) +} + +func redirectFromRouteYaml(node *yaml.Node) (RouteAction, error) { + var rawNode struct { + Redirect *struct { + Destination string `yaml:"dest"` + Status int `yaml:"status"` + } `yaml:"redirect,omitempty"` + } + + if err := node.Decode(&rawNode); err == nil && rawNode.Redirect != nil { + u, err := url.Parse(rawNode.Redirect.Destination) + if err != nil { + return nil, err + } + return &RedirectAction{ + Destination: u, + StatusCode: rawNode.Redirect.Status, + }, nil + } + + return nil, nil +} + +func init() { + AddRouteParseFunc(redirectFromRouteYaml) +} diff --git a/http/route_action_s3.go b/http/route_action_s3.go new file mode 100644 index 0000000..49bf85b --- /dev/null +++ b/http/route_action_s3.go @@ -0,0 +1,132 @@ +package http + +import ( + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "gopkg.in/yaml.v3" +) + +type S3Action struct { + S3Endpoint string `yaml:"endpoint"` + S3AccessKey string `yaml:"access_key"` + S3SecretKey string `yaml:"secret_key"` + BucketName string `yaml:"bucket"` + ObjectPrefix string `yaml:"prefix"` + StripPrefix string `yaml:"strip_prefix"` + + mc *minio.Client + mcOnce sync.Once +} + +// Handle implements RouteAction +func (a *S3Action) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + mc, err := a.minioClient() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf("failed to init minio client: %+v", err))) + return + } + + reqPath := strings.TrimPrefix(r.URL.Path, a.StripPrefix) + objPath := "/" + strings.Trim(strings.Join( + []string{ + strings.Trim(a.ObjectPrefix, "/"), + strings.Trim(reqPath, "/"), + }, + "/"), "/") + + object, err := mc.GetObject( + r.Context(), + a.BucketName, + objPath, + minio.GetObjectOptions{}) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf("failed to GetObject %q: %+v", reqPath, err))) + return + } + + stat, err := object.Stat() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf("failed to stat object %q: %+v", objPath, err))) + return + } + + w.Header().Set("content-type", stat.ContentType) + w.WriteHeader(http.StatusOK) + + if strings.HasPrefix(stat.ContentType, "text/") || strings.HasPrefix(stat.ContentType, "application/") { + buf := make([]byte, 32*1024) + var out []byte + for { + nr, err := object.Read(buf) + if nr == 0 && err != nil { + break + } + end := nr + for i := nr; i > 0; i-- { + if buf[i-1] != '\000' { + end = i + break + } + } + out = append(out, buf[:end]...) + + if err == io.EOF { + break + } + } + w.Header().Set("content-length", strconv.Itoa(len(out))) + for nw := 0; nw < len(out); { + c, err := w.Write(out[nw:]) + if err != nil { + break + } + nw += c + } + } else { + io.Copy(w, object) + } +} + +func (a *S3Action) minioClient() (mc *minio.Client, err error) { + a.mcOnce.Do(func() { + mc, err = minio.New( + a.S3Endpoint, + &minio.Options{ + Creds: credentials.NewStaticV4(a.S3AccessKey, a.S3SecretKey, ""), + Secure: true, + BucketLookup: minio.BucketLookupDNS, + }, + ) + + a.mc = mc + }) + + return a.mc, err +} + +func s3ActionFromRouteYaml(node *yaml.Node) (RouteAction, error) { + var rawNode struct { + S3 *S3Action `yaml:"s3,omitempty"` + } + + if err := node.Decode(&rawNode); err == nil && rawNode.S3 != nil { + return rawNode.S3, nil + } + + return nil, nil +} + +func init() { + AddRouteParseFunc(s3ActionFromRouteYaml) +} diff --git a/http/route_action_saml.go b/http/route_action_saml.go new file mode 100644 index 0000000..eb0d829 --- /dev/null +++ b/http/route_action_saml.go @@ -0,0 +1,322 @@ +package http + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" + + "go.fuhry.dev/runtime/mtls/certutil" + "go.fuhry.dev/runtime/utils/hashset" + "gopkg.in/yaml.v3" +) + +type SAMLServiceProvider struct { + EntityID string `yaml:"entity_id"` + EntityCertificate string `yaml:"entity_certificate"` + EntityPrivateKey string `yaml:"entity_key"` + IDP string `yaml:"idp"` + + metadata *saml.EntityDescriptor + metadataOnce sync.Once + + entityCert *x509.Certificate + entityKey *rsa.PrivateKey + certKeyOnce sync.Once + spMu sync.Mutex + mw map[string]*samlsp.Middleware +} + +type samlAction struct { + sp *SAMLServiceProvider + requireAuth bool + usernameHeader string +} + +var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`) +var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"}) + +func (sa *samlAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + logger := LoggerFromContext(r.Context()) + + // ensure client isn't trying to inject saml-related headers + if err := sa.checkRequest(r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + sp := sa.sp + if sp == nil { + sp = serverDefaultSamlConfig(r.Context()) + } + if sp == nil { + http.Error(w, "SAML auth requested but no SP config present", + http.StatusInternalServerError) + } + provider, err := sp.getMiddleware(r.Host) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if r.URL.Path == "/saml/acs" { + provider.ServeACS(w, r) + return + } + + session, sessionErr := provider.Session.GetSession(r) + + if sessionErr != nil && sessionErr != samlsp.ErrNoSession { + http.Error(w, sessionErr.Error(), http.StatusBadRequest) + return + } + + if sessionErr == samlsp.ErrNoSession && sa.requireAuth { + logger.V(3).Debugf("route requires a valid session, redirecting") + + provider.HandleStartAuthFlow(w, r) + return + } + + if swa, ok := session.(samlsp.SessionWithAttributes); ok { + attrs := swa.GetAttributes() + oboHeader := sa.usernameHeader + if oboHeader == "" { + oboHeader = "on-behalf-of" + } + logger.V(3).Debugf("setting origin request header: %s: %q", oboHeader, attrs.Get("uid")) + r.Header.Set(oboHeader, attrs.Get("uid")) + } + + if jwts, ok := session.(samlsp.JWTSessionClaims); ok { + r.Header.Set("x-saml-audience", jwts.StandardClaims.Audience) + logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-audience", jwts.StandardClaims.Audience) + + iat := strconv.FormatInt(jwts.StandardClaims.IssuedAt, 10) + r.Header.Set("x-saml-issued-at", iat) + logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-issued-at", iat) + + eat := strconv.FormatInt(jwts.StandardClaims.ExpiresAt, 10) + r.Header.Set("x-saml-expires-at", eat) + logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-expires-at", eat) + + r.Header.Set("x-saml-subject", jwts.StandardClaims.Subject) + logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject) + + for attr, values := range jwts.Attributes { + headerName := fmt.Sprintf("x-saml-%s", + samlAttributeReplaceRegexp.ReplaceAllString(strings.ToLower(attr), "-")) + headerValue := strings.Join(values, ", ") + logger.V(3).Debugf("setting origin request header: %s: %s", + headerName, headerValue) + r.Header.Set(headerName, headerValue) + } + } else { + r.Header.Set("x-saml-anonymous-auth", "1") + } + + next(w, r) +} + +func (sa *samlAction) checkRequest(r *http.Request) error { + for k, _ := range r.Header { + k = strings.ToLower(k) + if strings.HasPrefix(k, "x-saml-") || restrictedHeaders.Contains(k) { + return fmt.Errorf("downstream attempted to overwrite restricted header: %q", k) + } + } + return nil +} + +func (sp *SAMLServiceProvider) Metadata() (*saml.EntityDescriptor, error) { + var err error + sp.metadataOnce.Do(func() { + var idpMetadataUrl *url.URL + idpMetadataUrl, err = url.Parse(sp.IDP) + if err != nil { + return + } + + sp.metadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataUrl) + }) + + if err != nil { + sp.metadataOnce = sync.Once{} + return nil, err + } + + return sp.metadata, nil +} + +func (sp *SAMLServiceProvider) certAndKey() (cert *x509.Certificate, pvk *rsa.PrivateKey, err error) { + sp.certKeyOnce.Do(func() { + if sp.EntityPrivateKey != "" { + var loadedKey crypto.PrivateKey + loadedKey, err = certutil.LoadPrivateKeyFromPEM(sp.EntityPrivateKey) + if err != nil { + return + } + var ok bool + pvk, ok = loadedKey.(*rsa.PrivateKey) + if !ok { + err = fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk) + return + } + } else { + // generate new RSA private key + pvk, err = rsa.GenerateKey(saml.RandReader, 2048) + if err != nil { + return + } + } + if sp.EntityCertificate != "" { + certs, err := certutil.LoadCertificatesFromPEM(sp.EntityCertificate) + if err != nil { + return + } + cert = certs[0] + } else { + // generate new self-signed X509 certificate + serialBytes := make([]byte, 16) + saml.RandReader.Read(serialBytes) + serial := big.NewInt(0) + serial.SetBytes(serialBytes) + + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: sp.EntityID, + }, + SerialNumber: serial, + BasicConstraintsValid: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, + IsCA: false, + NotBefore: time.Now(), + NotAfter: time.Now().Add(90 * 86400 * time.Second), + } + certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &pvk.PublicKey, pvk) + if err != nil { + return + } + cert, err = x509.ParseCertificate(certBytes) + if err != nil { + return + } + } + + sp.entityCert = cert + sp.entityKey = pvk + }) + + if err != nil { + sp.metadataOnce = sync.Once{} + return nil, nil, err + } + + return sp.entityCert, sp.entityKey, nil +} + +func (sp *SAMLServiceProvider) getMiddleware(host string) (*samlsp.Middleware, error) { + sp.spMu.Lock() + defer sp.spMu.Unlock() + + if sp.mw == nil { + sp.mw = make(map[string]*samlsp.Middleware) + } + + if _, ok := sp.mw[host]; !ok { + mw, err := sp.newMiddleware(host) + if err != nil { + return nil, err + } + sp.mw[host] = mw + } + + return sp.mw[host], nil +} + +func (sp *SAMLServiceProvider) newMiddleware(host string) (*samlsp.Middleware, error) { + idpMetadata, err := sp.Metadata() + if err != nil { + return nil, err + } + cert, key, err := sp.certAndKey() + if err != nil { + return nil, err + } + return samlsp.New(samlsp.Options{ + EntityID: sp.EntityID, + URL: url.URL{ + Scheme: "https", + Host: host, + }, + Key: key, + Certificate: cert, + IDPMetadata: idpMetadata, + }) +} + +func samlInitHook(ctx context.Context, node *yaml.Node) (context.Context, error) { + var samlConfig struct { + SP *SAMLServiceProvider `yaml:"saml"` + } + + if err := node.Decode(&samlConfig); err == nil && samlConfig.SP != nil { + ctx = context.WithValue(ctx, kSamlDefaults, samlConfig.SP) + } + + return ctx, nil +} + +func serverDefaultSamlConfig(ctx context.Context) *SAMLServiceProvider { + v := ctx.Value(kSamlDefaults) + if c, ok := v.(*SAMLServiceProvider); ok { + return c + } + return nil +} + +func samlActionFromRouteYaml(node *yaml.Node) (RouteAction, error) { + var rawNode struct { + SP *struct { + *SAMLServiceProvider + Require string `yaml:"require"` + } `yaml:"saml,omitempty"` + Auth string `yaml:"auth"` + } + + err := node.Decode(&rawNode) + if err != nil || rawNode.Auth != "saml" { + return nil, nil + } + + require, err := strconv.ParseBool(rawNode.SP.Require) + if err != nil { + return nil, err + } + + sa := &samlAction{ + sp: rawNode.SP.SAMLServiceProvider, + requireAuth: require, + } + + return sa, nil +} + +func init() { + AddServerInitHook(samlInitHook) + AddRouteParseFunc(samlActionFromRouteYaml) +} diff --git a/http/samlproxy.go b/http/samlproxy.go deleted file mode 100644 index 489f5d2..0000000 --- a/http/samlproxy.go +++ /dev/null @@ -1,694 +0,0 @@ -package http - -import ( - "context" - "crypto" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "fmt" - "io" - "math/big" - "net" - "net/http" - "net/url" - "os" - "regexp" - "strconv" - "strings" - "sync" - "time" - - "github.com/crewjam/saml" - "github.com/crewjam/saml/samlsp" - - "go.fuhry.dev/runtime/mtls" - "go.fuhry.dev/runtime/mtls/certutil" - "go.fuhry.dev/runtime/utils/hashset" - "go.fuhry.dev/runtime/utils/log" - "go.fuhry.dev/runtime/utils/stringmatch" - "gopkg.in/yaml.v3" -) - -type authEnforcement uint - -type RouteAction interface { - Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) -} - -type RedirectAction struct { - StatusCode int - Destination *url.URL -} - -type Route struct { - Auth authEnforcement - Path stringmatch.StringMatcher - Action RouteAction -} - -type SAMLBackend struct { - Host string `yaml:"host"` - Port int `yaml:"port"` - Identity string `yaml:"mtls_id"` - UsernameHeader string `yaml:"username_header"` - - client *http.Client - clientOnce sync.Once -} - -type SAMLVirtualHost struct { - *SAMLServiceProvider `yaml:"saml"` - - Backend *SAMLBackend `yaml:"backend"` - Routes []*Route `yaml:"routes"` -} - -type SAMLServiceProvider struct { - EntityID string `yaml:"entity_id"` - EntityCertificate string `yaml:"entity_certificate"` - EntityPrivateKey string `yaml:"entity_key"` - IDP string `yaml:"idp"` - - metadata *saml.EntityDescriptor - metadataOnce sync.Once - - entityCert *x509.Certificate - entityKey *rsa.PrivateKey - certKeyOnce sync.Once -} - -type SAMLListener struct { - *SAMLServiceProvider `yaml:"saml"` - - Addr string `yaml:"listen"` - InsecureAddr string `yaml:"listen_insecure"` - Certificate string `yaml:"cert"` - VirtualHosts map[string]*SAMLVirtualHost `yaml:"virtual_hosts"` -} - -type SAMLProxy struct { - Listener SAMLListener `yaml:"listener"` - - logger log.Logger -} - -const ( - AuthRequired authEnforcement = iota - AuthOptional -) - -var samlAttributeReplaceRegexp = regexp.MustCompile(`[^a-z0-9]+`) -var restrictedHeaders = hashset.FromSlice([]string{"on-behalf-of"}) - -// Handle implements RouteAction -func (a *RedirectAction) Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - newUrl := *r.URL - - if a.Destination.Host != "" { - newUrl.Host = a.Destination.Host - } - newUrl.Path = a.Destination.Path - if a.Destination.RawQuery != "" { - newUrl.RawQuery = a.Destination.RawQuery - } - if a.Destination.Scheme != "" { - newUrl.Scheme = a.Destination.Scheme - } - if a.Destination.Fragment != "" { - newUrl.Fragment = a.Destination.Fragment - } - - status := a.StatusCode - if status == 0 { - status = http.StatusFound - } - - w.Header().Set("location", newUrl.String()) - w.WriteHeader(status) -} - -// UnmarshalYAML implements yaml.Unmarshaler -func (r *Route) UnmarshalYAML(node *yaml.Node) error { - var rawNode struct { - Auth string `yaml:"auth"` - Path *stringmatch.MatchRule `yaml:"path"` - Redirect *struct { - Destination string `yaml:"dest"` - Status int `yaml:"status"` - } `yaml:"redirect"` - } - - if err := node.Decode(&rawNode); err != nil { - return err - } - - switch rawNode.Auth { - case "required": - r.Auth = AuthRequired - case "optional": - r.Auth = AuthOptional - default: - return fmt.Errorf("error unmarshaling route: invalid auth enforcement string value: %s", node.Value) - } - - if rawNode.Path != nil { - m, err := rawNode.Path.Matcher() - if err != nil { - return fmt.Errorf("error unmarshaling route: invalid path matcher: %v", err) - } - r.Path = m - } else { - return errors.New("error unmarshaling route: exactly one of (`path`) must be specified") - } - - if rawNode.Redirect != nil { - u, err := url.Parse(rawNode.Redirect.Destination) - if err != nil { - return err - } - r.Action = &RedirectAction{ - Destination: u, - StatusCode: rawNode.Redirect.Status, - } - } - - return nil -} - -// RouteFromArg implements the 3rd argument to flag.Func. -// -// It parses a string in the format of auth:field:match_mode:value, returning a Route if -// it parses successfully. -func RouteFromArg(arg string) (*Route, error) { - parts := strings.SplitN(arg, ":", 4) - if len(parts) != 4 { - return nil, fmt.Errorf("invalid route spec: %q", arg) - } - a, f, t, v := parts[0], parts[1], parts[2], parts[3] - var auth authEnforcement - switch strings.ToLower(a) { - case "r", "req", "required": - auth = AuthRequired - case "o", "opt", "optional": - auth = AuthOptional - default: - return nil, fmt.Errorf("invalid auth setting: %q", a) - } - - route := &Route{ - Auth: auth, - } - - match := stringmatch.MatchRule{ - Mode: t, - Value: v, - } - m, err := match.Matcher() - if err != nil { - return nil, err - } - - switch strings.ToLower(f) { - case "p", "path": - route.Path = m - default: - return nil, fmt.Errorf("invalid match field: %q", f) - } - - return route, nil -} - -// Client returns an HTTP client for making requests to the backend. -func (b *SAMLBackend) Client() (*http.Client, error) { - var err error - b.clientOnce.Do(func() { - transport := &http.Transport{} - var tlsConfig *tls.Config - - if b.Identity != "" { - myIdentity := mtls.DefaultIdentity() - tlsConfig, err = myIdentity.TlsConfig(context.Background()) - if err != nil { - return - } - - verifier := mtls.NewPeerNameVerifier() - verifier.AllowFrom(mtls.Service, b.Identity) - err = verifier.ConfigureClient(tlsConfig) - if err != nil { - return - } - - transport.TLSClientConfig = tlsConfig - } - - client := &http.Client{ - Transport: transport, - } - - b.client = client - }) - if err != nil { - return nil, err - } - return b.client, nil -} - -// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host -// and other settings. -func (sp *SAMLProxy) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) { - var _ yaml.Unmarshaler = &Route{} - - if sp.logger == nil { - sp.logger = log.Default().WithPrefix("SAMLProxy") - } - - handler, err := sp.newHandler() - if err != nil { - return nil, err - } - - addr := sp.Listener.Addr - if addr == "" { - addr = "[::]:8443" - } - - lm := log.NewLoggingMiddlewareWithLogger(handler, sp.logger) - server := &http.Server{ - Addr: addr, - Handler: lm.HandlerFunc(), - } - - if sp.Listener.Certificate != "" { - cert := mtls.NewSSLCertificate(sp.Listener.Certificate) - tlsConfig, err := cert.TlsConfig(ctx) - if err != nil { - return nil, err - } - server.TLSConfig = tlsConfig - } - - return server, nil -} - -func (sp *SAMLProxy) NewHTTPSRedirectorWithContext(ctx context.Context) *http.Server { - addr := sp.Listener.InsecureAddr - if addr == "" { - addr = "[::]:8080" - } - - server := &http.Server{ - Addr: addr, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := r.Host - if host == "" { - w.WriteHeader(http.StatusBadRequest) - return - } - - if _, ok := sp.Listener.VirtualHosts[host]; !ok { - w.WriteHeader(http.StatusMisdirectedRequest) - return - } - - newUrl := *r.URL - newUrl.Scheme = "https" - newUrl.Host = host - w.Header().Set("location", newUrl.String()) - w.WriteHeader(http.StatusFound) - }), - } - - return server -} - -func (sp *SAMLServiceProvider) Metadata() (*saml.EntityDescriptor, error) { - var err error - sp.metadataOnce.Do(func() { - var idpMetadataUrl *url.URL - idpMetadataUrl, err = url.Parse(sp.IDP) - if err != nil { - return - } - - sp.metadata, err = samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataUrl) - }) - - if err != nil { - sp.metadataOnce = sync.Once{} - return nil, err - } - - return sp.metadata, nil -} - -func (sp *SAMLServiceProvider) CertAndKey() (cert *x509.Certificate, pvk *rsa.PrivateKey, err error) { - sp.certKeyOnce.Do(func() { - if sp.EntityPrivateKey != "" { - var loadedKey crypto.PrivateKey - loadedKey, err = certutil.LoadPrivateKeyFromPEM(sp.EntityPrivateKey) - if err != nil { - return - } - var ok bool - pvk, ok = loadedKey.(*rsa.PrivateKey) - if !ok { - err = fmt.Errorf("loaded private key is %T, not *rsa.PrivateKey", pvk) - return - } - } else { - // generate new RSA private key - pvk, err = rsa.GenerateKey(saml.RandReader, 2048) - if err != nil { - return - } - } - if sp.EntityCertificate != "" { - certs, err := certutil.LoadCertificatesFromPEM(sp.EntityCertificate) - if err != nil { - return - } - cert = certs[0] - } else { - // generate new self-signed X509 certificate - serialBytes := make([]byte, 16) - saml.RandReader.Read(serialBytes) - serial := big.NewInt(0) - serial.SetBytes(serialBytes) - - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: sp.EntityID, - }, - SerialNumber: serial, - BasicConstraintsValid: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, - IsCA: false, - NotBefore: time.Now(), - NotAfter: time.Now().Add(90 * 86400 * time.Second), - } - certBytes, err := x509.CreateCertificate(saml.RandReader, template, template, &pvk.PublicKey, pvk) - if err != nil { - return - } - cert, err = x509.ParseCertificate(certBytes) - if err != nil { - return - } - } - - sp.entityCert = cert - sp.entityKey = pvk - }) - - if err != nil { - sp.metadataOnce = sync.Once{} - return nil, nil, err - } - - return sp.entityCert, sp.entityKey, nil -} - -func (sp *SAMLServiceProvider) NewServiceProvider(host string) (*samlsp.Middleware, error) { - idpMetadata, err := sp.Metadata() - if err != nil { - return nil, err - } - cert, key, err := sp.CertAndKey() - if err != nil { - return nil, err - } - return samlsp.New(samlsp.Options{ - EntityID: sp.EntityID, - URL: url.URL{ - Scheme: "https", - Host: host, - }, - Key: key, - Certificate: cert, - IDPMetadata: idpMetadata, - }) -} - -func (sp *SAMLProxy) newHandler() (http.HandlerFunc, error) { - samlSp := make(map[string]*samlsp.Middleware, 0) - spMu := &sync.Mutex{} - - handle := func(w http.ResponseWriter, r *http.Request) { - // ensure host header present - host := r.Header.Get("Host") - if host == "" { - host = r.Header.Get(":authority") - } - if host == "" { - host = r.Host - } - if host == "" { - r.Header.Write(os.Stderr) - fmt.Fprintf(os.Stderr, "%v\n", r.URL.String()) - sp.writeError(w, http.StatusBadRequest, errors.New("missing Host header")) - return - } - - // ensure client isn't trying to inject saml-related headers - if err := sp.checkRequest(r); err != nil { - sp.writeError(w, http.StatusBadRequest, err) - return - } - - // make sure this host is known - vhost, ok := sp.Listener.VirtualHosts[host] - if !ok { - sp.writeError(w, http.StatusMisdirectedRequest, - errors.New("Misdirected request: unknown virtual host")) - - return - } - - // ensure we have SP instance - spMu.Lock() - if _, ok := samlSp[host]; !ok { - samlSettings := vhost.SAMLServiceProvider - if samlSettings == nil { - samlSettings = sp.Listener.SAMLServiceProvider - } - idpMetadata, err := samlSettings.Metadata() - if err != nil { - sp.writeError(w, http.StatusInternalServerError, err) - return - } - // make sure the browser isn't trying to access the IdP - this can happen if the TLS session - // was reused because our certificate is also valid for the SSO URL. - for _, ssoDesc := range idpMetadata.IDPSSODescriptors { - for _, ssoSvc := range ssoDesc.SingleSignOnServices { - if loginUrl, err := url.Parse(ssoSvc.Location); err == nil { - if loginUrl.Host == host { - sp.writeError(w, http.StatusMisdirectedRequest, - errors.New("Misdirected request: this is not the IDP you're looking for")) - - return - } - } - } - } - middleware, err := samlSettings.NewServiceProvider(host) - if err != nil { - sp.writeError(w, http.StatusInternalServerError, err) - return - } - samlSp[host] = middleware - } - spMu.Unlock() - - provider := samlSp[host] - if r.URL.Path == "/saml/acs" { - provider.ServeACS(w, r) - return - } - - session, sessionErr := provider.Session.GetSession(r) - - if sessionErr != nil && sessionErr != samlsp.ErrNoSession { - sp.logger.V(2).Warningf("non-NoSession err from sp: %v", sessionErr) - sp.writeError(w, http.StatusBadRequest, sessionErr) - return - } - - defaultRoute := true - next := sp.fulfill(vhost, session) - - sp.logger.V(3).Debugf("checking for routes matching %s", r.URL) - for _, route := range vhost.Routes { - match := false - if route.Path != nil { - match = route.Path.Match(r.URL.Path) - sp.logger.V(3).Debugf("path %s matches %s: %t", - r.URL.Path, route.Path.String(), match) - } else { - sp.writeError(w, http.StatusInternalServerError, - errors.New("nothing to match on in route")) - } - - if match { - defaultRoute = false - if sessionErr == samlsp.ErrNoSession && route.Auth == AuthRequired { - sp.logger.V(3).Debugf("route requires a valid session, redirecting") - - provider.HandleStartAuthFlow(w, r) - return - } - - if route.Action != nil { - sp.logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action) - route.Action.Handle(w, r, next) - } - } - } - - if defaultRoute { - sp.logger.V(3).Debugf("using default route") - if sessionErr == samlsp.ErrNoSession { - sp.logger.V(3).Debugf("default route requires a valid session, redirecting") - provider.HandleStartAuthFlow(w, r) - return - } - } - - next(w, r) - } - - return handle, nil -} - -func (sp *SAMLProxy) fulfill(vhost *SAMLVirtualHost, session samlsp.Session) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if session != nil { - sp.logger.V(3).Debugf("valid saml session(%T): %+v", session, session) - } else { - sp.logger.V(3).Debugf("serving path %s without session", r.URL.Path) - } - - newReq := r.Clone(r.Context()) - newReq.URL.Scheme = "http" - if vhost.Backend.Identity != "" { - newReq.URL.Scheme = "https" - } - newReq.URL.Host = net.JoinHostPort(vhost.Backend.Host, strconv.Itoa(vhost.Backend.Port)) - newReq.RequestURI = "" - - if swa, ok := session.(samlsp.SessionWithAttributes); ok { - attrs := swa.GetAttributes() - oboHeader := vhost.Backend.UsernameHeader - if oboHeader == "" { - oboHeader = "on-behalf-of" - } - sp.logger.V(3).Debugf("setting origin request header: %s: %q", oboHeader, attrs.Get("uid")) - newReq.Header.Set(oboHeader, attrs.Get("uid")) - } - - if jwts, ok := session.(samlsp.JWTSessionClaims); ok { - newReq.Header.Set("x-saml-audience", jwts.StandardClaims.Audience) - sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-audience", jwts.StandardClaims.Audience) - - iat := strconv.FormatInt(jwts.StandardClaims.IssuedAt, 10) - newReq.Header.Set("x-saml-issued-at", iat) - sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-issued-at", iat) - - eat := strconv.FormatInt(jwts.StandardClaims.ExpiresAt, 10) - newReq.Header.Set("x-saml-expires-at", eat) - sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-expires-at", eat) - - newReq.Header.Set("x-saml-subject", jwts.StandardClaims.Subject) - sp.logger.V(3).Debugf("setting origin request header: %s: %s", "x-saml-subject", jwts.StandardClaims.Subject) - - for attr, values := range jwts.Attributes { - headerName := fmt.Sprintf("x-saml-%s", - samlAttributeReplaceRegexp.ReplaceAllString(strings.ToLower(attr), "-")) - headerValue := strings.Join(values, ", ") - sp.logger.V(3).Debugf("setting origin request header: %s: %s", - headerName, headerValue) - newReq.Header.Set(headerName, headerValue) - } - } - - // set proxy headers - if remoteHost, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { - sp.logger.V(3).Debugf("x-forwarded-for: %s", remoteHost) - newReq.Header.Set("x-forwarded-for", remoteHost) - } - - // proxy the request to the backend - client, err := vhost.Backend.Client() - if err != nil { - sp.writeError(w, http.StatusInternalServerError, fmt.Errorf("error setting up connection to backend: %v", err)) - } - response, err := client.Do(newReq) - if err != nil { - sp.writeError(w, http.StatusBadGateway, err) - return - } - - for name, value := range response.Header { - w.Header().Set(name, strings.Join(value, ", ")) - } - - if response.StatusCode == http.StatusSwitchingProtocols { - hijacker, ok := w.(http.Hijacker) - if !ok { - sp.writeError(w, http.StatusMethodNotAllowed, errors.New("websocket passthrough not supported")) - return - } - - upstreamWriter, ok := response.Body.(io.Writer) - if !ok { - sp.writeError(w, http.StatusMethodNotAllowed, errors.New("body doesn't support io.Writer")) - return - } - - w.WriteHeader(response.StatusCode) - - conn, rw, err := hijacker.Hijack() - if err != nil { - sp.writeError(w, http.StatusInternalServerError, err) - return - } - - wg := sync.WaitGroup{} - wg.Add(2) - pipe := func(w io.Writer, r io.Reader) { - defer wg.Done() - io.Copy(w, r) - } - go pipe(rw, response.Body) - go pipe(upstreamWriter, rw) - - wg.Wait() - conn.Close() - return - } - - w.WriteHeader(response.StatusCode) - io.Copy(w, response.Body) - } -} - -func (sp *SAMLProxy) writeError(w http.ResponseWriter, status int, err error) { - sp.logger.V(1).Warningf("returning status: %d %s", status, err.Error()) - - w.WriteHeader(status) - w.Write([]byte(fmt.Sprintf("

%d %s

", status, err.Error()))) -} - -func (sp *SAMLProxy) checkRequest(r *http.Request) error { - for k, _ := range r.Header { - k = strings.ToLower(k) - if strings.HasPrefix(k, "x-saml-") || restrictedHeaders.Contains(k) { - return fmt.Errorf("downstream attempted to overwrite restricted header: %q", k) - } - } - return nil -} diff --git a/http/samlproxy/main.go b/http/samlproxy/main.go deleted file mode 100644 index 50ff0ab..0000000 --- a/http/samlproxy/main.go +++ /dev/null @@ -1,86 +0,0 @@ -package main - -import ( - "context" - "flag" - "os" - "os/signal" - "syscall" - "time" - - "github.com/coreos/go-systemd/daemon" - "gopkg.in/yaml.v3" - - "go.fuhry.dev/runtime/http" - "go.fuhry.dev/runtime/mtls" - "go.fuhry.dev/runtime/utils/log" -) - -func main() { - mtls.SetDefaultIdentity("authproxy") - - sp := &http.SAMLProxy{ - Listener: http.SAMLListener{ - SAMLServiceProvider: &http.SAMLServiceProvider{}, - }, - } - vhost := &http.SAMLVirtualHost{ - Backend: &http.SAMLBackend{}, - } - - loadConfig := func(arg string) error { - contents, err := os.ReadFile(arg) - if err != nil { - return err - } - - err = yaml.Unmarshal(contents, sp) - return err - } - addRoute := func(arg string) error { - route, err := http.RouteFromArg(arg) - if err != nil { - return err - } - vhost.Routes = append(vhost.Routes, route) - return nil - } - - vhostName := flag.String("vhost", "", "HTTP(S) hostname to serve") - flag.Func("config", "YAML file to load configuration from", loadConfig) - flag.Func("route", "Route rule in the format of auth:field:matcher:value\n"+ - " auth: required, optional\n"+ - " field: path\n"+ - " matcher: prefix, suffix, exact, contains, regexp\n"+ - " value: any string", addRoute) - flag.StringVar(&sp.Listener.EntityID, "saml.entity-id", "", "entity ID of SAML service provider") - flag.StringVar(&sp.Listener.IDP, "saml.idp.url", "", "URL to IdP metadata") - flag.StringVar(&sp.Listener.Certificate, "ssl-cert", "", "SSL certificate name to use from /etc/ssl/private") - flag.StringVar(&vhost.Backend.Host, "backend.host", "127.0.0.1", "backend host") - flag.IntVar(&vhost.Backend.Port, "backend.port", 0, "backend port") - flag.StringVar(&vhost.Backend.Identity, "backend.mtls-id", "", "backend mTLS identity; omit to disable TLS to backend") - flag.StringVar(&sp.Listener.Addr, "listen", "[::]:8443", "address for auth proxy to listen on") - flag.StringVar(&sp.Listener.InsecureAddr, "listen.http", "[::]:8080", "address for http-to-https redirector") - - flag.Parse() - - sp.Listener.VirtualHosts[*vhostName] = vhost - - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer cancel() - server, err := sp.NewHTTPServerWithContext(ctx) - if err != nil { - log.Panic(err) - } - go server.ListenAndServeTLS("", "") - - unsecureServer := sp.NewHTTPSRedirectorWithContext(ctx) - go unsecureServer.ListenAndServe() - - daemon.SdNotify(false, daemon.SdNotifyReady) - - <-ctx.Done() - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer shutdownCancel() - server.Shutdown(shutdownCtx) -} diff --git a/http/server.go b/http/server.go new file mode 100644 index 0000000..9f88be1 --- /dev/null +++ b/http/server.go @@ -0,0 +1,282 @@ +package http + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + + "go.fuhry.dev/runtime/mtls" + "go.fuhry.dev/runtime/utils/log" + "go.fuhry.dev/runtime/utils/stringmatch" + "gopkg.in/yaml.v3" +) + +type serverCtxVar int + +type RouteAction interface { + Handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) +} + +type Route struct { + Path stringmatch.StringMatcher + Action RouteAction +} + +type VirtualHost struct { + Routes []*Route `yaml:"routes"` +} + +type Listener struct { + Addr string `yaml:"listen"` + InsecureAddr string `yaml:"listen_insecure"` + Certificate string `yaml:"cert"` + VirtualHosts map[string]*VirtualHost `yaml:"virtual_hosts"` +} + +type Server struct { + Listener *Listener `yaml:"listener"` + Context context.Context `yaml:"-"` +} + +type initHook func(context.Context, *yaml.Node) (context.Context, error) +type routeParseFunc func(*yaml.Node) (RouteAction, error) + +const ( + kLogger serverCtxVar = iota + kListener + kListenAddr + kSamlDefaults +) + +var initHooks []initHook +var routeParseFuncs []routeParseFunc + +func AddServerInitHook(hook initHook) { + initHooks = append(initHooks, hook) +} + +func AddRouteParseFunc(rpf routeParseFunc) { + routeParseFuncs = append(routeParseFuncs, rpf) +} + +func NewServer() *Server { + return NewServerWithContext(context.Background()) +} + +func NewServerWithContext(ctx context.Context) *Server { + logger := log.WithPrefix(fmt.Sprintf("%T", &Server{})) + + return &Server{ + Listener: &Listener{ + VirtualHosts: make(map[string]*VirtualHost, 0), + }, + Context: context.WithValue(ctx, kLogger, logger), + } +} + +// UnmarshalYAML implements yaml.Unmarshaler +func (r *Route) UnmarshalYAML(node *yaml.Node) error { + var rawNode struct { + Path *stringmatch.MatchRule `yaml:"path"` + } + + if err := node.Decode(&rawNode); err != nil { + return err + } + + if rawNode.Path != nil { + m, err := rawNode.Path.Matcher() + if err != nil { + return fmt.Errorf("error unmarshaling route: invalid path matcher: %v", err) + } + r.Path = m + } else { + return errors.New("error unmarshaling route: path must be specified") + } + + for _, rpf := range routeParseFuncs { + action, err := rpf(node) + if err != nil { + return err + } + if action != nil { + r.Action = action + break + } + } + + return nil +} + +// UnmarshalYAML implements yaml.Unmarshaler +func (s *Server) UnmarshalYAML(node *yaml.Node) error { + lc := &struct { + Listener *Listener `yaml:"listener"` + }{} + + if s.Context == nil { + s.Context = context.Background() + } + + if err := node.Decode(&lc); err != nil { + return err + } + + s.Listener = lc.Listener + + for _, initHook := range initHooks { + newCtx, err := initHook(s.Context, node) + if err != nil { + return err + } + s.Context = newCtx + } + + return nil +} + +func (s *Server) Create() (*http.Server, error) { + listenerCtx := context.WithValue(s.Context, kListener, s.Listener) + return s.Listener.NewHTTPServerWithContext(listenerCtx) +} + +func (s *Server) CreateInsecure() *http.Server { + listenerCtx := context.WithValue(s.Context, kListener, s.Listener) + return s.Listener.NewHTTPSRedirectorWithContext(listenerCtx) +} + +// NewHTTPServerWithContext creates an http.Server using the proxy's virtual host +// and other settings. +func (l *Listener) NewHTTPServerWithContext(ctx context.Context) (*http.Server, error) { + if l.Addr == "" { + l.Addr = "[::]:8443" + } + + logger := LoggerFromContext(ctx).WithPrefix(fmt.Sprintf("%T(%s)", l, l.Addr)) + serverCtx := context.WithValue(ctx, kLogger, logger) + + lm := log.NewLoggingMiddlewareWithLogger( + http.HandlerFunc(l.handle), + logger.AppendPrefix("access")) + + server := &http.Server{ + Addr: l.Addr, + BaseContext: func(l net.Listener) context.Context { + return context.WithValue(serverCtx, kListenAddr, l.Addr()) + }, + Handler: lm.HandlerFunc(), + } + + if l.Certificate != "" { + cert := mtls.NewSSLCertificate(l.Certificate) + tlsConfig, err := cert.TlsConfig(serverCtx) + if err != nil { + return nil, err + } + server.TLSConfig = tlsConfig + } + + return server, nil +} + +func (l *Listener) NewHTTPSRedirectorWithContext(ctx context.Context) *http.Server { + if l.InsecureAddr == "" { + l.InsecureAddr = "[::]:8080" + } + + server := &http.Server{ + Addr: l.InsecureAddr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := r.Host + if host == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + if _, ok := l.VirtualHosts[host]; !ok { + w.WriteHeader(http.StatusMisdirectedRequest) + return + } + + newUrl := *r.URL + newUrl.Scheme = "https" + newUrl.Host = host + w.Header().Set("location", newUrl.String()) + w.WriteHeader(http.StatusFound) + }), + } + + return server +} + +func (l *Listener) handle(w http.ResponseWriter, r *http.Request) { + // ensure host header present + if r.Host == "" { + r.Header.Write(os.Stderr) + http.Error(w, "missing Host header", http.StatusBadRequest) + return + } + + // make sure this host is known + vhost, ok := l.VirtualHosts[r.Host] + if !ok { + http.Error(w, "Misdirected request: unknown virtual host", + http.StatusMisdirectedRequest) + + return + } + + l.fulfill(w, r, vhost.Routes) +} + +func (l *Listener) fulfill(w http.ResponseWriter, r *http.Request, routes []*Route) { + logger := LoggerFromContext(r.Context()) + if logger == nil { + http.Error(w, "cannot get logger", http.StatusInternalServerError) + } + + logger.V(3).Debugf("checking for routes matching %s", r.URL) + for i, route := range routes { + match := false + if route.Path != nil { + match = route.Path.Match(r.URL.Path) + logger.V(3).Debugf("path %s matches %s: %t", + r.URL.Path, route.Path.String(), match) + } else { + http.Error(w, "nothing to match on in route", http.StatusInternalServerError) + } + + if match { + if route.Action != nil { + logger.V(3).Debugf("route has action %T, dispatching: %+v", route.Action, route.Action) + next := http.NotFound + if len(routes) > i { + next = func(w http.ResponseWriter, r *http.Request) { + logger.V(3).Debugf("%T called next(), continuing request processing", route.Action) + l.fulfill(w, r, routes[i+1:]) + } + } + route.Action.Handle(w, r, next) + return + } else { + http.Error(w, + fmt.Sprintf("no action configured for route %s", route.Path.String()), + http.StatusInternalServerError) + } + } + } + + http.NotFound(w, r) +} + +func LoggerFromContext(ctx context.Context) log.Logger { + l := ctx.Value(kLogger) + if logger, ok := l.(log.Logger); ok { + return logger + } + + return nil +} -- 2.50.1