Use the SSH key parameters when generating keys
authorHrvoje Ribicic <riba@google.com>
Tue, 13 Oct 2015 16:05:18 +0000 (12:05 -0400)
committerHrvoje Ribicic <riba@google.com>
Fri, 20 Nov 2015 10:14:12 +0000 (11:14 +0100)
This patch makes sure that the parameters introduced in previous
patches propagates wherever SSH keys are generated and used, allowing
Ganeti to use different types of SSH keys. With tis patch, the key type
can be set only at cluster initialization time.

Signed-off-by: Hrvoje Ribicic <riba@google.com>
Reviewed-by: Helga Velroyen <helgav@google.com>

17 files changed:
lib/backend.py
lib/bootstrap.py
lib/client/gnt_cluster.py
lib/client/gnt_node.py
lib/cmdlib/cluster/__init__.py
lib/cmdlib/cluster/verify.py
lib/rpc_defs.py
lib/server/noded.py
lib/ssh.py
lib/tools/common.py
lib/tools/prepare_node_join.py
lib/tools/ssh_update.py
src/Ganeti/Constants.hs
test/py/ganeti.backend_unittest.py
test/py/ganeti.client.gnt_cluster_unittest.py
test/py/ganeti.ssh_unittest.py
test/py/ganeti.tools.prepare_node_join_unittest.py

index 64b55e0..19409e5 100644 (file)
@@ -967,8 +967,8 @@ def _VerifyClientCertificate(cert_file=pathutils.NODED_CLIENT_CERT_FILE):
   return (None, utils.GetCertificateDigest(cert_filename=cert_file))
 
 
-def _VerifySshSetup(node_status_list, my_name,
-                    pub_key_file=pathutils.SSH_PUB_KEYS):
+def _VerifySshSetup(node_status_list, my_name, ssh_key_type,
+                    ganeti_pub_keys_file=pathutils.SSH_PUB_KEYS):
   """Verifies the state of the SSH key files.
 
   @type node_status_list: list of tuples
@@ -977,8 +977,10 @@ def _VerifySshSetup(node_status_list, my_name,
     is_potential_master_candidate, online)
   @type my_name: str
   @param my_name: name of this node
-  @type pub_key_file: str
-  @param pub_key_file: filename of the public key file
+  @type ssh_key_type: one of L{constants.SSHK_ALL}
+  @param ssh_key_type: type of key used on nodes
+  @type ganeti_pub_keys_file: str
+  @param ganeti_pub_keys_file: filename of the public keys file
 
   """
   if node_status_list is None:
@@ -994,16 +996,16 @@ def _VerifySshSetup(node_status_list, my_name,
 
   result = []
 
-  if not os.path.exists(pub_key_file):
+  if not os.path.exists(ganeti_pub_keys_file):
     result.append("The public key file '%s' does not exist. Consider running"
                   " 'gnt-cluster renew-crypto --new-ssh-keys"
-                  " [--no-ssh-key-check]' to fix this." % pub_key_file)
+                  " [--no-ssh-key-check]' to fix this." % ganeti_pub_keys_file)
     return result
 
   pot_mc_uuids = [uuid for (uuid, _, _, _, _) in node_status_list]
   offline_nodes = [uuid for (uuid, _, _, _, online) in node_status_list
                    if not online]
-  pub_keys = ssh.QueryPubKeyFile(None)
+  pub_keys = ssh.QueryPubKeyFile(None, key_file=ganeti_pub_keys_file)
 
   if potential_master_candidate:
     # Check that the set of potential master candidates matches the
@@ -1026,14 +1028,14 @@ def _VerifySshSetup(node_status_list, my_name,
 
     (_, key_files) = \
       ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
-    (_, dsa_pub_key_filename) = key_files[constants.SSHK_DSA]
+    (_, node_pub_key_file) = key_files[ssh_key_type]
 
     my_keys = pub_keys[my_uuid]
 
-    dsa_pub_key = utils.ReadFile(dsa_pub_key_filename)
-    if dsa_pub_key.strip() not in my_keys:
+    node_pub_key = utils.ReadFile(node_pub_key_file)
+    if node_pub_key.strip() not in my_keys:
       result.append("The dsa key of node %s does not match this node's key"
-                    " in the pub key file." % (my_name))
+                    " in the pub key file." % my_name)
     if len(my_keys) != 1:
       result.append("There is more than one key for node %s in the public key"
                     " file." % my_name)
@@ -1152,8 +1154,9 @@ def VerifyNode(what, cluster_name, all_hvparams):
     result[constants.NV_CLIENT_CERT] = _VerifyClientCertificate()
 
   if constants.NV_SSH_SETUP in what:
+    node_status_list, key_type = what[constants.NV_SSH_SETUP]
     result[constants.NV_SSH_SETUP] = \
