]> go.fuhry.dev Git - runtime.git/commitdiff
machines/client: sparse generic type for server fields that can be a UUID or full...
authorDan Fuhry <dan@fuhry.com>
Tue, 27 Feb 2024 20:55:50 +0000 (15:55 -0500)
committerDan Fuhry <dan@fuhry.com>
Tue, 27 Feb 2024 20:55:50 +0000 (15:55 -0500)
attestation/internal/attestation/server.go
machines/types.go

index 2fcd2f59d91afe9cfe9ae500c5317a8240da6d55..d6fccdf0565a99753fc2292248101a34da4dc1d5 100644 (file)
@@ -146,14 +146,14 @@ func verifyEK(ek *attest.EK, clientInfo *rpcClientInfo) (string, error) {
                logger.Errorf("failed to lookup host %q with machines API: %+v", udn, err)
                return "", err
        }
-       if machinesHost.ID == "" {
+       if machinesHost.ID() == "" {
                logger.Errorf("failed to lookup host %q with machines API: API call returned OK, but host ID is empty", udn, err)
                return "", fmt.Errorf("cannot get host UUID from Machines API")
        }
        logger.V(1).Infof("host %q has machines UUID %s", udn, machinesHost.ID)
 
        machinesEK := &machines.EndorsementKey{}
-       err = client.APICall("/host/"+machinesHost.ID+"/endorsement_key", nil, machinesEK)
+       err = client.APICall("/host/"+machinesHost.ID()+"/endorsement_key", nil, machinesEK)
        if err != nil {
                logger.Errorf("failed to retrieve endorsement key for host %q from machines API: %+v", udn, err)
                return "", err
@@ -163,7 +163,7 @@ func verifyEK(ek *attest.EK, clientInfo *rpcClientInfo) (string, error) {
                // logger.Errorf("host %q does not have an endorsement key on file with Machines", udn)
                // return "", fmt.Errorf("host %q does not have an endorsement key on file with Machines", udn)
                // if the host doesn't have an endorsement key on file, we TOFU.
-               return machinesHost.ID, nil
+               return machinesHost.ID(), nil
        }
 
        expectFingerprint := machinesEK.EndorsementKey.Fingerprint.SHA256.AsBytes()
@@ -199,7 +199,7 @@ func verifyEK(ek *attest.EK, clientInfo *rpcClientInfo) (string, error) {
        }
 
        logger.Noticef("host %q: fingerprint matched, proceeding with attestation", udn)
-       return machinesHost.ID, nil
+       return machinesHost.ID(), nil
 }
 
 func storeQuote(rci *rpcClientInfo, hostGUID string, quote *rpcStoreQuoteParams) error {
index 8402d03bc29dc7119d0f36fb9e9a57fdc5380fa9..2ab2e634c6dedc4427bb8b5c9ca5a2f3b930e95f 100644 (file)
@@ -1,17 +1,54 @@
 package machines
 
 import (
+       "bytes"
        "encoding/hex"
+       "encoding/json"
+       "fmt"
        "net"
+       "slices"
+       "strings"
        "time"
+
+       "github.com/miekg/dns"
+
+       "go.fuhry.dev/runtime/utils/hashset"
 )
 
+const StaleHostTTL = 1440 * time.Hour
+
 type Timestamp uint64
 type IPString string
 type HexEncoded string
+type IPv4PrefixLength uint8
+type IPv6PrefixLength uint8
+
+type ErrUnhandledRecordType string
+
+func (e ErrUnhandledRecordType) Error() string {
+       return string(e)
+}
+
+type HasID interface {
+       ID() string
+}
 
 type WithUUID struct {
-       ID string `json:"id"`
+       Id string `json:"id"`
+}
+
+func (wu *WithUUID) ID() string {
+       if wu == nil {
+               return ""
+       }
+
+       return wu.Id
+}
+
+type Sparse[T HasID] struct {
+       id  string
+       set bool
+       o   T
 }
 
 type WithCalculatedName struct {
@@ -22,15 +59,15 @@ type Host struct {
        *WithUUID
        *WithCalculatedName
 
-       Name            string    `json:"name"`
-       Owner           *User     `json:"owner,omitempty"`
-       Role            string    `json:"role"`
-       OS              string    `json:"os"`
-       CreatedAt       Timestamp `json:"created_at"`
-       LastSeen        Timestamp `json:"last_seen"`
-       LastSeenIface   *Iface    `json:"last_seen_iface"`
-       LastSeenIfaceID string    `json:"last_seen_iface"`
-       Interfaces      []*Iface  `json:"interfaces"`
+       Name          string         `json:"name"`
+       Owner         Sparse[*User]  `json:"owner"`
+       Role          string         `json:"role"`
+       OS            string         `json:"os"`
+       CreatedAt     Timestamp      `json:"created_at"`
+       LastSeen      Timestamp      `json:"last_seen"`
+       LastSeenIface Sparse[*Iface] `json:"last_seen_iface"`
+       Flags         []string       `json:"flags"`
+       Interfaces    []*Iface       `json:"interfaces"`
 }
 
 type User struct {
@@ -45,67 +82,256 @@ type User struct {
        Flags      []string  `json:"flags"`
 }
 
+type DNSRData struct {
+       // A/AAAA
+       Address IPString `json:"Address"`
+
+       // MX
+       // Priority is under SRV below
+       Mailserver string `json:"Mailserver"`
+
+       // CNAME
+       Target string `json:"Target"`
+
+       // NS
+       Nameserver string `json:"Nameserver"`
+
+       // SRV
+       Priority uint   `json:"Priority"`
+       Weight   uint   `json:"Weight"`
+       Port     uint   `json:"Port"`
+       Server   string `json:"Server"`
+
+       // TXT
+       Value string `json:"Value"`
+
+       // SSHFP
+       HostKeyAlg  int    `json:"Host Key Algorithm"`
+       FPHashAlg   int    `json:"Fingerprint Hash Algorithm"`
+       Fingerprint string `json:"Fingerprint"`
+
+       // CAA
+       Flags int    `json:"Flags"`
+       Tag   string `json:"Tag"`
+       // Value is under TXT
+}
+
+type DNSRecord struct {
+       *WithUUID
+       *WithCalculatedName
+
+       Domain Sparse[*Domain] `json:"domain"`
+       Owner  string          `json:"owner"`
+       Type   string          `json:"type"`
+       RName  string          `json:"rname"`
+       RData  *DNSRData       `json:"rdata"`
+       TTL    uint            `json:"ttl"`
+}
+
+func (r *DNSRecord) ToRR() (dns.RR, error) {
+       var msg dns.RR
+
+       hdr := dns.RR_Header{
+               Class: dns.ClassINET,
+               Ttl:   3600,
+       }
+       if r.TTL > 0 {
+               hdr.Ttl = uint32(r.TTL)
+       }
+
+       if r.Type == "A" {
+               hdr.Rrtype = dns.TypeA
+               msg = &dns.A{
+                       Hdr: hdr,
+                       A:   IPString(r.RData.Address).AsIP(),
+               }
+       } else if r.Type == "AAAA" {
+               hdr.Rrtype = dns.TypeAAAA
+               msg = &dns.AAAA{
+                       Hdr:  hdr,
+                       AAAA: IPString(r.RData.Address).AsIP(),
+               }
+       } else if r.Type == "CNAME" {
+               hdr.Rrtype = dns.TypeCNAME
+               msg = &dns.CNAME{
+                       Hdr:    hdr,
+                       Target: r.RData.Target,
+               }
+       } else if r.Type == "SRV" {
+               hdr.Rrtype = dns.TypeSRV
+               msg = &dns.SRV{
+                       Hdr:      hdr,
+                       Priority: uint16(r.RData.Priority),
+                       Weight:   uint16(r.RData.Weight),
+                       Port:     uint16(r.RData.Port),
+                       Target:   r.RData.Server,
+               }
+       } else if r.Type == "TXT" {
+               hdr.Rrtype = dns.TypeTXT
+               msg = &dns.TXT{
+                       Hdr: hdr,
+                       Txt: []string{
+                               r.RData.Value,
+                       },
+               }
+       } else if r.Type == "MX" {
+               hdr.Rrtype = dns.TypeMX
+               msg = &dns.MX{
+                       Hdr:        hdr,
+                       Preference: uint16(r.RData.Priority),
+                       Mx:         r.RData.Mailserver,
+               }
+       } else if r.Type == "SSHFP" {
+               hdr.Rrtype = dns.TypeSSHFP
+               msg = &dns.SSHFP{
+                       Hdr:         hdr,
+                       Type:        uint8(r.RData.HostKeyAlg),
+                       Algorithm:   uint8(r.RData.FPHashAlg),
+                       FingerPrint: r.RData.Fingerprint,
+               }
+       } else if r.Type == "CAA" {
+               hdr.Rrtype = dns.TypeCAA
+               msg = &dns.CAA{
+                       Hdr:   hdr,
+                       Flag:  uint8(r.RData.Flags),
+                       Tag:   r.RData.Tag,
+                       Value: r.RData.Value,
+               }
+       } else {
+               return nil, ErrUnhandledRecordType(
+                       fmt.Sprintf("don't know how to handle record type: %q", r.Type))
+       }
+
+       return msg, nil
+}
+
+func (r *DNSRecord) String() string {
+       parts := make([]string, 0)
+
+       if r.RName != "" {
+               parts = append(parts, r.RName)
+       } else {
+               parts = append(parts, "@")
+       }
+
+       if r.TTL > 0 {
+               parts = append(parts, fmt.Sprintf("%d", r.TTL))
+       }
+
+       parts = append(parts, "IN", r.Type)
+
+       switch r.Type {
+       case "A", "AAAA":
+               parts = append(parts, r.RData.Address.String())
+       case "CNAME":
+               parts = append(parts, r.RData.Target)
+       case "MX":
+               parts = append(parts, fmt.Sprintf("%d %s", r.RData.Priority, r.RData.Mailserver))
+       case "SRV":
+               parts = append(parts, fmt.Sprintf("%d %d %d %s", r.RData.Priority, r.RData.Weight, r.RData.Port, r.RData.Server))
+       case "TXT":
+               value := r.RData.Value
+               if !strings.HasPrefix(value, `"`) || !strings.HasSuffix(value, `"`) {
+                       value = fmt.Sprintf(`"%s"`, strings.ReplaceAll(value, `"`, `\"`))
+               }
+               parts = append(parts, value)
+       case "SSHFP":
+               parts = append(parts, fmt.Sprintf("%d %d %s", r.RData.HostKeyAlg, r.RData.FPHashAlg, r.RData.Fingerprint))
+       case "CAA":
+               parts = append(parts, fmt.Sprintf(`%d %s "%s"`, r.RData.Flags, r.RData.Tag, r.RData.Value))
+       default:
+               return fmt.Sprintf("; error processing record %s: unsupported record type: %s", r.RName, r.Type)
+       }
+
+       return strings.Join(parts, "\t")
+}
+
 type Iface struct {
        *WithUUID
        *WithCalculatedName
 
-       Host            *Host          `json:"host"`
-       HostID          string         `json:"host"`
-       Name            string         `json:"name"`
-       MediaType       string         `json:"type"`
-       HardwareAddress string         `json:"hardware_address"`
-       LastIPv4        IPString       `json:"last_inet4"`
-       LastIPv6        IPString       `json:"last_inet6"`
-       LastSeen        Timestamp      `json:"last_seen"`
-       NameScrubbed    string         `json:"name_scrubbed"`
-       Reservations    []*Reservation `json:"reservations"`
+       Host            Sparse[*Host]   `json:"host"`
+       Name            string          `json:"name"`
+       MediaType       string          `json:"type"`
+       HardwareAddress string          `json:"hardware_address"`
+       LastIPv4        IPString        `json:"last_inet4"`
+       LastIPv6        IPString        `json:"last_inet6"`
+       LastSeen        Timestamp       `json:"last_seen"`
+       LastDomain      Sparse[*Domain] `json:"last_domain"`
+       NameScrubbed    string          `json:"name_scrubbed"`
+       Reservations    []*Reservation  `json:"reservations"`
 }
 
 type Reservation struct {
        *WithUUID
        *WithCalculatedName
 
-       Iface         *Iface   `json:"iface"`
-       IfaceID       string   `json:"iface"`
-       AddressFamily string   `json:"af"`
-       Address       IPString `json:"address"`
-       Domain        *Domain  `json:"domain"`
-       Range         *Range   `json:"range"`
+       Iface         Sparse[*Iface] `json:"iface"`
+       AddressFamily string         `json:"af"`
+       Address       IPString       `json:"address"`
+       Domain        *Domain        `json:"domain"`
+       Range         *Range         `json:"range"`
 }
 
 type Domain struct {
        *WithUUID
        *WithCalculatedName
 
-       Name   string `json:"name"`
-       Site   *Site  `json:"site"`
-       SiteID string `json:"site"`
-       VlanID uint   `json:"vlan_id"`
+       Name   string        `json:"name"`
+       Site   *Sparse[Site] `json:"site"`
+       VlanID uint          `json:"vlan_id"`
 
-       IPv4Address       IPString `json:"inet4_address"`
-       IPv4PrefixLength  uint8    `json:"inet4_prefixlen"`
-       IPv4RouterAddress IPString `json:"inet4_routeraddr"`
+       IPv4Address       IPString         `json:"inet4_address"`
+       IPv4PrefixLength  IPv4PrefixLength `json:"inet4_prefixlen"`
+       IPv4RouterAddress IPString         `json:"inet4_routeraddr"`
 
-       IPv6Address       IPString `json:"inet6_address"`
-       IPv6PrefixLength  uint8    `json:"inet6_prefixlen"`
-       IPv6RouterAddress IPString `json:"inet6_routeraddr"`
+       IPv6Address       IPString         `json:"inet6_address"`
+       IPv6PrefixLength  IPv6PrefixLength `json:"inet6_prefixlen"`
+       IPv6RouterAddress IPString         `json:"inet6_routeraddr"`
 
        PXEServerIPv4   IPString `json:"pxe4_server"`
        PXEServerIPv6   IPString `json:"pxe6_server"`
        PXEFilenameBIOS string   `json:"pxe_filename_bios"`
        PXEFilenameUEFI string   `json:"pxe_filename_uefi"`
+       PXEFilenameIPXE string   `json:"pxe_filename_ipxe"`
 
-       Features           []string `json:"features"`
-       DefaultRange       *Range
+       Features           []string       `json:"features"`
+       DefaultRange       Sparse[*Range] `json:"default_range"`
+       Ranges             map[string]*Range
        ReverseDNSZoneIPv4 string     `json:"inet4_reverse_zone"`
        ReverseDNSZoneIPv6 string     `json:"inet6_reverse_zone"`
        GuestSeedStr       HexEncoded `json:"guest_seed"`
        GuestPassword      string     `json:"guest_password"`
+
+       // these are not actually part of the server-side schema, just included here
+       // as a debugging aid
+       LastModified Timestamp    `json:"last_modified"`
+       DNSSearch    []string     `json:"dns_search"`
+       Interfaces   []*Iface     `json:"interfaces"`
+       Records      []*DNSRecord `json:"records"`
+}
+
+type RouterAddress struct {
+       Address   IPString
+       Interface string
+}
+
+type RouterAddresses struct {
+       IPv4 map[string]*RouterAddress
+       IPv6 map[string]*RouterAddress
 }
 
 type Range struct {
        *WithUUID
        *WithCalculatedName
+
+       Name      string
+       IPv4Start IPString `json:"inet4_start"`
+       IPv4End   IPString `json:"inet4_end"`
+       IPv6Start IPString `json:"inet6_start"`
+       IPv6End   IPString `json:"inet6_end"`
+
+       Reservations []*Reservation `json:"reservations"`
 }
 
 type Site struct {
@@ -134,6 +360,66 @@ func (ip IPString) AsIP() net.IP {
        return net.ParseIP(string(ip))
 }
 
+func (ip IPString) AsInt32() int32 {
+       netip := ip.AsIP().To4()
+       if netip == nil {
+               return 0
+       }
+
+       return int32(
+               int32(netip[3])<<24 |
+                       int32(netip[2])<<16 |
+                       int32(netip[1])<<8 |
+                       int32(netip[0]))
+}
+
+func (ip IPString) IsIPv4() bool {
+       netIP := ip.AsIP()
+       return bytes.Equal(netIP[:12], []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff})
+}
+
+func (ip IPString) ReverseDNSName() string {
+       netIP := ip.AsIP()
+       slices.Reverse(netIP)
+
+       if ip.IsIPv4() {
+               return fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa", netIP[0], netIP[1], netIP[2], netIP[3])
+       } else {
+               parts := make([]string, 0)
+               for _, b := range netIP {
+                       highNibble := (b >> 4) & 0xf
+                       lowNibble := b & 0xf
+                       parts = append(parts, fmt.Sprintf("%x.%x", lowNibble, highNibble))
+               }
+               parts = append(parts, "ip6", "arpa")
+               return strings.Join(parts, ".")
+       }
+}
+
+func (ip IPString) String() string {
+       return string(ip)
+}
+
+func (ip IPString) Defined() bool {
+       return ip != "" && ip.AsIP() != nil
+}
+
+func (pl IPv4PrefixLength) Mask() string {
+       return net.IP(net.CIDRMask(int(pl), 32)).String()
+}
+
+func (pl IPv4PrefixLength) IPMask() net.IPMask {
+       return net.CIDRMask(int(pl), 32)
+}
+
+func (pl IPv6PrefixLength) Mask() string {
+       return net.CIDRMask(int(pl), 128).String()
+}
+
+func (pl IPv6PrefixLength) IPMask() net.IPMask {
+       return net.CIDRMask(int(pl), 128)
+}
+
 func (h HexEncoded) AsBytes() []byte {
        ba, err := hex.DecodeString(string(h))
        if err != nil {
@@ -141,3 +427,84 @@ func (h HexEncoded) AsBytes() []byte {
        }
        return ba
 }
+
+func (sp *Sparse[T]) Set(in T) {
+       sp.o = in
+       sp.id = sp.o.ID()
+       sp.set = true
+}
+
+func (sp Sparse[T]) Get() T {
+       var undef T
+       if sp.set {
+               return sp.o
+       }
+
+       return undef
+}
+
+func (sp Sparse[T]) Defined() bool {
+       return sp.set || sp.id != ""
+}
+
+func (sp *Sparse[T]) UnmarshalJSON(in []byte) error {
+       var id string
+       obj := new(T)
+       if err := json.Unmarshal(in, &id); err == nil {
+               sp.id = id
+               return nil
+       } else if err := json.Unmarshal(in, obj); err == nil {
+               sp.Set(*obj)
+               return nil
+       }
+
+       return fmt.Errorf("failed to unmarshal %s to string or %T: %v", string(in), obj, json.Unmarshal(in, obj))
+}
+
+func (sp *Sparse[T]) MarshalJSON() ([]byte, error) {
+       if sp.set {
+               return json.Marshal(sp.o)
+       }
+       if sp.id != "" {
+               return json.Marshal(sp.id)
+       }
+
+       return []byte(`""`), nil
+}
+
+func (sp Sparse[T]) ID() string {
+       if sp.id != "" {
+               return sp.id
+       }
+       if sp.set {
+               return sp.o.ID()
+       }
+       return ""
+}
+
+func (i *Iface) ShouldPublishInDNS() bool {
+       if i == nil {
+               return false
+       }
+
+       if i.Host.ID() == "" {
+               return false
+       }
+
+       if host := i.Host.Get(); host != nil {
+               flags := hashset.FromSlice(host.Flags)
+               if flags.Contains("disabled") {
+                       return false
+               }
+
+               if flags.Contains("sticky") {
+                       return true
+               }
+       }
+
+       if i.LastSeen.AsTime().Add(StaleHostTTL).Before(time.Now()) {
+               return false
+       }
+
+       return true
+}