diff options
Diffstat (limited to 'dns/client.go')
-rw-r--r-- | dns/client.go | 70 |
1 files changed, 69 insertions, 1 deletions
diff --git a/dns/client.go b/dns/client.go index 525444a..f39bb4b 100644 --- a/dns/client.go +++ b/dns/client.go | |||
@@ -1,7 +1,9 @@ | |||
1 | package dns | 1 | package dns |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "context" | ||
4 | "fmt" | 5 | "fmt" |
6 | "time" | ||
5 | 7 | ||
6 | "github.com/miekg/dns" | 8 | "github.com/miekg/dns" |
7 | 9 | ||
@@ -9,7 +11,9 @@ import ( | |||
9 | ) | 11 | ) |
10 | 12 | ||
11 | type DNSClient struct { | 13 | type DNSClient struct { |
12 | Server string | 14 | Server string |
15 | RecursiveResolvers []string | ||
16 | PollTimeout time.Duration | ||
13 | } | 17 | } |
14 | 18 | ||
15 | type DNSTransaction struct { | 19 | type DNSTransaction struct { |
@@ -142,3 +146,67 @@ func (c *DNSClient) SendQuery(t *DNSTransaction) ([]dns.RR, error) { | |||
142 | 146 | ||
143 | return in.Answer, nil | 147 | return in.Answer, nil |
144 | } | 148 | } |
149 | |||
150 | // TODO: Copied from the letsencrypt service, merge this into existing functions | ||
151 | func (c *DNSClient) sendReadQuery(ctx context.Context, fqdn string, rtype uint16, nameserver string) (*dns.Msg, error) { | ||
152 | udp := &dns.Client{Net: "udp"} | ||
153 | tcp := &dns.Client{Net: "tcp"} | ||
154 | |||
155 | m := &dns.Msg{} | ||
156 | m.SetQuestion(fqdn, rtype) | ||
157 | m.SetEdns0(4096, false) | ||
158 | m.RecursionDesired = true | ||
159 | |||
160 | in, _, err := udp.ExchangeContext(ctx, m, nameserver) | ||
161 | if in != nil && in.Truncated { | ||
162 | // If the TCP request succeeds, the err will reset to nil | ||
163 | in, _, err = tcp.ExchangeContext(ctx, m, nameserver) | ||
164 | } | ||
165 | |||
166 | if err != nil { | ||
167 | return nil, err | ||
168 | } | ||
169 | |||
170 | return in, err | ||
171 | } | ||
172 | |||
173 | func (c *DNSClient) WaitForDNSPropagation(ctx context.Context, fqdn, value string) error { | ||
174 | if c.RecursiveResolvers == nil { | ||
175 | return fmt.Errorf("DNSClient.WaitForDNSPropagation: RecursiveResolvers not set") | ||
176 | } | ||
177 | |||
178 | pt := c.PollTimeout | ||
179 | if pt == 0 { | ||
180 | pt = 3 * time.Second | ||
181 | } | ||
182 | |||
183 | timer := time.NewTicker(pt) | ||
184 | defer timer.Stop() | ||
185 | |||
186 | for { | ||
187 | // Give the server the initial timout to satisfy the request | ||
188 | select { | ||
189 | case <-ctx.Done(): | ||
190 | return fmt.Errorf("DNSClient.WaitForDNSPropagation: context has expired, polling terminated") | ||
191 | case <-timer.C: | ||
192 | } | ||
193 | |||
194 | ok_count := 0 | ||
195 | for _, rs := range c.RecursiveResolvers { | ||
196 | r, err := c.sendReadQuery(ctx, fqdn, dns.TypeTXT, rs) | ||
197 | if err != nil { | ||
198 | return err | ||
199 | } | ||
200 | |||
201 | if len(r.Answer) > 0 { | ||
202 | if r.Answer[0].(*dns.TXT).Txt[0] == value { | ||
203 | ok_count++ | ||
204 | } | ||
205 | } | ||
206 | } | ||
207 | |||
208 | if ok_count == len(c.RecursiveResolvers) { | ||
209 | return nil | ||
210 | } | ||
211 | } | ||
212 | } | ||