//go:build ignore package main import ( "bytes" "fmt" "go/format" "go/types" "log" "os" "text/template" "golang.org/x/tools/go/packages" ) // TODO: Also extract the body of the String() method and remove the // `rr.Hdr.String() + ` prefix in the return statement so it can be used // in a Value() method. Detect if anything is more complex than that. type Field struct { Name string Type string } var tpl = template.Must(template.New("").Parse(`package dns // GENERATED FILE, DO NOT MODIFY // See generate_dns_types.go in the repo root. import ( "encoding/json" "fmt" "net" "github.com/miekg/dns" ) {{ range $name, $fields := . -}} type {{ $name }} struct { Name string Ttl int {{ range $fields -}} {{ .Name }} {{ .Type }} {{ end -}} } func (r *{{ $name }}) ToDNS(zone NamedZone) dns.RR { return &dns.{{ $name }}{ Hdr: makeHeader(r.Name, zone, dns.Type{{ $name }}, r.Ttl), {{ range $fields -}} {{ .Name }}: r.{{ .Name }}, {{ end }} } } func (r *{{ $name }}) FromDNS(rr dns.RR) error { rt, ok := rr.(*dns.{{ $name }}) if !ok { return fmt.Errorf("Invalid type %T for '{{ $name }}'", rr) } r.Name = rr.Header().Name r.Ttl = int(rr.Header().Ttl) {{ range $fields -}} r.{{ .Name }} = rt.{{ .Name }} {{ end }} return nil } func (r *{{ $name }}) MarshalJSON() ([]byte, error) { type Alias {{ $name }} return json.Marshal(&struct { Type string *Alias }{"{{ $name }}", (*Alias)(r)}) } func (r *{{ $name }}) UnmarshalJSON(data []byte) error { type Alias {{ $name }} if err := json.Unmarshal(data, &struct{ Type string *Alias }{Alias: (*Alias)(r)}); err != nil { return err } return nil } var _ RR = (*{{ $name }})(nil) {{ end }} func FromDNS(rr dns.RR) interface{} { switch v := rr.(type) { {{ range $name, $fields := . -}} case *dns.{{ $name }}: rv := &{{ $name }}{} rv.FromDNS(v) return rv {{ end }} } return nil } `)) // TODO: support the alises by looking for a single embedded type var disallowedTypes = map[string]bool{ "CDNSKEY": true, // Alias for DNSKEY "CDS": true, // Alias for DS "DLV": true, // Alias for DS "KEY": true, // Alias for DNSKEY "SIG": true, // Alias for RRSIG "OPT": true, // []EDNS0 "PrivateRR": true, // For testing new RRTypes "RFC3597": true, // Unknown/generic rdata "ANY": true, // No rdata } var allowedPackages = map[string]bool{ "net": true, } func main() { conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedDeps} pkgs, err := packages.Load(&conf, "github.com/miekg/dns") if err != nil { panic(err) } scope := pkgs[0].Types.Scope() localTypes := map[string][]Field{} for _, name := range scope.Names() { o := scope.Lookup(name) if o == nil || !o.Exported() { continue } // Only consider structs st, ok := o.Type().Underlying().(*types.Struct) if !ok { continue } name := o.Name() // Explicitly disallow some types that have complex embedded types if _, skip := disallowedTypes[name]; skip { continue } // There must be a type constant for this if scope.Lookup(fmt.Sprintf("Type%s", name)) == nil { continue } fields := []Field{} for i := 0; i < st.NumFields(); i++ { f := st.Field(i) // Exclude header field if f.Name() == "Hdr" { continue } // Fail if there are complex types embedded if tp, ok := f.Type().(*types.Named); ok { if _, ok := allowedPackages[tp.Obj().Pkg().Path()]; !ok { log.Fatalf("Invalid embedded complex type: %s", tp) } } // Also fail if there are complex types embedded in a slice if tp, ok := f.Type().(*types.Slice); ok { if ut, ok := tp.Elem().(*types.Named); ok { if _, ok := allowedPackages[ut.Obj().Pkg().Path()]; !ok { log.Fatalf("Invalid embedded complex type: %s", tp) } } } fields = append(fields, Field{f.Name(), f.Type().String()}) } localTypes[name] = fields } // gofmt it! buf := bytes.Buffer{} tpl.Execute(&buf, localTypes) src, err := format.Source(buf.Bytes()) if err != nil { panic(err) } fp, err := os.Create("zzz_types.go") if err != nil { panic(err) } defer fp.Close() fp.Write(src) }