-      _VerifySshSetup(what[constants.NV_SSH_SETUP], my_name)
+      _VerifySshSetup(node_status_list, my_name, key_type)
     if constants.NV_SSH_CLUTTER in what:
       result[constants.NV_SSH_CLUTTER] = \
         _VerifySshClutter(what[constants.NV_SSH_SETUP], my_name)
@@ -1774,8 +1777,8 @@ def RemoveNodeSshKey(node_uuid, node_name,
   return result_msgs
 
 
-def _GenerateNodeSshKey(node_uuid, node_name, ssh_port_map,
-                        pub_key_file=pathutils.SSH_PUB_KEYS,
+def _GenerateNodeSshKey(node_uuid, node_name, ssh_port_map, ssh_key_type,
+                        ssh_key_bits, pub_key_file=pathutils.SSH_PUB_KEYS,
                         ssconf_store=None,
                         noded_cert_file=pathutils.NODED_CERT_FILE,
                         run_cmd_fn=ssh.RunSshCmdWithStdin,
@@ -1788,6 +1791,10 @@ def _GenerateNodeSshKey(node_uuid, node_name, ssh_port_map,
   @param node_name: name of the node whose key is remove
   @type ssh_port_map: dict of str to int
   @param ssh_port_map: mapping of node names to their SSH port
+  @type ssh_key_type: One of L{constants.SSHK_ALL}
+  @param ssh_key_type: the type of SSH key to be generated
+  @type ssh_key_bits: int
+  @param ssh_key_bits: the length of the key to be generated
 
   """
   if not ssconf_store:
@@ -1802,7 +1809,7 @@ def _GenerateNodeSshKey(node_uuid, node_name, ssh_port_map,
   data = {}
   _InitSshUpdateData(data, noded_cert_file, ssconf_store)
   cluster_name = data[constants.SSHS_CLUSTER_NAME]
-  data[constants.SSHS_GENERATE] = {constants.SSHS_SUFFIX: suffix}
+  data[constants.SSHS_GENERATE] = (ssh_key_type, ssh_key_bits, suffix)
 
   run_cmd_fn(cluster_name, node_name, pathutils.SSH_UPDATE,
              ssh_port_map.get(node_name), data,
@@ -1877,8 +1884,8 @@ def _ReplaceMasterKeyOnMaster(root_keyfiles):
 
 
 def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
-                 potential_master_candidates,
-                 pub_key_file=pathutils.SSH_PUB_KEYS,
+                 potential_master_candidates, ssh_key_type, ssh_key_bits,
+                 ganeti_pub_keys_file=pathutils.SSH_PUB_KEYS,
                  ssconf_store=None,
                  noded_cert_file=pathutils.NODED_CERT_FILE,
                  run_cmd_fn=ssh.RunSshCmdWithStdin):
@@ -1892,8 +1899,12 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
   @type master_candidate_uuids: list of str
   @param master_candidate_uuids: list of UUIDs of master candidates or
     master node
-  @type pub_key_file: str
-  @param pub_key_file: file path of the the public key file
+  @type ssh_key_type: One of L{constants.SSHK_ALL}
+  @param ssh_key_type: the type of SSH key to be generated
+  @type ssh_key_bits: int
+  @param ssh_key_bits: the length of the key to be generated
+  @type ganeti_pub_keys_file: str
+  @param ganeti_pub_keys_file: file path of the the public key file
   @type noded_cert_file: str
   @param noded_cert_file: path of the noded SSL certificate file
   @type run_cmd_fn: function
@@ -1915,8 +1926,8 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
 
   (_, root_keyfiles) = \
     ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
-  (_, dsa_pub_keyfile) = root_keyfiles[constants.SSHK_DSA]
-  old_master_key = utils.ReadFile(dsa_pub_keyfile)
+  (_, node_pub_keyfile) = root_keyfiles[ssh_key_type]
+  old_master_key = utils.ReadFile(node_pub_keyfile)
 
   node_uuid_name_map = zip(node_uuids, node_names)
 
@@ -1935,7 +1946,8 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
     master_candidate = node_uuid in master_candidate_uuids
     potential_master_candidate = node_name in potential_master_candidates
 
-    keys_by_uuid = ssh.QueryPubKeyFile([node_uuid], key_file=pub_key_file)
+    keys_by_uuid = ssh.QueryPubKeyFile([node_uuid],
+                                       key_file=ganeti_pub_keys_file)
     if not keys_by_uuid:
       raise errors.SshUpdateError("No public key of node %s (UUID %s) found,"
                                   " not generating a new key."
@@ -1943,7 +1955,7 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
 
     if master_candidate:
       logging.debug("Fetching old SSH key from node '%s'.", node_name)
-      old_pub_key = ssh.ReadRemoteSshPubKeys(dsa_pub_keyfile,
+      old_pub_key = ssh.ReadRemoteSshPubKeys(node_pub_keyfile,
                                              node_name, cluster_name,
                                              ssh_port_map[node_name],
                                              False, # ask_key
@@ -1968,15 +1980,15 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
                       " key. Not deleting that key on the node.", node_name)
 
     logging.debug("Generating new SSH key for node '%s'.", node_name)
-    _GenerateNodeSshKey(node_uuid, node_name, ssh_port_map,
-                        pub_key_file=pub_key_file,
+    _GenerateNodeSshKey(node_uuid, node_name, ssh_port_map, ssh_key_type,
+                        ssh_key_bits, pub_key_file=ganeti_pub_keys_file,
                         ssconf_store=ssconf_store,
                         noded_cert_file=noded_cert_file,
                         run_cmd_fn=run_cmd_fn)
 
     try:
       logging.debug("Fetching newly created SSH key from node '%s'.", node_name)
-      pub_key = ssh.ReadRemoteSshPubKeys(dsa_pub_keyfile,
+      pub_key = ssh.ReadRemoteSshPubKeys(node_pub_keyfile,
                                          node_name, cluster_name,
                                          ssh_port_map[node_name],
                                          False, # ask_key
@@ -1986,8 +1998,8 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
                                   " (UUID %s)" % (node_name, node_uuid))
 
     if potential_master_candidate:
-      ssh.RemovePublicKey(node_uuid, key_file=pub_key_file)
-      ssh.AddPublicKey(node_uuid, pub_key, key_file=pub_key_file)
+      ssh.RemovePublicKey(node_uuid, key_file=ganeti_pub_keys_file)
+      ssh.AddPublicKey(node_uuid, pub_key, key_file=ganeti_pub_keys_file)
 
     logging.debug("Add ssh key of node '%s'.", node_name)
     node_errors = AddNodeSshKey(
@@ -1995,7 +2007,8 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
         to_authorized_keys=master_candidate,
         to_public_keys=potential_master_candidate,
         get_public_keys=True,
-        pub_key_file=pub_key_file, ssconf_store=ssconf_store,
+        pub_key_file=ganeti_pub_keys_file,
+        ssconf_store=ssconf_store,
         noded_cert_file=noded_cert_file,
         run_cmd_fn=run_cmd_fn)
     if node_errors:
@@ -2004,12 +2017,14 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
   # Renewing the master node's key
 
   # Preserve the old keys for now
-  old_master_keys_by_uuid = _GetOldMasterKeys(master_node_uuid, pub_key_file)
+  old_master_keys_by_uuid = _GetOldMasterKeys(master_node_uuid,
+                                              ganeti_pub_keys_file)
 
   # Generate a new master key with a suffix, don't touch the old one for now
   logging.debug("Generate new ssh key of master.")
   _GenerateNodeSshKey(master_node_uuid, master_node_name, ssh_port_map,
-                      pub_key_file=pub_key_file,
+                      ssh_key_type, ssh_key_bits,
+                      pub_key_file=ganeti_pub_keys_file,
                       ssconf_store=ssconf_store,
                       noded_cert_file=noded_cert_file,
                       run_cmd_fn=run_cmd_fn,
@@ -2018,16 +2033,16 @@ def RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
   new_master_key_dict = _GetNewMasterKey(root_keyfiles, master_node_uuid)
 
   # Replace master key in the master nodes' public key file
-  ssh.RemovePublicKey(master_node_uuid, key_file=pub_key_file)
+  ssh.RemovePublicKey(master_node_uuid, key_file=ganeti_pub_keys_file)
   for pub_key in new_master_key_dict[master_node_uuid]:
-    ssh.AddPublicKey(master_node_uuid, pub_key, key_file=pub_key_file)
+    ssh.AddPublicKey(master_node_uuid, pub_key, key_file=ganeti_pub_keys_file)
 
   # Add new master key to all node's public and authorized keys
   logging.debug("Add new master key to all nodes.")
   node_errors = AddNodeSshKey(
       master_node_uuid, master_node_name, potential_master_candidates,
       to_authorized_keys=True, to_public_keys=True,
-      get_public_keys=False, pub_key_file=pub_key_file,
+      get_public_keys=False, pub_key_file=ganeti_pub_keys_file,
       ssconf_store=ssconf_store, noded_cert_file=noded_cert_file,
       run_cmd_fn=run_cmd_fn)
   if node_errors:
index 69f75dd..370b4c7 100644 (file)
@@ -714,7 +714,7 @@ def InitCluster(cluster_name, mac_prefix, # pylint: disable=R0913, R0914
     utils.AddHostToEtcHosts(hostname.name, hostname.ip)
 
   if modify_ssh_setup:
-    ssh.InitSSHSetup()
+    ssh.InitSSHSetup(ssh_key_type, ssh_key_bits)
 
   if default_iallocator is not None:
     alloc_script = utils.FindFile(default_iallocator,
@@ -817,7 +817,7 @@ def InitCluster(cluster_name, mac_prefix, # pylint: disable=R0913, R0914
 
   master_uuid = cfg.GetMasterNode()
   if modify_ssh_setup:
-    ssh.InitPubKeyFile(master_uuid)
+    ssh.InitPubKeyFile(master_uuid, ssh_key_type)
   # set up the inter-node password and certificate
   _InitGanetiServerSetup(hostname.name, cfg)
 
index 2de389e..16120a7 100644 (file)
@@ -1216,8 +1216,9 @@ def _BuildGanetiPubKeys(options, pub_key_file=pathutils.SSH_PUB_KEYS, cl=None,
   if not cl:
     cl = GetClient()
 
-  (cluster_name, master_node, modify_ssh_setup) = \
-    cl.QueryConfigValues(["cluster_name", "master_node", "modify_ssh_setup"])
+  (cluster_name, master_node, modify_ssh_setup, ssh_key_type) = \
+    cl.QueryConfigValues(["cluster_name", "master_node", "modify_ssh_setup",
+                          "ssh_key_type"])
 
   # In case Ganeti is not supposed to modify the SSH setup, simply exit and do
   # not update this file.
@@ -1242,7 +1243,7 @@ def _BuildGanetiPubKeys(options, pub_key_file=pathutils.SSH_PUB_KEYS, cl=None,
 
   _, pub_key_filename, _ = \
     ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False,
-                     kind=constants.SSHK_DSA, _homedir_fn=homedir_fn)
+                     kind=ssh_key_type, _homedir_fn=homedir_fn)
 
   # get the key file of the master node
   pub_key = utils.ReadFile(pub_key_filename)
index b1ce8dc..4c26231 100644 (file)
@@ -230,12 +230,17 @@ def _SetupSSH(options, cluster_name, node, ssh_port, cl):
   (_, cert_pem) = \
     utils.ExtractX509Certificate(utils.ReadFile(pathutils.NODED_CERT_FILE))
 
+  (ssh_key_type, ssh_key_bits) = \
+    cl.QueryConfigValues(["ssh_key_type", "ssh_key_bits"])
+
   data = {
     constants.SSHS_CLUSTER_NAME: cluster_name,
     constants.SSHS_NODE_DAEMON_CERTIFICATE: cert_pem,
     constants.SSHS_SSH_HOST_KEY: host_keys,
     constants.SSHS_SSH_ROOT_KEY: root_keys,
     constants.SSHS_SSH_AUTHORIZED_KEYS: candidate_keys,
+    constants.SSHS_SSH_KEY_TYPE: ssh_key_type,
+    constants.SSHS_SSH_KEY_BITS: ssh_key_bits,
     }
 
   ssh.RunSshCmdWithStdin(cluster_name, node, pathutils.PREPARE_NODE_JOIN,
@@ -244,9 +249,9 @@ def _SetupSSH(options, cluster_name, node, ssh_port, cl):
                          use_cluster_key=False, ask_key=options.ssh_key_check,
                          strict_host_check=options.ssh_key_check)
 
-  (_, dsa_pub_keyfile) = root_keyfiles[constants.SSHK_DSA]
-  pub_key = ssh.ReadRemoteSshPubKeys(dsa_pub_keyfile, node, cluster_name,
-                                     ssh_port, options.ssh_key_check,
+  (_, pub_keyfile) = root_keyfiles[ssh_key_type]
+  pub_key = ssh.ReadRemoteSshPubKeys(pub_keyfile, node, cluster_name, ssh_port,
+                                     options.ssh_key_check,
                                      options.ssh_key_check)
   # Unfortunately, we have to add the key with the node name rather than
   # the node's UUID here, because at this point, we do not have a UUID yet.
index 51474d6..ed6a3b8 100644 (file)
@@ -172,11 +172,16 @@ class LUClusterRenewCrypto(NoHooksLU):
     node_uuids = [uuid for (uuid, _) in nodes_uuid_names]
     potential_master_candidates = self.cfg.GetPotentialMasterCandidates()
     master_candidate_uuids = self.cfg.GetMasterCandidateUuids()
+
+    cluster_info = self.cfg.GetClusterInfo()
+
     result = self.rpc.call_node_ssh_keys_renew(
       [master_uuid],
       node_uuids, node_names,
       master_candidate_uuids,
-      potential_master_candidates)
+      potential_master_candidates,
+      cluster_info.ssh_key_type,
+      cluster_info.ssh_key_bits)
     result[master_uuid].Raise("Could not renew the SSH keys of all nodes")
 
   def Exec(self, feedback_fn):
index e6fc13f..772ea9a 100644 (file)
@@ -1873,7 +1873,8 @@ class LUClusterVerifyGroup(LogicalUnit, _VerifyErrors):
       }
 
     if self.cfg.GetClusterInfo().modify_ssh_setup:
-      node_verify_param[constants.NV_SSH_SETUP] = self._PrepareSshSetupCheck()
+      node_verify_param[constants.NV_SSH_SETUP] = \
+        (self._PrepareSshSetupCheck(), self.cfg.GetClusterInfo().ssh_key_type)
       if self.op.verify_clutter:
         node_verify_param[constants.NV_SSH_CLUTTER] = True
 
index 021807c..8163735 100644 (file)
@@ -565,7 +565,9 @@ _NODE_CALLS = [
     ("node_uuids", None, "UUIDs of the nodes whose key is renewed"),
     ("node_names", None, "Names of the nodes whose key is renewed"),
     ("master_candidate_uuids", None, "List of UUIDs of master candidates."),
-    ("potential_master_candidates", None, "Potential master candidates")],
+    ("potential_master_candidates", None, "Potential master candidates"),
+    ("ssh_key_type", None, "The type of key to generate"),
+    ("ssh_key_bits", None, "The length of the key to generate")],
     None, None, "Renew all SSH key pairs of all nodes nodes."),
   ]
 
index c73897c..871b4e1 100644 (file)
@@ -945,10 +945,10 @@ class NodeRequestHandler(http.server.HttpServerHandler):
 
     """
     (node_uuids, node_names, master_candidate_uuids,
-     potential_master_candidates) = params
-    return backend.RenewSshKeys(node_uuids, node_names,
-                                master_candidate_uuids,
-                                potential_master_candidates)
+     potential_master_candidates, ssh_key_type, ssh_key_bits) = params
+    return backend.RenewSshKeys(node_uuids, node_names, master_candidate_uuids,
+                                potential_master_candidates, ssh_key_type,
+                                ssh_key_bits)
 
   @staticmethod
   def perspective_node_ssh_key_remove(params):
index 59ecbf9..d2684fc 100644 (file)
@@ -677,15 +677,18 @@ def QueryPubKeyFile(target_uuids, key_file=pathutils.SSH_PUB_KEYS,
   return result
 
 
-def InitSSHSetup(error_fn=errors.OpPrereqError, _homedir_fn=None,
-                 _suffix=""):
+def InitSSHSetup(key_type, key_bits, error_fn=errors.OpPrereqError,
+                 _homedir_fn=None, _suffix=""):
   """Setup the SSH configuration for the node.
 
   This generates a dsa keypair for root, adds the pub key to the
   permitted hosts and adds the hostkey to its own known hosts.
 
+  @param key_type: the type of SSH keypair to be generated
+  @param key_bits: the key length, in bits, to be used
+
   """
-  priv_key, _, auth_keys = GetUserFiles(constants.SSH_LOGIN_USER,
+  priv_key, _, auth_keys = GetUserFiles(constants.SSH_LOGIN_USER, kind=key_type,
                                         mkdir=True, _homedir_fn=_homedir_fn)
 
   new_priv_key_name = priv_key + _suffix
@@ -696,7 +699,7 @@ def InitSSHSetup(error_fn=errors.OpPrereqError, _homedir_fn=None,
       utils.CreateBackup(name)
     utils.RemoveFile(name)
 
-  result = utils.RunCmd(["ssh-keygen", "-t", "dsa",
+  result = utils.RunCmd(["ssh-keygen", "-b", str(key_bits), "-t", key_type,
                          "-f", new_priv_key_name,
                          "-q", "-N", ""])
   if result.failed:
@@ -706,16 +709,18 @@ def InitSSHSetup(error_fn=errors.OpPrereqError, _homedir_fn=None,
   AddAuthorizedKey(auth_keys, utils.ReadFile(new_pub_key_name))
 
 
-def InitPubKeyFile(master_uuid, key_file=pathutils.SSH_PUB_KEYS):
+def InitPubKeyFile(master_uuid, key_type, key_file=pathutils.SSH_PUB_KEYS):
   """Creates the public key file and adds the master node's SSH key.
 
   @type master_uuid: str
   @param master_uuid: the master node's UUID
+  @type key_type: one of L{constants.SSHK_ALL}
+  @param key_type: the type of ssh key to be used
   @type key_file: str
   @param key_file: name of the file containing the public keys
 
   """
-  _, pub_key, _ = GetUserFiles(constants.SSH_LOGIN_USER)
+  _, pub_key, _ = GetUserFiles(constants.SSH_LOGIN_USER, kind=key_type)
   ClearPubKeyFile(key_file=key_file)
   key = utils.ReadFile(pub_key)
   AddPublicKey(master_uuid, key, key_file=key_file)
@@ -1069,7 +1074,7 @@ def RunSshCmdWithStdin(cluster_name, node, basecmd, port, data,
 
 def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
                          strict_host_check):
-  """Fetches the public DSA SSH key from a node via SSH.
+  """Fetches a public SSH key from a node via SSH.
 
   @type pub_key_file: string
   @param pub_key_file: a tuple consisting of the file name of the public DSA key
@@ -1087,7 +1092,7 @@ def ReadRemoteSshPubKeys(pub_key_file, node, cluster_name, port, ask_key,
 
   result = utils.RunCmd(ssh_cmd)
   if result.failed:
-    raise errors.OpPrereqError("Could not fetch a public DSA SSH key from node"
+    raise errors.OpPrereqError("Could not fetch a public SSH key (%s) from node"
                                " '%s': ran command '%s', failure reason: '%s'."
-                               % (node, cmd, result.fail_reason))
+                               % (pub_key_file, node, cmd, result.fail_reason))
   return result.stdout
index a9149f6..ca8288a 100644 (file)
@@ -191,11 +191,13 @@ def LoadData(raw, data_check):
   return serializer.LoadAndVerifyJson(raw, data_check)
 
 
-def GenerateRootSshKeys(error_fn, _suffix="", _homedir_fn=None):
+def GenerateRootSshKeys(key_type, key_bits, error_fn, _suffix="",
+                        _homedir_fn=None):
   """Generates root's SSH keys for this node.
 
   """
-  ssh.InitSSHSetup(error_fn=error_fn, _homedir_fn=_homedir_fn, _suffix=_suffix)
+  ssh.InitSSHSetup(key_type, key_bits, error_fn=error_fn,
+                   _homedir_fn=_homedir_fn, _suffix=_suffix)
 
 
 def GenerateClientCertificate(
index 82a35dc..fa45a58 100644 (file)
@@ -50,7 +50,7 @@ from ganeti.tools import common
 _SSH_KEY_LIST_ITEM = \
   ht.TAnd(ht.TIsLength(3),
           ht.TItems([
-            ht.TElemOf(constants.SSHK_ALL),
+            ht.TSshKeyType,
             ht.Comment("public")(ht.TNonEmptyString),
             ht.Comment("private")(ht.TNonEmptyString),
           ]))
@@ -64,6 +64,8 @@ _DATA_CHECK = ht.TStrictDict(False, True, {
   constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
   constants.SSHS_SSH_AUTHORIZED_KEYS:
     ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString)),
+  constants.SSHS_SSH_KEY_TYPE: ht.TSshKeyType,
+  constants.SSHS_SSH_KEY_BITS: ht.TPositive,
   })
 
 
@@ -172,7 +174,10 @@ def UpdateSshRoot(data, dry_run, _homedir_fn=None):
   if dry_run:
     logging.info("This is a dry run, not replacing the SSH keys.")
   else:
-    common.GenerateRootSshKeys(error_fn=JoinError, _homedir_fn=_homedir_fn)
+    ssh_key_type = data.get(constants.SSHS_SSH_KEY_TYPE)
+    ssh_key_bits = data.get(constants.SSHS_SSH_KEY_BITS)
+    common.GenerateRootSshKeys(ssh_key_type, ssh_key_bits, error_fn=JoinError,
+                               _homedir_fn=_homedir_fn)
 
   if authorized_keys:
     if dry_run:
index f9d1b6d..b37972e 100644 (file)
@@ -62,7 +62,13 @@ _DATA_CHECK = ht.TStrictDict(False, True, {
     ht.TItems(
       [ht.TElemOf(constants.SSHS_ACTIONS),
        ht.TDictOf(ht.TNonEmptyString, ht.TListOf(ht.TNonEmptyString))]),
-  constants.SSHS_GENERATE: ht.TDictOf(ht.TNonEmptyString, ht.TString),
+  constants.SSHS_GENERATE:
+    ht.TItems(
+      [ht.TSshKeyType, # The type of key to generate
+       ht.TPositive, # The number of bits in the key
+       ht.TString]), # The suffix
+  constants.SSHS_SSH_KEY_TYPE: ht.TSshKeyType,
+  constants.SSHS_SSH_KEY_BITS: ht.TPositive,
   })
 
 
@@ -190,11 +196,12 @@ def GenerateRootSshKeys(data, dry_run):
   """
   generate_info = data.get(constants.SSHS_GENERATE)
   if generate_info:
-    suffix = generate_info[constants.SSHS_SUFFIX]
+    key_type, key_bits, suffix = generate_info
     if dry_run:
       logging.info("This is a dry run, not generating any files.")
     else:
-      common.GenerateRootSshKeys(SshUpdateError, _suffix=suffix)
+      common.GenerateRootSshKeys(key_type, key_bits, SshUpdateError,
+                                 _suffix=suffix)
 
 
 def Main():
index 1a6ceca..c9ca540 100644 (file)
@@ -4730,6 +4730,12 @@ sshsSshPublicKeys = "public_keys"
 sshsNodeDaemonCertificate :: String
 sshsNodeDaemonCertificate = "node_daemon_certificate"
 
+sshsSshKeyType :: String
+sshsSshKeyType = "ssh_key_type"
+
+sshsSshKeyBits :: String
+sshsSshKeyBits = "ssh_key_bits"
+
 -- Number of maximum retries when contacting nodes per SSH
 -- during SSH update operations.
 sshsMaxRetries :: Integer
index 1a0a51d..e441621 100755 (executable)
@@ -1052,6 +1052,7 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
     backend._GenerateNodeSshKey(
         test_node_uuid, test_node_name,
         self._ssh_file_manager.GetSshPortMap(self._SSH_PORT),
+        "rsa", 2048,
         pub_key_file=self._pub_key_file,
         ssconf_store=self._ssconf_mock,
         noded_cert_file=self.noded_cert_file,
@@ -1656,8 +1657,8 @@ class TestVerifySshSetup(testutils.GanetiTestCase):
     self._read_file_mock = self._read_file_patcher.start()
     self._read_file_mock.return_value = self._NODE1_KEYS[0]
     self.tmpdir = tempfile.mkdtemp()
-    self.pub_key_file = os.path.join(self.tmpdir, "pub_key_file")
-    open(self.pub_key_file, "w").close()
+    self.pub_keys_file = os.path.join(self.tmpdir, "pub_keys_file")
+    open(self.pub_keys_file, "w").close()
 
   def tearDown(self):
     super(testutils.GanetiTestCase, self).tearDown()
@@ -1672,7 +1673,8 @@ class TestVerifySshSetup(testutils.GanetiTestCase):
     self._query_mock.return_value = self._PUB_KEY_RESULT
     result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
                                      self._NODE1_NAME,
-                                     pub_key_file=self.pub_key_file)
+                                     "dsa",
+                                     ganeti_pub_keys_file=self.pub_keys_file)
     self.assertEqual(result, [])
 
   def testMissingKey(self):
@@ -1683,7 +1685,8 @@ class TestVerifySshSetup(testutils.GanetiTestCase):
     self._query_mock.return_value = pub_key_missing
     result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
                                      self._NODE1_NAME,
-                                     pub_key_file=self.pub_key_file)
+                                     "dsa",
+                                     ganeti_pub_keys_file=self.pub_keys_file)
     self.assertTrue(self._NODE2_UUID in result[0])
 
   def testUnknownKey(self):
@@ -1694,7 +1697,8 @@ class TestVerifySshSetup(testutils.GanetiTestCase):
     self._query_mock.return_value = pub_key_missing
     result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
                                      self._NODE1_NAME,
-                                     pub_key_file=self.pub_key_file)
+                                     "dsa",
+                                     ganeti_pub_keys_file=self.pub_keys_file)
     self.assertTrue("unkownnodeuuid" in result[0])
 
   def testMissingMasterCandidate(self):
@@ -1705,7 +1709,8 @@ class TestVerifySshSetup(testutils.GanetiTestCase):
     self._query_mock.return_value = self._PUB_KEY_RESULT
     result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
                                      self._NODE1_NAME,
-                                     pub_key_file=self.pub_key_file)
+                                     "dsa",
+                                     ganeti_pub_keys_file=self.pub_keys_file)
     self.assertTrue(self._NODE1_UUID in result[0])
 
   def testSuperfluousNormalNode(self):
@@ -1716,7 +1721,8 @@ class TestVerifySshSetup(testutils.GanetiTestCase):
     self._query_mock.return_value = self._PUB_KEY_RESULT
     result = backend._VerifySshSetup(self._NODE_STATUS_LIST,
                                      self._NODE1_NAME,
-                                     pub_key_file=self.pub_key_file)
+                                     "dsa",
+                                     ganeti_pub_keys_file=self.pub_keys_file)
     self.assertTrue(self._NODE3_UUID in result[0])
 
 
index 595864a..c2cb9f5 100755 (executable)
@@ -382,6 +382,7 @@ class TestBuildGanetiPubKeys(testutils.GanetiTestCase):
   _PUB_KEY = "master_public_key"
   _MODIFY_SSH_SETUP = True
   _AUTH_KEYS = "a\nb\nc"
+  _SSH_KEY_TYPE = "dsa"
 
   def _setUpFakeKeys(self):
     os.makedirs(os.path.join(self.tmpdir, ".ssh"))
@@ -412,7 +413,8 @@ class TestBuildGanetiPubKeys(testutils.GanetiTestCase):
     self.mock_cl = mock.Mock()
     self.mock_cl.QueryConfigValues = mock.Mock()
     self.mock_cl.QueryConfigValues.return_value = \
-      (self._CLUSTER_NAME, self._MASTER_NODE_NAME, self._MODIFY_SSH_SETUP)
+      (self._CLUSTER_NAME, self._MASTER_NODE_NAME, self._MODIFY_SSH_SETUP,
+       self._SSH_KEY_TYPE)
 
     self._get_online_nodes_mock = mock.Mock()
     self._get_online_nodes_mock.return_value = \
index 9ec2397..b13dda1 100755 (executable)
@@ -279,6 +279,30 @@ class TestSshKeys(testutils.GanetiTestCase):
       "ssh-dss AAAAB3asdfasdfaYTUCB laracroft@test\n"
       "ssh-dss AasdfliuobaosfMAAACB frodo@test\n")
 
+  def testOtherKeyTypes(self):
+    key_rsa = "ssh-rsa AAAAimnottypingallofthathere0jfJs22 test@test"
+    key_ed25519 = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOlcZ6cpQTGow0LZECRHWn9"\
+                  "7Yvn16J5un501T/RcbfuF fast@secure"
+    key_ecdsa = "ecdsa-sha2-nistp256 AAAAE2VjZHNtoolongk/TNhVbEg= secure@secure"
+
+    def _ToFileContent(keys):
+      return '\n'.join(keys) + '\n'
+
+    ssh.AddAuthorizedKeys(self.tmpname, [key_rsa, key_ed25519, key_ecdsa])
+    self.assertFileContent(self.tmpname,
+                           _ToFileContent([self.KEY_A, self.KEY_B, key_rsa,
+                                           key_ed25519, key_ecdsa]))
+
+    ssh.RemoveAuthorizedKey(self.tmpname, key_ed25519)
+    self.assertFileContent(self.tmpname,
+                           _ToFileContent([self.KEY_A, self.KEY_B, key_rsa,
+                                           key_ecdsa]))
+
+    ssh.RemoveAuthorizedKey(self.tmpname, key_rsa)
+    ssh.RemoveAuthorizedKey(self.tmpname, key_ecdsa)
+    self.assertFileContent(self.tmpname,
+                           _ToFileContent([self.KEY_A, self.KEY_B]))
+
 
 class TestPublicSshKeys(testutils.GanetiTestCase):
   """Test case for the handling of the list of public ssh keys."""
@@ -450,13 +474,14 @@ class TestGetUserFiles(testutils.GanetiTestCase):
     return self.tmpdir
 
   def testNewKeysOverrideOldKeys(self):
-    ssh.InitSSHSetup(_homedir_fn=self._GetTempHomedir)
+    ssh.InitSSHSetup("dsa", 1024, _homedir_fn=self._GetTempHomedir)
     self.assertFileContentNotEqual(self.priv_filename, self._PRIV_KEY)
     self.assertFileContentNotEqual(self.pub_filename, self._PUB_KEY)
 
   def testSuffix(self):
     suffix = "_pinkbunny"
-    ssh.InitSSHSetup(_homedir_fn=self._GetTempHomedir, _suffix=suffix)
+    ssh.InitSSHSetup("dsa", 1024, _homedir_fn=self._GetTempHomedir,
+                     _suffix=suffix)
     self.assertFileContent(self.priv_filename, self._PRIV_KEY)
     self.assertFileContent(self.pub_filename, self._PUB_KEY)
     self.assertTrue(os.path.exists(self.priv_filename + suffix))
index a76db15..7901199 100755 (executable)
@@ -164,6 +164,8 @@ class TestUpdateSshDaemon(unittest.TestCase):
         (constants.SSHK_ECDSA, "ecdsapriv", "ecdsapub"),
         (constants.SSHK_RSA, "rsapriv", "rsapub"),
         ],
+      constants.SSHS_SSH_KEY_TYPE: "dsa",
+      constants.SSHS_SSH_KEY_BITS: 1024,
       }
     runcmd_fn = compat.partial(self._RunCmd, failcmd)
     if failcmd:
@@ -228,7 +230,9 @@ class TestUpdateSshRoot(unittest.TestCase):
     data = {
       constants.SSHS_SSH_ROOT_KEY: [
         (constants.SSHK_DSA, "privatedsa", "ssh-dss pubdsa"),
-        ]
+        ],
+      constants.SSHS_SSH_KEY_TYPE: "dsa",
+      constants.SSHS_SSH_KEY_BITS: 1024,
       }
 
     prepare_node_join.UpdateSshRoot(data, False,