Fail early for invalid key type and size combinations
[ganeti-github.git] / lib / ssh.py
1 #
2 #
3
4 # Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
5 # All rights reserved.
6 #
7 # Redistribution and use in source and binary forms, with or without
8 # modification, are permitted provided that the following conditions are
9 # met:
10 #
11 # 1. Redistributions of source code must retain the above copyright notice,
12 # this list of conditions and the following disclaimer.
13 #
14 # 2. Redistributions in binary form must reproduce the above copyright
15 # notice, this list of conditions and the following disclaimer in the
16 # documentation and/or other materials provided with the distribution.
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
19 # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20 # TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21 # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22 # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23 # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24 # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25 # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26 # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27 # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30
31 """Module encapsulating ssh functionality.
32
33 """
34
35
36 import logging
37 import os
38 import tempfile
39
40 from collections import namedtuple
41 from functools import partial
42
43 from ganeti import utils
44 from ganeti import errors
45 from ganeti import constants
46 from ganeti import netutils
47 from ganeti import pathutils
48 from ganeti import vcluster
49 from ganeti import compat
50 from ganeti import serializer
51 from ganeti import ssconf
52
53
54 def GetUserFiles(user, mkdir=False, dircheck=True, kind=constants.SSHK_DSA,
55 _homedir_fn=None):
56 """Return the paths of a user's SSH files.
57
58 @type user: string
59 @param user: Username
60 @type mkdir: bool
61 @param mkdir: Whether to create ".ssh" directory if it doesn't exist
62 @type dircheck: bool
63 @param dircheck: Whether to check if ".ssh" directory exists
64 @type kind: string
65 @param kind: One of L{constants.SSHK_ALL}
66 @rtype: tuple; (string, string, string)
67 @return: Tuple containing three file system paths; the private SSH key file,
68 the public SSH key file and the user's C{authorized_keys} file
69 @raise errors.OpExecError: When home directory of the user can not be
70 determined
71 @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
72 exception is raised if C{~$user/.ssh} is not a directory and C{dircheck}
73 is set to C{True}
74
75 """
76 if _homedir_fn is None:
77 _homedir_fn = utils.GetHomeDir
78
79 user_dir = _homedir_fn(user)
80 if not user_dir:
81 raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
82
83 if kind == constants.SSHK_DSA:
84 suffix = "dsa"
85 elif kind == constants.SSHK_RSA:
86 suffix = "rsa"
87 elif kind == constants.SSHK_ECDSA:
88 suffix = "ecdsa"
89 else:
90 raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind)
91
92 ssh_dir = utils.PathJoin(user_dir, ".ssh")
93 if mkdir:
94 utils.EnsureDirs([(ssh_dir, constants.SECURE_DIR_MODE)])
95 elif dircheck and not os.path.isdir(ssh_dir):
96 raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
97
98 return [utils.PathJoin(ssh_dir, base)
99 for base in ["id_%s" % suffix, "id_%s.pub" % suffix,
100 "authorized_keys"]]
101
102
103 def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None):
104 """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types.
105
106 See L{GetUserFiles} for details.
107
108 @rtype: tuple; (string, dict with string as key, tuple of (string, string) as
109 value)
110
111 """
112 helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck,
113 _homedir_fn=_homedir_fn)
114 result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL]
115
116 authorized_keys = [i for (_, (_, _, i)) in result]
117
118 assert len(frozenset(authorized_keys)) == 1, \
119 "Different paths for authorized_keys were returned"
120
121 return (authorized_keys[0],
122 dict((kind, (privkey, pubkey))
123 for (kind, (privkey, pubkey, _)) in result))
124
125
126 def _SplitSshKey(key):
127 """Splits a line for SSH's C{authorized_keys} file.
128
129 If the line has no options (e.g. no C{command="..."}), only the significant
130 parts, the key type and its hash, are used. Otherwise the whole line is used
131 (split at whitespace).
132
133 @type key: string
134 @param key: Key line
135 @rtype: tuple
136
137 """
138 parts = key.split()
139
140 if parts and parts[0] in constants.SSHAK_ALL:
141 # If the key has no options in front of it, we only want the significant
142 # fields
143 return (False, parts[:2])
144 else:
145 # Can't properly split the line, so use everything
146 return (True, parts)
147
148
149 def AddAuthorizedKeys(file_obj, keys):
150 """Adds a list of SSH public key to an authorized_keys file.
151
152 @type file_obj: str or file handle
153 @param file_obj: path to authorized_keys file
154 @type keys: list of str
155 @param keys: list of strings containing keys
156
157 """
158 key_field_list = [(key, _SplitSshKey(key)) for key in keys]
159
160 if isinstance(file_obj, basestring):
161 f = open(file_obj, "a+")
162 else:
163 f = file_obj
164
165 try:
166 nl = True
167 for line in f:
168 # Ignore whitespace changes
169 line_key = _SplitSshKey(line)
170 key_field_list[:] = [(key, split_key) for (key, split_key)
171 in key_field_list
172 if split_key != line_key]
173 nl = line.endswith("\n")
174 else:
175 if not nl:
176 f.write("\n")
177 for (key, _) in key_field_list:
178 f.write(key.rstrip("\r\n"))
179 f.write("\n")
180 f.flush()
181 finally:
182 f.close()
183
184
185 def HasAuthorizedKey(file_obj, key):
186 """Check if a particular key is in the 'authorized_keys' file.
187
188 @type file_obj: str or file handle
189 @param file_obj: path to authorized_keys file
190 @type key: str
191 @param key: string containing key
192
193 """
194 key_fields = _SplitSshKey(key)
195
196 if isinstance(file_obj, basestring):
197 f = open(file_obj, "r")
198 else:
199 f = file_obj
200
201 try:
202 for line in f:
203 # Ignore whitespace changes
204 line_key = _SplitSshKey(line)
205 if line_key == key_fields:
206 return True
207 finally:
208 f.close()
209
210 return False
211
212
213 def CheckForMultipleKeys(file_obj, node_names):
214 """Check if there is at most one key per host in 'authorized_keys' file.
215
216 @type file_obj: str or file handle
217 @param file_obj: path to authorized_keys file
218 @type node_names: list of str
219 @param node_names: list of names of nodes of the cluster
220 @returns: a dictionary with hostnames which occur more than once
221
222 """
223
224 if isinstance(file_obj, basestring):
225 f = open(file_obj, "r")
226 else:
227 f = file_obj
228
229 occurrences = {}
230
231 try:
232 index = 0
233 for line in f:
234 index += 1
235 if line.startswith("#"):
236 continue
237 chunks = line.split()
238 # find the chunk with user@hostname
239 user_hostname = [chunk.strip() for chunk in chunks if "@" in chunk][0]
240 if not user_hostname in occurrences:
241 occurrences[user_hostname] = []
242 occurrences[user_hostname].append(index)
243 finally:
244 f.close()
245
246 bad_occurrences = {}
247 for user_hostname, occ in occurrences.items():
248 _, hostname = user_hostname.split("@")
249 if hostname in node_names and len(occ) > 1:
250 bad_occurrences[user_hostname] = occ
251
252 return bad_occurrences
253
254
255 def AddAuthorizedKey(file_obj, key):
256 """Adds an SSH public key to an authorized_keys file.
257
258 @type file_obj: str or file handle
259 @param file_obj: path to authorized_keys file
260 @type key: str
261 @param key: string containing key
262
263 """
264 AddAuthorizedKeys(file_obj, [key])
265
266
267 def RemoveAuthorizedKeys(file_name, keys):
268 """Removes public SSH keys from an authorized_keys file.
269
270 @type file_name: str
271 @param file_name: path to authorized_keys file
272 @type keys: list of str
273 @param keys: list of strings containing keys
274
275 """
276 key_field_list = [_SplitSshKey(key) for key in keys]
277
278 fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
279 try:
280 out = os.fdopen(fd, "w")
281 try:
282 f = open(file_name, "r")
283 try:
284 for line in f:
285 # Ignore whitespace changes while comparing lines
286 if _SplitSshKey(line) not in key_field_list:
287 out.write(line)
288
289 out.flush()
290 os.rename(tmpname, file_name)
291 finally:
292 f.close()
293 finally:
294 out.close()
295 except:
296 utils.RemoveFile(tmpname)
297 raise
298
299
300 def RemoveAuthorizedKey(file_name, key):
301 """Removes an SSH public key from an authorized_keys file.
302
303 @type file_name: str
304 @param file_name: path to authorized_keys file
305 @type key: str
306 @param key: string containing key
307
308 """
309 RemoveAuthorizedKeys(file_name, [key])
310
311
312 def _AddPublicKeyProcessLine(new_uuid, new_key, line_uuid, line_key, found):
313 """Processes one line of the public key file when adding a key.
314
315 This is a sub function that can be called within the
316 C{_ManipulatePublicKeyFile} function. It processes one line of the public
317 key file, checks if this line contains the key to add already and if so,
318 notes the occurrence in the return value.
319
320 @type new_uuid: string
321 @param new_uuid: the node UUID of the node whose key is added
322 @type new_key: string
323 @param new_key: the SSH key to be added
324 @type line_uuid: the UUID of the node whose line in the public key file
325 is processed in this function call
326 @param line_key: the SSH key of the node whose line in the public key
327 file is processed in this function call
328 @type found: boolean
329 @param found: whether or not the (UUID, key) pair of the node whose key
330 is being added was found in the public key file already.
331 @rtype: (boolean, string)
332 @return: a possibly updated value of C{found} and the processed line
333
334 """
335 if line_uuid == new_uuid and line_key == new_key:
336 logging.debug("SSH key of node '%s' already in key file.", new_uuid)
337 found = True
338 return (found, "%s %s\n" % (line_uuid, line_key))
339
340
341 def _AddPublicKeyElse(new_uuid, new_key):
342 """Adds a new SSH key to the key file if it did not exist already.
343
344 This is an auxiliary function for C{_ManipulatePublicKeyFile} which
345 is carried out when a new key is added to the public key file and
346 after processing the whole file, we found out that the key does
347 not exist in the file yet but needs to be appended at the end.
348
349 @type new_uuid: string
350 @param new_uuid: the UUID of the node whose key is added
351 @type new_key: string
352 @param new_key: the SSH key to be added
353 @rtype: string
354 @return: a new line to be added to the file
355
356 """
357 return "%s %s\n" % (new_uuid, new_key)
358
359
360 def _RemovePublicKeyProcessLine(
361 target_uuid, _target_key,
362 line_uuid, line_key, found):
363 """Processes a line in the public key file when aiming for removing a key.
364
365 This is an auxiliary function for C{_ManipulatePublicKeyFile} when we
366 are removing a key from the public key file. This particular function
367 only checks if the current line contains the UUID of the node in
368 question and writes the line to the temporary file otherwise.
369
370 @type target_uuid: string
371 @param target_uuid: UUID of the node whose key is being removed
372 @type _target_key: string
373 @param _target_key: SSH key of the node (not used)
374 @type line_uuid: string
375 @param line_uuid: UUID of the node whose line is processed in this call
376 @type line_key: string
377 @param line_key: SSH key of the nodes whose line is processed in this call
378 @type found: boolean
379 @param found: whether or not the UUID was already found.
380 @rtype: (boolean, string)
381 @return: a tuple, indicating if the target line was found and the processed
382 line; the line is 'None', if the original line is removed
383
384 """
385 if line_uuid != target_uuid:
386 return (found, "%s %s\n" % (line_uuid, line_key))
387 else:
388 return (True, None)
389
390
391 def _RemovePublicKeyElse(
392 target_uuid, _target_key):
393 """Logs when we tried to remove a key that does not exist.
394
395 This is an auxiliary function for C{_ManipulatePublicKeyFile} which is
396 run after we have processed the complete public key file and did not find
397 the key to be removed.
398
399 @type target_uuid: string
400 @param target_uuid: the UUID of the node whose key was supposed to be removed
401 @type _target_key: string
402 @param _target_key: the key of the node which was supposed to be removed
403 (not used)
404 @rtype: string
405 @return: in this case, always None
406
407 """
408 logging.debug("Trying to remove key of node '%s' which is not in list"
409 " of public keys.", target_uuid)
410 return None
411
412
413 def _ReplaceNameByUuidProcessLine(
414 node_name, _key, line_identifier, line_key, found, node_uuid=None):
415 """Replaces a node's name with its UUID on a matching line in the key file.
416
417 This is an auxiliary function for C{_ManipulatePublicKeyFile} which processes
418 a line of the ganeti public key file. If the line in question matches the
419 node's name, the name will be replaced by the node's UUID.
420
421 @type node_name: string
422 @param node_name: name of the node to be replaced by the UUID
423 @type _key: string
424 @param _key: SSH key of the node (not used)
425 @type line_identifier: string
426 @param line_identifier: an identifier of a node in a line of the public key
427 file. This can be either a node name or a node UUID, depending on if it
428 got replaced already or not.
429 @type line_key: string
430 @param line_key: SSH key of the node whose line is processed
431 @type found: boolean
432 @param found: whether or not the line matches the node's name
433 @type node_uuid: string
434 @param node_uuid: the node's UUID which will replace the node name
435 @rtype: (boolean, string)
436 @return: a tuple indicating whether the target line was found and the
437 processed line
438
439 """
440 if node_name == line_identifier:
441 return (True, "%s %s\n" % (node_uuid, line_key))
442 else:
443 return (found, "%s %s\n" % (line_identifier, line_key))
444
445
446 def _ReplaceNameByUuidElse(
447 node_uuid, node_name, _key):
448 """Logs a debug message when we try to replace a key that is not there.
449
450 This is an implementation of the auxiliary C{process_else_fn} function for
451 the C{_ManipulatePubKeyFile} function when we use it to replace a line
452 in the public key file that is indexed by the node's name instead of the
453 node's UUID.
454
455 @type node_uuid: string
456 @param node_uuid: the node's UUID
457 @type node_name: string
458 @param node_name: the node's UUID
459 @type _key: string (not used)
460 @param _key: the node's SSH key (not used)
461 @rtype: string
462 @return: in this case, always None
463
464 """
465 logging.debug("Trying to replace node name '%s' with UUID '%s', but"
466 " no line with that name was found.", node_name, node_uuid)
467 return None
468
469
470 def _ParseKeyLine(line, error_fn):
471 """Parses a line of the public key file.
472
473 @type line: string
474 @param line: line of the public key file
475 @type error_fn: function
476 @param error_fn: function to process error messages
477 @rtype: tuple (string, string)
478 @return: a tuple containing the UUID of the node and a string containing
479 the SSH key and possible more parameters for the key
480
481 """
482 if len(line.rstrip()) == 0:
483 return (None, None)
484 chunks = line.split(" ")
485 if len(chunks) < 2:
486 raise error_fn("Error parsing public SSH key file. Line: '%s'"
487 % line)
488 uuid = chunks[0]
489 key = " ".join(chunks[1:]).rstrip()
490 return (uuid, key)
491
492
493 def _ManipulatePubKeyFile(target_identifier, target_key,
494 key_file=pathutils.SSH_PUB_KEYS,
495 error_fn=errors.ProgrammerError,
496 process_line_fn=None, process_else_fn=None):
497 """Manipulates the list of public SSH keys of the cluster.
498
499 This is a general function to manipulate the public key file. It needs
500 two auxiliary functions C{process_line_fn} and C{process_else_fn} to
501 work. Generally, the public key file is processed as follows:
502 1) The function processes each line of the original ganeti public key file,
503 applies the C{process_line_fn} function on it, which returns a possibly
504 manipulated line and an indicator whether the line in question was found.
505 If a line is returned, it is added to a list of lines for later writing
506 to the file.
507 2) If all lines are processed and the 'found' variable is False, the
508 seconds auxiliary function C{process_else_fn} is called to possibly
509 add more lines to the list of lines.
510 3) Finally, the list of lines is assembled to a string and written
511 atomically to the public key file, thereby overriding it.
512
513 If the public key file does not exist, we create it. This is necessary for
514 a smooth transition after an upgrade.
515
516 @type target_identifier: str
517 @param target_identifier: identifier of the node whose key is added; in most
518 cases this is the node's UUID, but in some it is the node's host name
519 @type target_key: str
520 @param target_key: string containing a public SSH key (a complete line
521 possibly including more parameters than just the key)
522 @type key_file: str
523 @param key_file: filename of the file of public node keys (optional
524 parameter for testing)
525 @type error_fn: function
526 @param error_fn: Function that returns an exception, used to customize
527 exception types depending on the calling context
528 @type process_line_fn: function
529 @param process_line_fn: function to process one line of the public key file
530 @type process_else_fn: function
531 @param process_else_fn: function to be called if no line of the key file
532 matches the target uuid
533
534 """
535 assert process_else_fn is not None
536 assert process_line_fn is not None
537
538 old_lines = []
539 f_orig = None
540 if os.path.exists(key_file):
541 try:
542 f_orig = open(key_file, "r")
543 old_lines = f_orig.readlines()
544 finally:
545 f_orig.close()
546 else:
547 try:
548 f_orig = open(key_file, "w")
549 f_orig.close()
550 except IOError as e:
551 raise errors.SshUpdateError("Cannot create public key file: %s" % e)
552
553 found = False
554 new_lines = []
555 for line in old_lines:
556 (uuid, key) = _ParseKeyLine(line, error_fn)
557 if not uuid:
558 continue
559 (new_found, new_line) = process_line_fn(target_identifier, target_key,
560 uuid, key, found)
561 if new_found:
562 found = True
563 if new_line is not None:
564 new_lines.append(new_line)
565 if not found:
566 new_line = process_else_fn(target_identifier, target_key)
567 if new_line is not None:
568 new_lines.append(new_line)
569 new_file_content = "".join(new_lines)
570 utils.WriteFile(key_file, data=new_file_content)
571
572
573 def AddPublicKey(new_uuid, new_key, key_file=pathutils.SSH_PUB_KEYS,
574 error_fn=errors.ProgrammerError):
575 """Adds a new key to the list of public keys.
576
577 @see: _ManipulatePubKeyFile for parameter descriptions.
578
579 """
580 _ManipulatePubKeyFile(new_uuid, new_key, key_file=key_file,
581 process_line_fn=_AddPublicKeyProcessLine,
582 process_else_fn=_AddPublicKeyElse,
583 error_fn=error_fn)
584
585
586 def RemovePublicKey(target_uuid, key_file=pathutils.SSH_PUB_KEYS,
587 error_fn=errors.ProgrammerError):
588 """Removes a key from the list of public keys.
589
590 @see: _ManipulatePubKeyFile for parameter descriptions.
591
592 """
593 _ManipulatePubKeyFile(target_uuid, None, key_file=key_file,
594 process_line_fn=_RemovePublicKeyProcessLine,
595 process_else_fn=_RemovePublicKeyElse,
596 error_fn=error_fn)
597
598
599 def ReplaceNameByUuid(node_uuid, node_name, key_file=pathutils.SSH_PUB_KEYS,
600 error_fn=errors.ProgrammerError):
601 """Replaces a host name with the node's corresponding UUID.
602
603 When a node is added to the cluster, we don't know it's UUID yet. So first
604 its SSH key gets added to the public key file and in a second step, the
605 node's name gets replaced with the node's UUID as soon as we know the UUID.
606
607 @type node_uuid: string
608 @param node_uuid: the node's UUID to replace the node's name
609 @type node_name: string
610 @param node_name: the node's name to be replaced by the node's UUID
611
612 @see: _ManipulatePubKeyFile for the other parameter descriptions.
613
614 """
615 process_line_fn = partial(_ReplaceNameByUuidProcessLine, node_uuid=node_uuid)
616 process_else_fn = partial(_ReplaceNameByUuidElse, node_uuid=node_uuid)
617 _ManipulatePubKeyFile(node_name, None, key_file=key_file,
618 process_line_fn=process_line_fn,
619 process_else_fn=process_else_fn,
620 error_fn=error_fn)
621
622
623 def ClearPubKeyFile(key_file=pathutils.SSH_PUB_KEYS, mode=0600):
624 """Resets the content of the public key file.
625
626 """
627 utils.WriteFile(key_file, data="", mode=mode)
628
629
630 def OverridePubKeyFile(key_map, key_file=pathutils.SSH_PUB_KEYS):
631 """Overrides the public key file with a list of given keys.
632
633 @type key_map: dict from str to list of str
634 @param key_map: dictionary mapping uuids to lists of SSH keys
635
636 """
637 new_lines = []
638 for (uuid, keys) in key_map.items():
639 for key in keys:
640 new_lines.append("%s %s\n" % (uuid, key))
641 new_file_content = "".join(new_lines)
642 utils.WriteFile(key_file, data=new_file_content)
643
644
645 def QueryPubKeyFile(target_uuids, key_file=pathutils.SSH_PUB_KEYS,
646 error_fn=errors.ProgrammerError):
647 """Retrieves a map of keys for the requested node UUIDs.
648
649 @type target_uuids: str or list of str
650 @param target_uuids: UUID of the node to retrieve the key for or a list
651 of UUIDs of nodes to retrieve the keys for
652 @type key_file: str
653 @param key_file: filename of the file of public node keys (optional
654 parameter for testing)
655 @type error_fn: function
656 @param error_fn: Function that returns an exception, used to customize
657 exception types depending on the calling context
658 @rtype: dict mapping strings to list of strings
659 @return: dictionary mapping node uuids to their ssh keys
660
661 """
662 all_keys = target_uuids is None
663 if isinstance(target_uuids, str):
664 target_uuids = [target_uuids]
665 result = {}
666 f = open(key_file, "r")
667 try:
668 for line in f:
669 (uuid, key) = _ParseKeyLine(line, error_fn)
670 if not uuid:
671 continue
672 if all_keys or (uuid in target_uuids):
673 if uuid not in result:
674 result[uuid] = []
675 result[uuid].append(key)
676 finally:
677 f.close()
678 return result
679
680
681 def InitSSHSetup(key_type, key_bits, error_fn=errors.OpPrereqError,
682 _homedir_fn=None, _suffix=""):
683 """Setup the SSH configuration for the node.
684
685 This generates a dsa keypair for root, adds the pub key to the
686 permitted hosts and adds the hostkey to its own known hosts.
687
688 @param key_type: the type of SSH keypair to be generated
689 @param key_bits: the key length, in bits, to be used
690
691 """
692 priv_key, _, auth_keys = GetUserFiles(constants.SSH_LOGIN_USER, kind=key_type,
693 mkdir=True, _homedir_fn=_homedir_fn)
694
695 new_priv_key_name = priv_key + _suffix
696 new_pub_key_name = priv_key + _suffix + ".pub"
697
698 for name in new_priv_key_name, new_pub_key_name:
699 if os.path.exists(name):
700 utils.CreateBackup(name)
701 utils.RemoveFile(name)
702
703 result = utils.RunCmd(["ssh-keygen", "-b", str(key_bits), "-t", key_type,
704 "-f", new_priv_key_name,
705 "-q", "-N", ""])
706 if result.failed:
707 raise error_fn("Could not generate ssh keypair, error %s" %
708 result.output)
709
710 AddAuthorizedKey(auth_keys, utils.ReadFile(new_pub_key_name))
711
712
713 def InitPubKeyFile(master_uuid, key_type, key_file=pathutils.SSH_PUB_KEYS):
714 """Creates the public key file and adds the master node's SSH key.
715
716 @type master_uuid: str
717 @param master_uuid: the master node's UUID
718 @type key_type: one of L{constants.SSHK_ALL}
719 @param key_type: the type of ssh key to be used
720 @type key_file: str
721 @param key_file: name of the file containing the public keys
722
723 """
724 _, pub_key, _ = GetUserFiles(constants.SSH_LOGIN_USER, kind=key_type)
725 ClearPubKeyFile(key_file=key_file)
726 key = utils.ReadFile(pub_key)
727 AddPublicKey(master_uuid, key, key_file=key_file)
728
729
730 class SshRunner:
731 """Wrapper for SSH commands.
732
733 """
734 def __init__(self, cluster_name):
735 """Initializes this class.
736
737 @type cluster_name: str
738 @param cluster_name: name of the cluster
739
740 """
741 self.cluster_name = cluster_name
742 family = ssconf.SimpleStore().GetPrimaryIPFamily()
743 self.ipv6 = (family == netutils.IP6Address.family)
744
745 def _BuildSshOptions(self, batch, ask_key, use_cluster_key,
746 strict_host_check, private_key=None, quiet=True,
747 port=None):
748 """Builds a list with needed SSH options.
749
750 @param batch: same as ssh's batch option
751 @param ask_key: allows ssh to ask for key confirmation; this
752 parameter conflicts with the batch one
753 @param use_cluster_key: if True, use the cluster name as the
754 HostKeyAlias name
755 @param strict_host_check: this makes the host key checking strict
756 @param private_key: use this private key instead of the default
757 @param quiet: whether to enable -q to ssh
758 @param port: the SSH port to use, or None to use the default
759
760 @rtype: list
761 @return: the list of options ready to use in L{utils.process.RunCmd}
762
763 """
764 options = [
765 "-oEscapeChar=none",
766 "-oHashKnownHosts=no",
767 "-oGlobalKnownHostsFile=%s" % pathutils.SSH_KNOWN_HOSTS_FILE,
768 "-oUserKnownHostsFile=/dev/null",
769 "-oCheckHostIp=no",
770 ]
771
772 if use_cluster_key:
773 options.append("-oHostKeyAlias=%s" % self.cluster_name)
774
775 if quiet:
776 options.append("-q")
777
778 if private_key:
779 options.append("-i%s" % private_key)
780
781 if port:
782 options.append("-oPort=%d" % port)
783
784 # TODO: Too many boolean options, maybe convert them to more descriptive
785 # constants.
786
787 # Note: ask_key conflicts with batch mode
788 if batch:
789 if ask_key:
790 raise errors.ProgrammerError("SSH call requested conflicting options")
791
792 options.append("-oBatchMode=yes")
793
794 if strict_host_check:
795 options.append("-oStrictHostKeyChecking=yes")
796 else:
797 options.append("-oStrictHostKeyChecking=no")
798
799 else:
800 # non-batch mode
801
802 if ask_key:
803 options.append("-oStrictHostKeyChecking=ask")
804 elif strict_host_check:
805 options.append("-oStrictHostKeyChecking=yes")
806 else:
807 options.append("-oStrictHostKeyChecking=no")
808
809 if self.ipv6:
810 options.append("-6")
811 else:
812 options.append("-4")
813
814 return options
815
816 def BuildCmd(self, hostname, user, command, batch=True, ask_key=False,
817 tty=False, use_cluster_key=True, strict_host_check=True,
818 private_key=None, quiet=True, port=None):
819 """Build an ssh command to execute a command on a remote node.
820
821 @param hostname: the target host, string
822 @param user: user to auth as
823 @param command: the command
824 @param batch: if true, ssh will run in batch mode with no prompting
825 @param ask_key: if true, ssh will run with
826 StrictHostKeyChecking=ask, so that we can connect to an
827 unknown host (not valid in batch mode)
828 @param use_cluster_key: whether to expect and use the
829 cluster-global SSH key
830 @param strict_host_check: whether to check the host's SSH key at all
831 @param private_key: use this private key instead of the default
832 @param quiet: whether to enable -q to ssh
833 @param port: the SSH port on which the node's daemon is running
834
835 @return: the ssh call to run 'command' on the remote host.
836
837 """
838 argv = [constants.SSH]
839 argv.extend(self._BuildSshOptions(batch, ask_key, use_cluster_key,
840 strict_host_check, private_key,
841 quiet=quiet, port=port))
842 if tty:
843 argv.extend(["-t", "-t"])
844
845 argv.append("%s@%s" % (user, hostname))
846
847 # Insert variables for virtual nodes
848 argv.extend("export %s=%s;" %
849 (utils.ShellQuote(name), utils.ShellQuote(value))
850 for (name, value) in
851 vcluster.EnvironmentForHost(hostname).items())
852
853 argv.append(command)
854
855 return argv
856
857 def Run(self, *args, **kwargs):
858 """Runs a command on a remote node.
859
860 This method has the same return value as `utils.RunCmd()`, which it
861 uses to launch ssh.
862
863 Args: see SshRunner.BuildCmd.
864
865 @rtype: L{utils.process.RunResult}
866 @return: the result as from L{utils.process.RunCmd()}
867
868 """
869 return utils.RunCmd(self.BuildCmd(*args, **kwargs))
870
871 def CopyFileToNode(self, node, port, filename):
872 """Copy a file to another node with scp.
873
874 @param node: node in the cluster
875 @param filename: absolute pathname of a local file
876
877 @rtype: boolean
878 @return: the success of the operation
879
880 """
881 if not os.path.isabs(filename):
882 logging.error("File %s must be an absolute path", filename)
883 return False
884
885 if not os.path.isfile(filename):
886 logging.error("File %s does not exist", filename)
887 return False
888
889 command = [constants.SCP, "-p"]
890 command.extend(self._BuildSshOptions(True, False, True, True, port=port))
891 command.append(filename)
892 if netutils.IP6Address.IsValid(node):
893 node = netutils.FormatAddress((node, None))
894
895 command.append("%s:%s" % (node, vcluster.ExchangeNodeRoot(node, filename)))
896
897 result = utils.RunCmd(command)
898
899 if result.failed:
900 logging.error("Copy to node %s failed (%s) error '%s',"
901 " command was '%s'",
902 node, result.fail_reason, result.output, result.cmd)
903
904 return not result.failed
905
906 def VerifyNodeHostname(self, node, ssh_port):
907 """Verify hostname consistency via SSH.
908
909 This functions connects via ssh to a node and compares the hostname
910 reported by the node to the name with have (the one that we
911 connected to).
912
913 This is used to detect problems in ssh known_hosts files
914 (conflicting known hosts) and inconsistencies between dns/hosts
915 entries and local machine names
916
917 @param node: nodename of a host to check; can be short or
918 full qualified hostname
919 @param ssh_port: the port of a SSH daemon running on the node
920
921 @return: (success, detail), where:
922 - success: True/False
923 - detail: string with details
924
925 """
926 cmd = ("if test -z \"$GANETI_HOSTNAME\"; then"
927 " hostname --fqdn;"
928 "else"
929 " echo \"$GANETI_HOSTNAME\";"
930 "fi")
931 retval = self.Run(node, constants.SSH_LOGIN_USER, cmd,
932 quiet=False, port=ssh_port)
933
934 if retval.failed:
935 msg = "ssh problem"
936 output = retval.output
937 if output:
938 msg += ": %s" % output
939 else:
940 msg += ": %s (no output)" % retval.fail_reason
941 logging.error("Command %s failed: %s", retval.cmd, msg)
942 return False, msg
943
944 remotehostname = retval.stdout.strip()
945
946 if not remotehostname or remotehostname != node:
947 if node.startswith(remotehostname + "."):
948 msg = "hostname not FQDN"
949 else:
950 msg = "hostname mismatch"
951 return False, ("%s: expected %s but got %s" %
952 (msg, node, remotehostname))
953
954 return True, "host matches"
955
956
957 def WriteKnownHostsFile(cfg, file_name):
958 """Writes the cluster-wide equally known_hosts file.
959
960 """
961 data = ""
962 if cfg.GetRsaHostKey():
963 data += "%s ssh-rsa %s\n" % (cfg.GetClusterName(), cfg.GetRsaHostKey())
964 if cfg.GetDsaHostKey():
965 data += "%s ssh-dss %s\n" % (cfg.GetClusterName(), cfg.GetDsaHostKey())
966
967 utils.WriteFile(file_name, mode=0600, data=data)
968
969
970 def _EnsureCorrectGanetiVersion(cmd):
971 """Ensured the correct Ganeti version before running a command via SSH.
972
973 Before a command is run on a node via SSH, it makes sense in some
974 situations to ensure that this node is indeed running the correct
975 version of Ganeti like the rest of the cluster.
976
977 @type cmd: string
978 @param cmd: string
979 @rtype: list of strings
980 @return: a list of commands with the newly added ones at the beginning
981
982 """
983 logging.debug("Ensure correct Ganeti version: %s", cmd)
984
985 version = constants.DIR_VERSION
986 all_cmds = [["test", "-d", os.path.join(pathutils.PKGLIBDIR, version)]]
987 if constants.HAS_GNU_LN:
988 all_cmds.extend([["ln", "-s", "-f", "-T",
989 os.path.join(pathutils.PKGLIBDIR, version),
990 os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")],
991 ["ln", "-s", "-f", "-T",
992 os.path.join(pathutils.SHAREDIR, version),
993 os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]])
994 else:
995 all_cmds.extend([["rm", "-f",
996 os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")],
997 ["ln", "-s", "-f",
998 os.path.join(pathutils.PKGLIBDIR, version),
999 os.path.join(pathutils.SYSCONFDIR, "ganeti/lib")],
1000 ["rm", "-f",
1001 os.path.join(pathutils.SYSCONFDIR, "ganeti/share")],
1002 ["ln", "-s", "-f",
1003 os.path.join(pathutils.SHAREDIR, version),
1004 os.path.join(pathutils.SYSCONFDIR, "ganeti/share")]])
1005 all_cmds.append(cmd)
1006 return all_cmds
1007
1008
1009 def RunSshCmdWithStdin(cluster_name, node, basecmd, port, data,
1010 debug=False, verbose=False, use_cluster_key=False,
1011 ask_key=False, strict_host_check=False,
1012 ensure_version=False):
1013 """Runs a command on a remote machine via SSH and provides input in stdin.
1014
1015 @type cluster_name: string
1016 @param cluster_name: Cluster name
1017 @type node: string
1018 @param node: Node name
1019 @type basecmd: string
1020 @param basecmd: Base command (path on the remote machine)
1021 @type port: int
1022 @param port: The SSH port of the remote machine or None for the default
1023 @param data: JSON-serializable input data for script (passed to stdin)
1024 @type debug: bool
1025 @param debug: Enable debug output
1026 @type verbose: bool
1027 @param verbose: Enable verbose output
1028 @type use_cluster_key: bool
1029 @param use_cluster_key: See L{ssh.SshRunner.BuildCmd}
1030 @type ask_key: bool
1031 @param ask_key: See L{ssh.SshRunner.BuildCmd}
1032 @type strict_host_check: bool
1033 @param strict_host_check: See L{ssh.SshRunner.BuildCmd}
1034
1035 """
1036 cmd = [basecmd]
1037
1038 # Pass --debug/--verbose to the external script if set on our invocation
1039 if debug:
1040 cmd.append("--debug")
1041
1042 if verbose:
1043 cmd.append("--verbose")
1044
1045 if ensure_version:
1046 all_cmds = _EnsureCorrectGanetiVersion(cmd)
1047 else:
1048 all_cmds = [cmd]
1049
1050 if port is None:
1051 port = netutils.GetDaemonPort(constants.SSH)
1052
1053 srun = SshRunner(cluster_name)
1054 scmd = srun.BuildCmd(node, constants.SSH_LOGIN_USER,
1055 utils.ShellQuoteArgs(
1056 utils.ShellCombineCommands(all_cmds)),
1057 batch=False, ask_key=ask_key, quiet=False,
1058 strict_host_check=strict_host_check,
1059 use_cluster_key=use_cluster_key,
1060 port=port)
1061
1062 tempfh = tempfile.TemporaryFile()
1063 try:
1064 tempfh.write(serializer.DumpJson(data))
1065 tempfh.seek(0)
1066
1067 result = utils.RunCmd(scmd, interactive=True, input_fd=tempfh)
1068 finally:
1069 tempfh.close()
1070
1071 if result.failed:
1072 raise errors.OpExecError("Command '%s' failed: %s" %
1073 (result.cmd, result.fail_reason))
1074
1075
1076 def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
1077 strict_host_check):
1078 """Fetches a public SSH key from a node via SSH.
1079
1080 @type pub_key_file: string
1081 @param pub_key_file: a tuple consisting of the file name of the public DSA key
1082
1083 """
1084 ssh_runner = SshRunner(cluster_name)
1085
1086 cmd = ["cat", pub_key_file]
1087 ssh_cmd = ssh_runner.BuildCmd(node, constants.SSH_LOGIN_USER,
1088 utils.ShellQuoteArgs(cmd),
1089 batch=False, ask_key=ask_key, quiet=False,
1090 strict_host_check=strict_host_check,
1091 use_cluster_key=False,
1092 port=port)
1093
1094 result = utils.RunCmd(ssh_cmd)
1095 if result.failed:
1096 raise errors.OpPrereqError("Could not fetch a public SSH key (%s) from node"
1097 " '%s': ran command '%s', failure reason: '%s'."
1098 % (pub_key_file, node, cmd, result.fail_reason),
1099 errors.ECODE_INVAL)
1100 return result.stdout
1101
1102
1103 KeyBitInfo = namedtuple('KeyBitInfo', ['default', 'validation_fn'])
1104 SSH_KEY_VALID_BITS = {
1105 constants.SSHK_DSA: KeyBitInfo(1024, lambda b: b == 1024),
1106 constants.SSHK_RSA: KeyBitInfo(2048, lambda b: b >= 768),
1107 constants.SSHK_ECDSA: KeyBitInfo(384, lambda b: b in [256, 384, 521]),
1108 }
1109
1110
1111 def DetermineKeyBits(key_type, key_bits, old_key_type, old_key_bits):
1112 """Checks the key bits to be used for a given key type, or provides defaults.
1113
1114 @type key_type: one of L{constants.SSHK_ALL}
1115 @param key_type: The key type to use.
1116 @type key_bits: positive int or None
1117 @param key_bits: The number of bits to use, if supplied by user.
1118 @type old_key_type: one of L{constants.SSHK_ALL} or None
1119 @param old_key_type: The previously used key type, if any.
1120 @type old_key_bits: positive int or None
1121 @param old_key_bits: The previously used number of bits, if any.
1122
1123 @rtype: positive int
1124 @return: The number of bits to use.
1125
1126 """
1127 if key_bits is None:
1128 if old_key_type is not None and old_key_type == key_type:
1129 key_bits = old_key_bits
1130 else:
1131 key_bits = SSH_KEY_VALID_BITS[key_type].default
1132
1133 if not SSH_KEY_VALID_BITS[key_type].validation_fn(key_bits):
1134 raise errors.OpPrereqError("Invalid key type and bit size combination:"
1135 " %s with %s bits" % (key_type, key_bits),
1136 errors.ECODE_INVAL)
1137
1138 return key_bits