Fail early for invalid key type and size combinations
authorHrvoje Ribicic <riba@google.com>
Fri, 6 Nov 2015 01:35:51 +0000 (02:35 +0100)
committerHrvoje Ribicic <riba@google.com>
Fri, 20 Nov 2015 10:14:19 +0000 (11:14 +0100)
The ssh-keygen utility permits only some combinations of key types and
bit sizes. As many more things can go wrong late in the renewal
process, this patch introduces prerequisite checks mimicking those of
ssh-keygen.

Signed-off-by: Hrvoje Ribicic <riba@google.com>
Reviewed-by: Helga Velroyen <helgav@google.com>

lib/client/gnt_cluster.py
lib/cmdlib/cluster/__init__.py
lib/ssh.py
test/py/ganeti.ssh_unittest.py

index 93ab24c..962be74 100644 (file)
@@ -304,10 +304,8 @@ def InitCluster(opts, args):
   else:
     ssh_key_type = constants.SSH_DEFAULT_KEY_TYPE
 
-  if opts.ssh_key_bits:
-    ssh_key_bits = opts.ssh_key_bits
-  else:
-    ssh_key_bits = constants.SSH_DEFAULT_KEY_BITS
+  ssh_key_bits = ssh.DetermineKeyBits(ssh_key_type, opts.ssh_key_bits, None,
+                                      None)
 
   bootstrap.InitCluster(cluster_name=args[0],
                         secondary_ip=opts.secondary_ip,
index 5658646..3147f96 100644 (file)
@@ -87,6 +87,23 @@ class LUClusterRenewCrypto(NoHooksLU):
     self.share_locks = ShareAll()
     self.share_locks[locking.LEVEL_NODE] = 0
 
+  def CheckPrereq(self):
+    """Check prerequisites.
+
+    Notably the compatibility of specified key bits and key type.
+
+    """
+    cluster_info = self.cfg.GetClusterInfo()
+
+    self.ssh_key_type = self.op.ssh_key_type
+    if self.ssh_key_type is None:
+      self.ssh_key_type = cluster_info.ssh_key_type
+
+    self.ssh_key_bits = ssh.DetermineKeyBits(self.ssh_key_type,
+                                             self.op.ssh_key_bits,
+                                             cluster_info.ssh_key_type,
+                                             cluster_info.ssh_key_bits)
+
   def _RenewNodeSslCertificates(self, feedback_fn):
     """Renews the nodes' SSL certificates.
 
@@ -167,28 +184,20 @@ class LUClusterRenewCrypto(NoHooksLU):
 
     cluster_info = self.cfg.GetClusterInfo()
 
-    new_ssh_key_type = self.op.ssh_key_type
-    if new_ssh_key_type is None:
-      new_ssh_key_type = cluster_info.ssh_key_type
-
-    new_ssh_key_bits = self.op.ssh_key_bits
-    if new_ssh_key_bits is None:
-      new_ssh_key_bits = cluster_info.ssh_key_bits
-
     result = self.rpc.call_node_ssh_keys_renew(
       [master_uuid],
       node_uuids, node_names,
       master_candidate_uuids,
       potential_master_candidates,
       cluster_info.ssh_key_type, # Old key type
-      new_ssh_key_type,          # New key type
-      new_ssh_key_bits)          # New key bits
+      self.ssh_key_type,         # New key type
+      self.ssh_key_bits)         # New key bits
     result[master_uuid].Raise("Could not renew the SSH keys of all nodes")
 
     # After the keys have been successfully swapped, time to commit the change
     # in key type
-    cluster_info.ssh_key_type = new_ssh_key_type
-    cluster_info.ssh_key_bits = new_ssh_key_bits
+    cluster_info.ssh_key_type = self.ssh_key_type
+    cluster_info.ssh_key_bits = self.ssh_key_bits
     self.cfg.Update(cluster_info, feedback_fn)
 
   def Exec(self, feedback_fn):
index d2684fc..7b27214 100644 (file)
@@ -37,6 +37,7 @@ import logging
 import os
 import tempfile
 
+from collections import namedtuple
 from functools import partial
 
 from ganeti import utils
@@ -1094,5 +1095,44 @@ def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
   if result.failed:
     raise errors.OpPrereqError("Could not fetch a public SSH key (%s) from node"
                                " '%s': ran command '%s', failure reason: '%s'."
-                               % (pub_key_file, node, cmd, result.fail_reason))
+                               % (pub_key_file, node, cmd, result.fail_reason),
+                               errors.ECODE_INVAL)
   return result.stdout
+
+
+KeyBitInfo = namedtuple('KeyBitInfo', ['default', 'validation_fn'])
+SSH_KEY_VALID_BITS = {
+  constants.SSHK_DSA: KeyBitInfo(1024, lambda b: b == 1024),
+  constants.SSHK_RSA: KeyBitInfo(2048, lambda b: b >= 768),
+  constants.SSHK_ECDSA: KeyBitInfo(384, lambda b: b in [256, 384, 521]),
+}
+
+
+def DetermineKeyBits(key_type, key_bits, old_key_type, old_key_bits):
+  """Checks the key bits to be used for a given key type, or provides defaults.
+
+  @type key_type: one of L{constants.SSHK_ALL}
+  @param key_type: The key type to use.
+  @type key_bits: positive int or None
+  @param key_bits: The number of bits to use, if supplied by user.
+  @type old_key_type: one of L{constants.SSHK_ALL} or None
+  @param old_key_type: The previously used key type, if any.
+  @type old_key_bits: positive int or None
+  @param old_key_bits: The previously used number of bits, if any.
+
+  @rtype: positive int
+  @return: The number of bits to use.
+
+  """
+  if key_bits is None:
+    if old_key_type is not None and old_key_type == key_type:
+      key_bits = old_key_bits
+    else:
+      key_bits = SSH_KEY_VALID_BITS[key_type].default
+
+  if not SSH_KEY_VALID_BITS[key_type].validation_fn(key_bits):
+    raise errors.OpPrereqError("Invalid key type and bit size combination:"
+                               " %s with %s bits" % (key_type, key_bits),
+                               errors.ECODE_INVAL)
+
+  return key_bits
index b13dda1..265adec 100755 (executable)
@@ -488,5 +488,37 @@ class TestGetUserFiles(testutils.GanetiTestCase):
     self.assertTrue(os.path.exists(self.priv_filename + suffix + ".pub"))
 
 
+class TestDetermineKeyBits():
+  def testCompleteness(self):
+    self.assertEquals(constants.SSHK_ALL, ssh.SSH_KEY_VALID_BITS.keys())
+
+  def testAdoptDefault(self):
+    self.assertEquals(2048, DetermineKeyBits("rsa", None, None, None))
+    self.assertEquals(1024, DetermineKeyBits("dsa", None, None, None))
+
+  def testAdoptOldKeySize(self):
+    self.assertEquals(4098, DetermineKeyBits("rsa", None, "rsa", 4098))
+    self.assertEquals(2048, DetermineKeyBits("rsa", None, "dsa", 1024))
+
+  def testDsaSpecificValues(self):
+    self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "dsa", 2048,
+                      None, None)
+    self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "dsa", 512,
+                      None, None)
+    self.assertEquals(1024, DetermineKeyBits("dsa", None, None, None))
+
+  def testEcdsaSpecificValues(self):
+    self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "ecdsa", 2048,
+                      None, None)
+    for b in [256, 384, 521]:
+      self.assertEquals(b, DetermineKeyBits("ecdsa", b, None, None))
+
+  def testRsaSpecificValues(self):
+    self.assertRaises(errors.OpPrereqError, DetermineKeyBits, "dsa", 766,
+                      None, None)
+    for b in [768, 769, 2048, 2049, 4096]:
+      self.assertEquals(b, DetermineKeyBits("rsa", b, None, None))
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()