summaryrefslogtreecommitdiff
path: root/generate_dns_types.go
diff options
context:
space:
mode:
Diffstat (limited to 'generate_dns_types.go')
-rw-r--r--generate_dns_types.go192
1 files changed, 192 insertions, 0 deletions
diff --git a/generate_dns_types.go b/generate_dns_types.go
new file mode 100644
index 0000000..282c630
--- /dev/null
+++ b/generate_dns_types.go
@@ -0,0 +1,192 @@
1//+build ignore
2
3package main
4
5import (
6 "fmt"
7 "go/types"
8 "log"
9 "os"
10 "text/template"
11
12 "golang.org/x/tools/go/packages"
13)
14
15type Field struct {
16 Name string
17 Type string
18}
19
20var tpl = template.Must(template.New("").Parse(`package dns
21
22// GENERATED FILE, DO NOT MODIFY
23// See generate_dns_types.go in the repo root.
24
25import (
26 "encoding/json"
27 "fmt"
28 "net"
29
30 "github.com/miekg/dns"
31
32 "code.crute.me/mcrute/go_ddns_manager/bind"
33)
34
35{{ range $name, $fields := . -}}
36type {{ $name }} struct {
37 Name string
38 Ttl int
39 {{ range $fields -}}
40 {{ .Name }} {{ .Type }}
41 {{ end -}}
42}
43
44func (r *{{ $name }}) ToDNS(zone *bind.Zone) dns.RR {
45 return &dns.{{ $name }}{
46 Hdr: makeHeader(r.Name, zone, dns.Type{{ $name }}, r.Ttl),
47 {{ range $fields -}}
48 {{ .Name }}: r.{{ .Name }},
49 {{ end }}
50 }
51}
52
53func (r *{{ $name }}) FromDNS(rr dns.RR) error {
54 rt, ok := rr.(*dns.{{ $name }})
55 if !ok {
56 return fmt.Errorf("Invalid type %T for '{{ $name }}'", rr)
57 }
58
59 r.Name = rr.Header().Name
60 r.Ttl = int(rr.Header().Ttl)
61 {{ range $fields -}}
62 r.{{ .Name }} = rt.{{ .Name }}
63 {{ end }}
64
65 return nil
66}
67
68func (r *{{ $name }}) MarshalJSON() ([]byte, error) {
69 type Alias {{ $name }}
70 return json.Marshal(&struct {
71 Type string
72 *Alias
73 }{"{{ $name }}", (*Alias)(r)})
74}
75
76func (r *{{ $name }}) UnmarshalJSON(data []byte) error {
77 type Alias {{ $name }}
78 if err := json.Unmarshal(data, &struct{
79 Type string
80 *Alias
81 }{Alias: (*Alias)(r)}); err != nil {
82 return err
83 }
84 return nil
85}
86
87var _ RR = (*{{ $name }})(nil)
88
89{{ end }}
90
91
92func FromDNS(rr dns.RR) interface{} {
93 switch v := rr.(type) {
94 {{ range $name, $fields := . -}}
95 case *dns.{{ $name }}:
96 rv := &{{ $name }}{}
97 rv.FromDNS(v)
98 return rv
99 {{ end }}
100 }
101 return nil
102}
103`))
104
105var disallowedTypes = map[string]bool{
106 "CDNSKEY": true,
107 "CDS": true,
108 "DLV": true,
109 "KEY": true,
110 "OPT": true,
111 "SIG": true,
112 "PrivateRR": true,
113 "RFC3597": true,
114 "ANY": true,
115}
116
117var allowedPackages = map[string]bool{
118 "net": true,
119}
120
121func main() {
122 conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedDeps}
123 pkgs, err := packages.Load(&conf, "github.com/miekg/dns")
124 if err != nil {
125 panic(err)
126 }
127
128 scope := pkgs[0].Types.Scope()
129 localTypes := map[string][]Field{}
130
131 for _, name := range scope.Names() {
132 o := scope.Lookup(name)
133 if o == nil || !o.Exported() {
134 continue
135 }
136
137 // Only consider structs
138 st, ok := o.Type().Underlying().(*types.Struct)
139 if !ok {
140 continue
141 }
142
143 name := o.Name()
144
145 // Explicitly disallow some types that have complex embedded types
146 if _, skip := disallowedTypes[name]; skip {
147 continue
148 }
149
150 // There must be a type constant for this
151 if scope.Lookup(fmt.Sprintf("Type%s", name)) == nil {
152 continue
153 }
154
155 fields := []Field{}
156 for i := 0; i < st.NumFields(); i++ {
157 f := st.Field(i)
158
159 // Exclude header field
160 if f.Name() == "Hdr" {
161 continue
162 }
163
164 // Fail if there are complex types embedded
165 if tp, ok := f.Type().(*types.Named); ok {
166 if _, ok := allowedPackages[tp.Obj().Pkg().Path()]; !ok {
167 log.Fatalf("Invalid embedded complex type: %s", tp)
168 }
169 }
170
171 // Also fail if there are complex types embedded in a slice
172 if tp, ok := f.Type().(*types.Slice); ok {
173 if ut, ok := tp.Elem().(*types.Named); ok {
174 if _, ok := allowedPackages[ut.Obj().Pkg().Path()]; !ok {
175 log.Fatalf("Invalid embedded complex type: %s", tp)
176 }
177 }
178 }
179
180 fields = append(fields, Field{f.Name(), f.Type().String()})
181 }
182
183 localTypes[name] = fields
184 }
185
186 fp, err := os.Create("zzz_types.go")
187 if err != nil {
188 panic(err)
189 }
190 defer fp.Close()
191 tpl.Execute(fp, localTypes)
192}