Fail early for invalid key type and size combinations
[ganeti-github.git] / lib / ssh.py
index d2684fc..7b27214 100644 (file)
@@ -37,6 +37,7 @@ import logging
 import os
 import tempfile
 
+from collections import namedtuple
 from functools import partial
 
 from ganeti import utils
@@ -1094,5 +1095,44 @@ def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
   if result.failed:
     raise errors.OpPrereqError("Could not fetch a public SSH key (%s) from node"
                                " '%s': ran command '%s', failure reason: '%s'."
-                               % (pub_key_file, node, cmd, result.fail_reason))
+                               % (pub_key_file, node, cmd, result.fail_reason),
+                               errors.ECODE_INVAL)
   return result.stdout
+
+
+KeyBitInfo = namedtuple('KeyBitInfo', ['default', 'validation_fn'])
+SSH_KEY_VALID_BITS = {
+  constants.SSHK_DSA: KeyBitInfo(1024, lambda b: b == 1024),
+  constants.SSHK_RSA: KeyBitInfo(2048, lambda b: b >= 768),
+  constants.SSHK_ECDSA: KeyBitInfo(384, lambda b: b in [256, 384, 521]),
+}
+
+
+def DetermineKeyBits(key_type, key_bits, old_key_type, old_key_bits):
+  """Checks the key bits to be used for a given key type, or provides defaults.
+
+  @type key_type: one of L{constants.SSHK_ALL}
+  @param key_type: The key type to use.
+  @type key_bits: positive int or None
+  @param key_bits: The number of bits to use, if supplied by user.
+  @type old_key_type: one of L{constants.SSHK_ALL} or None
+  @param old_key_type: The previously used key type, if any.
+  @type old_key_bits: positive int or None
+  @param old_key_bits: The previously used number of bits, if any.
+
+  @rtype: positive int
+  @return: The number of bits to use.
+
+  """
+  if key_bits is None:
+    if old_key_type is not None and old_key_type == key_type:
+      key_bits = old_key_bits
+    else:
+      key_bits = SSH_KEY_VALID_BITS[key_type].default
+
+  if not SSH_KEY_VALID_BITS[key_type].validation_fn(key_bits):
+    raise errors.OpPrereqError("Invalid key type and bit size combination:"
+                               " %s with %s bits" % (key_type, key_bits),
+                               errors.ECODE_INVAL)
+
+  return key_bits