Make mock SSH file manager deal with lists
authorHelga Velroyen <helgav@google.com>
Thu, 19 Nov 2015 15:13:17 +0000 (16:13 +0100)
committerHelga Velroyen <helgav@google.com>
Thu, 17 Dec 2015 08:12:42 +0000 (09:12 +0100)
There was a subtle bug in the unit test of backend.py
which was masking another subtle bug in the test framework
in testutils_ssh.py.

As relict from some previous refactoring, the ssh.py
functions assume that there can be more than one public
key per node. The testutils so far assume there is only
one key per node and due to a bug, this cancelled out
nicely and was not found so far.

As we actually only have one key per node, the elegant
thing to do would be to adapt ssh.py rather than the
testutils, but that will break the interface of the
ssh_update.py tool. Since we would rather not do that
in a stable, branch, this patch adapts the testutils.
The adaption of the ssh.py will be done in a newer
branch then.

Additionally, this patch also sprinkles assertions
everywhere to ensure finding these kind of type messups
sooner.

Signed-off-by: Helga Velroyen <helgav@google.com>
Reviewed-by: Lisa Velden <velden@google.com>

test/py/ganeti.backend_unittest.py
test/py/testutils_ssh.py

index 43d2dde..35ea9f4 100755 (executable)
@@ -1031,6 +1031,8 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
         self._ssh_file_manager.GetAllNodeNames
     self._ssconf_mock.GetOnlineNodeList.side_effect = \
         self._ssh_file_manager.GetAllNodeNames
+    self._ssconf_mock.GetMasterNode.side_effect = \
+        self._ssh_file_manager.GetMasterNodeName
 
   def _TearDownTestData(self):
     os.remove(self._pub_key_file)
@@ -1579,8 +1581,9 @@ class TestAddRemoveGenerateNodeSshKey(testutils.GanetiTestCase):
         noded_cert_file=self.noded_cert_file,
         run_cmd_fn=self._run_cmd_mock)
 
+    master_node = self._ssh_file_manager.GetMasterNodeName()
     for node in self._all_nodes:
-      if node == new_node_name:
+      if node in [new_node_name, master_node]:
         self.assertTrue(self._ssh_file_manager.NodeHasAuthorizedKey(
           node, new_node_key))
       else:
index dd14303..e4d76c1 100644 (file)
@@ -77,6 +77,9 @@ class FakeSshFileManager(object):
     # 'RunCommand' has already carried out.
     self._retries = {}
 
