Introduce bulk-adding of SSH keys
authorHelga Velroyen <helgav@google.com>
Wed, 11 Nov 2015 15:54:31 +0000 (16:54 +0100)
committerHelga Velroyen <helgav@google.com>
Tue, 17 Nov 2015 15:35:53 +0000 (16:35 +0100)
This patch introduces a backend function to add a set of
SSH keys to the nodes (rather than one key at a time).
The bulk-adding function is having the same structure
as the original one, but is adapted to work with a set
of keys rather than one key.

This patch also adds a unit test for testing the
bulk-adding of keys.

Note that this patch only adds the bulk-adding function
but does not use it yet. In the following patches of
this series, we will add more unit tests and at the
end integrate the bulk-adding function into
renew-crypto.

Signed-off-by: Helga Velroyen <helgav@google.com>
Reviewed-by: Klaus Aehlig <aehlig@google.com>

lib/backend.py
test/py/ganeti.backend_unittest.py

index f891ef6..9c42d5d 100644 (file)
@@ -62,6 +62,7 @@ import time
 import zlib
 import copy
 import contextlib
+import collections
 
 from ganeti import errors
 from ganeti import http
@@ -1566,6 +1567,175 @@ def AddNodeSshKey(node_uuid, node_name,
   return node_errors
 
 
+# Node info named tuple specifically for the use with AddNodeSshKeyBulk
+SshAddNodeInfo = collections.namedtuple(
+  "SshAddNodeInfo",
+  ["uuid",
+   "name",
+   "to_authorized_keys",
+   "to_public_keys",
+   "get_public_keys"])
+
+
+def AddNodeSshKeyBulk(node_list,
+                      potential_master_candidates,
+                      pub_key_file=pathutils.SSH_PUB_KEYS,
+                      ssconf_store=None,
+                      noded_cert_file=pathutils.NODED_CERT_FILE,
+                      run_cmd_fn=ssh.RunSshCmdWithStdin):
+  """Distributes a node's public SSH key across the cluster.
+
+  Note that this function should only be executed on the master node, which
+  then will copy the new node's key to all nodes in the cluster via SSH.
+
+  Also note: at least one of the flags C{to_authorized_keys},
+  C{to_public_keys}, and C{get_public_keys} has to be set to C{True} for
+  the function to actually perform any actions.
+
+  @type node_list: list of SshAddNodeInfo tuples
+  @param node_list: list of tuples containing the necessary node information for
+    adding their keys
+  @type potential_master_candidates: list of str
+  @param potential_master_candidates: list of node names of potential master
+    candidates; this should match the list of uuids in the public key file
+
+  """
+  # whether there are any keys to be added or retrieved at all
+  to_authorized_keys = any([node_info.to_authorized_keys for node_info in
+                            node_list])
+  to_public_keys = any([node_info.to_public_keys for node_info in
+                        node_list])
+  get_public_keys = any([node_info.get_public_keys for node_info in
+                         node_list])
+
+  # assure that at least one of those flags is true, as the function would
+  # not do anything otherwise
+  assert (to_authorized_keys or to_public_keys or get_public_keys)
+
+  if not ssconf_store:
+    ssconf_store = ssconf.SimpleStore()
+
+  for node_info in node_list:
+    # Check and fix sanity of key file
+    keys_by_name = ssh.QueryPubKeyFile([node_info.name], key_file=pub_key_file)
+    keys_by_uuid = ssh.QueryPubKeyFile([node_info.uuid], key_file=pub_key_file)
+
+    if (not keys_by_name or node_info.name not in keys_by_name) \
+        and (not keys_by_uuid or node_info.uuid not in keys_by_uuid):
+      raise errors.SshUpdateError(
+        "No keys found for the new node '%s' (UUID %s) in the list of public"
+        " SSH keys, neither for the name or the UUID" %
+        (node_info.name, node_info.uuid))
+    else:
+      if node_info.name in keys_by_name:
+        # Replace the name by UUID in the file as the name should only be used
+        # temporarily
+        ssh.ReplaceNameByUuid(node_info.uuid, node_info.name,
+                              error_fn=errors.SshUpdateError,
+                              key_file=pub_key_file)
+
+  # Retrieve updated map of UUIDs to keys
+  keys_by_uuid = ssh.QueryPubKeyFile(
+      [node_info.uuid for node_info in node_list], key_file=pub_key_file)
+
+  # Update the master node's key files
+  (auth_key_file, _) = \
+    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=False, dircheck=False)
+  for node_info in node_list:
+    if node_info.to_authorized_keys:
+      ssh.AddAuthorizedKeys(auth_key_file, keys_by_uuid[node_info.uuid])
+
+  base_data = {}
+  _InitSshUpdateData(base_data, noded_cert_file, ssconf_store)
+  cluster_name = base_data[constants.SSHS_CLUSTER_NAME]
+
+  ssh_port_map = ssconf_store.GetSshPortMap()
+
+  # Update the target nodes themselves
+  for node_info in node_list:
+    logging.debug("Updating SSH key files of target node '%s'.", node_info.name)
+    if node_info.get_public_keys:
+      node_data = {}
+      _InitSshUpdateData(node_data, noded_cert_file, ssconf_store)
+      all_keys = ssh.QueryPubKeyFile(None, key_file=pub_key_file)
+      node_data[constants.SSHS_SSH_PUBLIC_KEYS] = \
+        (constants.SSHS_OVERRIDE, all_keys)
+
+      try:
+        utils.RetryByNumberOfTimes(
+            constants.SSHS_MAX_RETRIES,
+            errors.SshUpdateError,
+            run_cmd_fn, cluster_name, node_info.name, pathutils.SSH_UPDATE,
+            ssh_port_map.get(node_info.name), node_data,
+            debug=False, verbose=False, use_cluster_key=False,
+            ask_key=False, strict_host_check=False)
+      except errors.SshUpdateError as e:
+        # Clean up the master's public key file if adding key fails
+        if node_info.to_public_keys:
+          ssh.RemovePublicKey(node_info.uuid)
+        raise e
+
+  # Update all nodes except master and the target nodes
+  keys_by_uuid_auth = ssh.QueryPubKeyFile(
+      [node_info.uuid for node_info in node_list
+       if node_info.to_authorized_keys],
+      key_file=pub_key_file)
+  if to_authorized_keys:
+    base_data[constants.SSHS_SSH_AUTHORIZED_KEYS] = \
+      (constants.SSHS_ADD, keys_by_uuid_auth)
+
+  pot_mc_data = copy.deepcopy(base_data)
+  keys_by_uuid_pub = ssh.QueryPubKeyFile(
+      [node_info.uuid for node_info in node_list
+       if node_info.to_public_keys],
+      key_file=pub_key_file)
+  if to_public_keys:
+    pot_mc_data[constants.SSHS_SSH_PUBLIC_KEYS] = \
+      (constants.SSHS_REPLACE_OR_ADD, keys_by_uuid_pub)
+
+  all_nodes = ssconf_store.GetNodeList()
+  master_node = ssconf_store.GetMasterNode()
+  online_nodes = ssconf_store.GetOnlineNodeList()
+
+  node_errors = []
+  for node in all_nodes:
+    if node == master_node:
+      logging.debug("Skipping master node '%s'.", master_node)
+      continue
+    if node not in online_nodes:
+      logging.debug("Skipping offline node '%s'.", node)
+      continue
+    if node in potential_master_candidates:
+      logging.debug("Updating SSH key files of node '%s'.", node)
+      try:
+        utils.RetryByNumberOfTimes(
+            constants.SSHS_MAX_RETRIES,
+            errors.SshUpdateError,
+            run_cmd_fn, cluster_name, node, pathutils.SSH_UPDATE,
+            ssh_port_map.get(node), pot_mc_data,
+            debug=False, verbose=False, use_cluster_key=False,
+            ask_key=False, strict_host_check=False)
+      except errors.SshUpdateError as last_exception:
+        error_msg = ("When adding the key of node '%s', updating SSH key"
+                     " files of node '%s' failed after %s retries."
+                     " Not trying again. Last error was: %s." %
+                     (node, node_info.name, constants.SSHS_MAX_RETRIES,
+                      last_exception))
+        node_errors.append((node, error_msg))
+        # We only log the error and don't throw an exception, because
+        # one unreachable node shall not abort the entire procedure.
+        logging.error(error_msg)
+
+    else:
+      if to_authorized_keys:
+        run_cmd_fn(cluster_name, node, pathutils.SSH_UPDATE,
+                   ssh_port_map.get(node), base_data,
+                   debug=False, verbose=False, use_cluster_key=False,
+                   ask_key=False, strict_host_check=False)
+
+  return node_errors
+
+
 def RemoveNodeSshKey(node_uuid, node_name,
                      master_candidate_uuids,
                      potential_master_candidates,
index 393239a..bbf00db 100755 (executable)
@@ -30,6 +30,7 @@
 
 """Script for testing ganeti.backend"""
 
+import collections
 import copy
 import mock
 import os
@@ -1026,8 +1027,10 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
         self._ssh_file_manager.GetAllMasterCandidateUuids()
     self._all_nodes = self._ssh_file_manager.GetAllNodeNames()
 
-    self._ssconf_mock.GetNodeList.return_value = self._all_nodes
-    self._ssconf_mock.GetOnlineNodeList.return_value = self._all_nodes
+    self._ssconf_mock.GetNodeList.side_effect = \
+        self._ssh_file_manager.GetAllNodeNames
+    self._ssconf_mock.GetOnlineNodeList.side_effect = \
+        self._ssh_file_manager.GetAllNodeNames
 
   def _TearDownTestData(self):
     os.remove(self._pub_key_file)
@@ -1075,7 +1078,15 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
 
   def _GetNewMasterCandidate(self):
     """Returns the properties of a new master candidate node."""
-    return ("new_node_name", "new_node_uuid", "new_node_key", True, True, False)
+    return ("new_node_name", "new_node_uuid", "new_node_key",
+            True, True, False)
+
+  def _GetNewNumberedMasterCandidate(self, num):
+    """Returns the properties of a new master candidate node."""
+    return ("new_node_name_%s" % num,
+            "new_node_uuid_%s" % num,
+            "new_node_key_%s" % num,
+            True, True, False)
 
   def testAddMasterCandidate(self):
     (new_node_name, new_node_uuid, new_node_key, is_master_candidate,
@@ -1100,6 +1111,59 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
         new_node_name)
     self._ssh_file_manager.AssertAllNodesHaveAuthorizedKey(new_node_key)
 
+  def _SetupNodeBulk(self, num_nodes, node_fn):
+    """Sets up the test data for a bulk of nodes.
+
+    @param num_nodes: number of nodes
+    @type num_nodes: integer
+    @param node_fn: function
+    @param node_fn: function to generate data of one node, taking an
+      integer as only argument
+
+    """
+    node_list = []
+    key_map = {}
+
+    for i in range(num_nodes):
+      (new_node_name, new_node_uuid, new_node_key, is_master_candidate,
+       is_potential_master_candidate, is_master) = \
+          node_fn(i)
+
+      self._AddNewNodeToTestData(
+          new_node_name, new_node_uuid, new_node_key,
+          is_potential_master_candidate, is_master_candidate,
+          is_master)
+
+      node_list.append(
+          backend.SshAddNodeInfo(
+              uuid=new_node_uuid,
+              name=new_node_name,
+              to_authorized_keys=is_master_candidate,
+              to_public_keys=is_potential_master_candidate,
+              get_public_keys=is_potential_master_candidate))
+
+      key_map[new_node_name] = new_node_key
+
+    return (node_list, key_map)
+
+  def testAddMasterCandidateBulk(self):
+    num_nodes = 3
+    (node_list, key_map) = self._SetupNodeBulk(
+        num_nodes, self._GetNewNumberedMasterCandidate)
+
+    backend.AddNodeSshKeyBulk(node_list,
+                              self._potential_master_candidates,
+                              pub_key_file=self._pub_key_file,
+                              ssconf_store=self._ssconf_mock,
+                              noded_cert_file=self.noded_cert_file,
+                              run_cmd_fn=self._run_cmd_mock)
+
+    for node_info in node_list:
+      self._ssh_file_manager.AssertPotentialMasterCandidatesOnlyHavePublicKey(
+          node_info.name)
+      self._ssh_file_manager.AssertAllNodesHaveAuthorizedKey(
+          key_map[node_info.name])
+
   def testAddPotentialMasterCandidate(self):
     new_node_name = "new_node_name"
     new_node_uuid = "new_node_uuid"
@@ -1309,7 +1373,8 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
         is_potential_master_candidate, is_master_candidate,
         is_master)
     self._online_nodes = self._GetReducedOnlineNodeList()
-    self._ssconf_mock.GetOnlineNodeList.return_value = self._online_nodes
+    self._ssconf_mock.GetOnlineNodeList.side_effect = \
+        lambda : self._online_nodes
 
     backend.AddNodeSshKey(new_node_uuid, new_node_name,
                           self._potential_master_candidates,
@@ -1333,7 +1398,8 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
     (node_name, node_info) = \
         self._ssh_file_manager.GetAllMasterCandidates()[0]
     self._online_nodes = self._GetReducedOnlineNodeList()
-    self._ssconf_mock.GetOnlineNodeList.return_value = self._online_nodes
+    self._ssconf_mock.GetOnlineNodeList.side_effect = \
+        lambda : self._online_nodes
 
     backend.RemoveNodeSshKey(node_info.uuid, node_name,
                              self._master_candidate_uuids,