diff options
Diffstat (limited to 'sshagent.py')
-rw-r--r-- | sshagent.py | 33 |
1 files changed, 19 insertions, 14 deletions
diff --git a/sshagent.py b/sshagent.py index 2a43f5f..32dc2f4 100644 --- a/sshagent.py +++ b/sshagent.py | |||
@@ -10,6 +10,7 @@ import os | |||
10 | import socket | 10 | import socket |
11 | import struct | 11 | import struct |
12 | 12 | ||
13 | from structutils import int_to_bytes | ||
13 | from structutils import pack_string, pack_int | 14 | from structutils import pack_string, pack_int |
14 | from structutils import unpack_int, unpack_string, unpack_mp_int | 15 | from structutils import unpack_int, unpack_string, unpack_mp_int |
15 | 16 | ||
@@ -22,13 +23,14 @@ class SSHAgent(object): | |||
22 | SSH2_AGENT_SIGN_RESPONSE = 14 | 23 | SSH2_AGENT_SIGN_RESPONSE = 14 |
23 | SSH2_AGENTC_SIGN_REQUEST = 13 | 24 | SSH2_AGENTC_SIGN_REQUEST = 13 |
24 | 25 | ||
25 | def __init__(self, socket_path): | 26 | def __init__(self, socket_path=None): |
26 | default_path = os.environ.get('SSH_AUTH_SOCK') | 27 | default_path = os.environ.get('SSH_AUTH_SOCK') |
27 | socket_path = default_path if not socket_path else socket_path | 28 | socket_path = default_path if not socket_path else socket_path |
28 | 29 | ||
29 | if not socket_path: | 30 | if not socket_path: |
30 | raise ValueError("Could not find an ssh agent.") | 31 | raise ValueError("Could not find an ssh agent.") |
31 | 32 | ||
33 | self.socket_path = socket_path | ||
32 | self.socket = None | 34 | self.socket = None |
33 | 35 | ||
34 | def connect(self): | 36 | def connect(self): |
@@ -43,16 +45,29 @@ class SSHAgent(object): | |||
43 | to_send = ''.join([chr(SSHAgent.SSH2_AGENTC_SIGN_REQUEST), | 45 | to_send = ''.join([chr(SSHAgent.SSH2_AGENTC_SIGN_REQUEST), |
44 | key, data, flags]) | 46 | key, data, flags]) |
45 | pkt_length = len(to_send) | 47 | pkt_length = len(to_send) |
46 | packet = pack_int(pkg_length) + to_send | 48 | packet = pack_int(pkt_length) + to_send |
47 | 49 | ||
48 | return packet | 50 | return packet |
49 | 51 | ||
52 | def sign(self, data, key): | ||
53 | if not self.socket: | ||
54 | self.connect() | ||
55 | |||
56 | packet = self._build_packet(data, key) | ||
57 | |||
58 | remaining = 0 | ||
59 | while remaining < len(packet): | ||
60 | sent = self.socket.send(packet[remaining:]) | ||
61 | remaining += sent | ||
62 | |||
63 | return self._parse_response() | ||
64 | |||
50 | def _parse_response(self): | 65 | def _parse_response(self): |
51 | response_length = unpack_int(self.socket.recv(4, socket.MSG_WAITALL))[0] | 66 | response_length = unpack_int(self.socket.recv(4, socket.MSG_WAITALL))[0] |
52 | if response_length == 1: | 67 | if response_length == 1: |
53 | raise ValueError("Agent failed") | 68 | raise ValueError("Agent failed") |
54 | 69 | ||
55 | response = auth_sock.recv(response_length, socket.MSG_WAITALL) | 70 | response = self.socket.recv(response_length, socket.MSG_WAITALL) |
56 | 71 | ||
57 | status = ord(response[0]) | 72 | status = ord(response[0]) |
58 | if status != SSHAgent.SSH2_AGENT_SIGN_RESPONSE: | 73 | if status != SSHAgent.SSH2_AGENT_SIGN_RESPONSE: |
@@ -62,14 +77,4 @@ class SSHAgent(object): | |||
62 | _, remainder = unpack_string(remainder) | 77 | _, remainder = unpack_string(remainder) |
63 | response, _ = unpack_mp_int(remainder) | 78 | response, _ = unpack_mp_int(remainder) |
64 | 79 | ||
65 | return response | 80 | return int_to_bytes(response) |
66 | |||
67 | def sign(self, data, key): | ||
68 | packet = self._build_packet(data, key) | ||
69 | |||
70 | remaining = 0 | ||
71 | while remaining < len(packet): | ||
72 | sent = self.socket.send(packet[remaining:]) | ||
73 | remaining += sent | ||
74 | |||
75 | return self._parse_response() | ||