+    self._AssertTypePublicKeys()
+    self._AssertTypeAuthorizedKeys()
+
   _NodeInfo = namedtuple(
       "NodeInfo",
       ["uuid",
@@ -110,6 +113,9 @@ class FakeSshFileManager(object):
 
       self._all_node_data[name] = self._NodeInfo(uuid, key, pot_mc, mc, master)
 
+    self._AssertTypePublicKeys()
+    self._AssertTypeAuthorizedKeys()
+
   def _FillPublicKeyOfOneNode(self, receiving_node_name):
     node_info = self._all_node_data[receiving_node_name]
     # Nodes which are not potential master candidates receive no keys
@@ -117,7 +123,7 @@ class FakeSshFileManager(object):
       return
     for node_info in self._all_node_data.values():
       if node_info.is_potential_master_candidate:
-        self._public_keys[receiving_node_name][node_info.uuid] = node_info.key
+        self._public_keys[receiving_node_name][node_info.uuid] = [node_info.key]
 
   def _FillAuthorizedKeyOfOneNode(self, receiving_node_name):
     for node_info in self._all_node_data.values():
@@ -142,6 +148,8 @@ class FakeSshFileManager(object):
       self._FillPublicKeyOfOneNode(node)
       self._FillAuthorizedKeyOfOneNode(node)
     self._SetMasterNodeName()
+    self._AssertTypePublicKeys()
+    self._AssertTypeAuthorizedKeys()
 
   def SetMaxRetries(self, node_name, retries):
     """Set the number of unsuccessful retries of 'RunCommand' per node.
@@ -205,7 +213,7 @@ class FakeSshFileManager(object):
 
     This is necessary when testing to add new nodes to the cluster. Otherwise
     this new node's state would not be evaluated properly with the assertion
-    fuctions.
+    functions.
 
     @type name: string
     @param name: name of the new node
@@ -228,6 +236,8 @@ class FakeSshFileManager(object):
       self._authorized_keys[name].add(key)
     if name not in self._public_keys:
       self._public_keys[name] = {}
+    self._AssertTypePublicKeys()
+    self._AssertTypeAuthorizedKeys()
 
   def NodeHasPublicKey(self, file_node_name, key_node_uuid, key):
     """Checks whether a node has another node's public key.
@@ -355,6 +365,37 @@ class FakeSshFileManager(object):
     self.AssertNodeSetOnlyHasPublicKey(
         potential_master_candidates, query_node_uuid, query_node_key)
 
+  def _AssertTypePublicKeys(self):
+    """Asserts that the public key dictionary has the right types.
+
+    This is helpful as an invariant that shall not be violated during the
+    tests due to type errors.
+
+    """
+    assert isinstance(self._public_keys, dict)
+    for node_file, pub_keys in self._public_keys.items():
+      assert isinstance(node_file, str)
+      assert isinstance(pub_keys, dict)
+      for node_key, keys in pub_keys.items():
+        assert isinstance(node_key, str)
+        assert isinstance(keys, list)
+        for key in keys:
+          assert isinstance(key, str)
+
+  def _AssertTypeAuthorizedKeys(self):
+    """Asserts that the authorized keys dictionary has the right types.
+
+    This is useful to check as an invariant that is not supposed to be violated
+    during the tests.
+
+    """
+    assert isinstance(self._authorized_keys, dict)
+    for node_file, auth_keys in self._authorized_keys.items():
+      assert isinstance(node_file, str)
+      assert isinstance(auth_keys, set)
+      for key in auth_keys:
+        assert isinstance(key, str)
+
   # Disabling a pylint warning about unused parameters. Those need
   # to be here to properly mock the real methods.
   # pylint: disable=W0613
@@ -392,52 +433,77 @@ class FakeSshFileManager(object):
   def _EnsureAuthKeyFile(self, file_node_name):
     if file_node_name not in self._authorized_keys:
       self._authorized_keys[file_node_name] = set()
+    self._AssertTypePublicKeys()
+    self._AssertTypeAuthorizedKeys()
 
-  def _AddAuthorizedKey(self, file_node_name, ssh_key):
+  def _AddAuthorizedKeys(self, file_node_name, ssh_keys):
+    """Mocks adding the given keys to the authorized_keys file."""
+    assert isinstance(ssh_keys, list)
     self._EnsureAuthKeyFile(file_node_name)
-    self._authorized_keys[file_node_name].add(ssh_key)
+    for key in ssh_keys:
+      self._authorized_keys[file_node_name].add(key)
+    self._AssertTypePublicKeys()
+    self._AssertTypeAuthorizedKeys()
+
+  def _RemoveAuthorizedKeys(self, file_node_name, keys):
+    """Mocks removing the keys from authorized_keys on the given node.
 
-  def _RemoveAuthorizedKey(self, file_node_name, key):
+    @param keys: list of ssh keys
+    @type keys: list of strings
+
+    """
     self._EnsureAuthKeyFile(file_node_name)
     self._authorized_keys[file_node_name] = \
-        set([k for k in self._authorized_keys[file_node_name] if k != key])
+        set([k for k in self._authorized_keys[file_node_name] if k not in keys])
+    self._AssertTypeAuthorizedKeys()
 
   def _HandleAuthorizedKeys(self, instructions, node):
     (action, authorized_keys) = instructions
-    ssh_keys = authorized_keys.values()
+    ssh_key_sets = authorized_keys.values()
     if action == constants.SSHS_ADD:
-      for ssh_key in ssh_keys:
-        self._AddAuthorizedKey(node, ssh_key)
+      for ssh_keys in ssh_key_sets:
+        self._AddAuthorizedKeys(node, ssh_keys)
     elif action == constants.SSHS_REMOVE:
-      for ssh_key in ssh_keys:
-        self._RemoveAuthorizedKey(node, ssh_key)
+      for ssh_keys in ssh_key_sets:
+        self._RemoveAuthorizedKeys(node, ssh_keys)
     else:
       raise Exception("Unsupported action: %s" % action)
+    self._AssertTypeAuthorizedKeys()
 
   def _EnsurePublicKeyFile(self, file_node_name):
     if file_node_name not in self._public_keys:
       self._public_keys[file_node_name] = {}
+    self._AssertTypePublicKeys()
 
   def _ClearPublicKeys(self, file_node_name):
     self._public_keys[file_node_name] = {}
+    self._AssertTypePublicKeys()
 
   def _OverridePublicKeys(self, ssh_keys, file_node_name):
+    assert isinstance(ssh_keys, dict)
     self._ClearPublicKeys(file_node_name)
     for key_node_uuid, node_keys in ssh_keys.items():
+      assert isinstance(node_keys, list)
       if key_node_uuid in self._public_keys[file_node_name]:
         raise Exception("Duplicate node in ssh_update data.")
       self._public_keys[file_node_name][key_node_uuid] = node_keys
+    self._AssertTypePublicKeys()
 
   def _ReplaceOrAddPublicKeys(self, public_keys, file_node_name):
+    assert isinstance(public_keys, dict)
     self._EnsurePublicKeyFile(file_node_name)
     for key_node_uuid, keys in public_keys.items():
+      assert isinstance(keys, list)
       self._public_keys[file_node_name][key_node_uuid] = keys
+    self._AssertTypePublicKeys()
 
   def _RemovePublicKeys(self, public_keys, file_node_name):
+    assert isinstance(public_keys, dict)
     self._EnsurePublicKeyFile(file_node_name)
     for key_node_uuid, _ in public_keys.items():
       if key_node_uuid in self._public_keys[file_node_name]:
         self._public_keys[file_node_name][key_node_uuid] = []
+    self._AssertTypePublicKeys()
 
   def _HandlePublicKeys(self, instructions, node):
     (action, public_keys) = instructions
@@ -453,6 +519,7 @@ class FakeSshFileManager(object):
       self._ClearPublicKeys(node)
     else:
       raise Exception("Unsupported action: %s." % action)
+    self._AssertTypePublicKeys()
 
   # pylint: disable=W0613
   def AddAuthorizedKeys(self, file_obj, keys):
@@ -464,9 +531,10 @@ class FakeSshFileManager(object):
     @see: C{ssh.AddAuthorizedKeys}
 
     """
+    assert isinstance(keys, list)
     assert self._master_node_name
-    for key in keys:
-      self._AddAuthorizedKey(self._master_node_name, key)
+    self._AddAuthorizedKeys(self._master_node_name, keys)
+    self._AssertTypeAuthorizedKeys()
 
   def RemoveAuthorizedKeys(self, file_name, keys):
     """Emulates ssh.RemoveAuthorizeKeys on the master node.
@@ -477,9 +545,10 @@ class FakeSshFileManager(object):
     @see: C{ssh.RemoveAuthorizedKeys}
 
     """
+    assert isinstance(keys, list)
     assert self._master_node_name
-    for key in keys:
-      self._RemoveAuthorizedKey(self._master_node_name, key)
+    self._RemoveAuthorizedKeys(self._master_node_name, keys)
+    self._AssertTypeAuthorizedKeys()
 
   def AddPublicKey(self, new_uuid, new_key, **kwargs):
     """Emulates ssh.AddPublicKey on the master node.
@@ -491,8 +560,10 @@ class FakeSshFileManager(object):
 
     """
     assert self._master_node_name
-    key_dict = {new_uuid: new_key}
+    assert isinstance(new_key, str)
+    key_dict = {new_uuid: [new_key]}
     self._ReplaceOrAddPublicKeys(key_dict, self._master_node_name)
+    self._AssertTypePublicKeys()
 
   def RemovePublicKey(self, target_uuid, **kwargs):
     """Emulates ssh.RemovePublicKey on the master node.
@@ -506,6 +577,7 @@ class FakeSshFileManager(object):
     assert self._master_node_name
     key_dict = {target_uuid: []}
     self._RemovePublicKeys(key_dict, self._master_node_name)
+    self._AssertTypePublicKeys()
 
   def QueryPubKeyFile(self, target_uuids, **kwargs):
     """Emulates ssh.QueryPubKeyFile on the master node.
@@ -516,6 +588,7 @@ class FakeSshFileManager(object):
     @see: C{ssh.QueryPubKey}
 
     """
+
     assert self._master_node_name
     all_keys = target_uuids is None
     if all_keys:
@@ -524,10 +597,11 @@ class FakeSshFileManager(object):
     if isinstance(target_uuids, str):
       target_uuids = [target_uuids]
     result_dict = {}
-    for key_node_uuid, key in \
+    for key_node_uuid, keys in \
         self._public_keys[self._master_node_name].items():
       if key_node_uuid in target_uuids:
-        result_dict[key_node_uuid] = key
+        result_dict[key_node_uuid] = keys
+    self._AssertTypePublicKeys()
     return result_dict
 
   def ReplaceNameByUuid(self, node_uuid, node_name, **kwargs):
@@ -539,9 +613,12 @@ class FakeSshFileManager(object):
     @see: C{ssh.ReplacenameByUuid}
 
     """
+    assert isinstance(node_uuid, str)
+    assert isinstance(node_name, str)
     assert self._master_node_name
     if node_name in self._public_keys[self._master_node_name]:
       self._public_keys[self._master_node_name][node_uuid] = \
         self._public_keys[self._master_node_name][node_name][:]
       del self._public_keys[self._master_node_name][node_name]
+    self._AssertTypePublicKeys()
   # pylint: enable=W0613