Source code for mpi_array.comms_test

"""
======================================
The :mod:`mpi_array.comms_test` Module
======================================

Module defining :mod:`mpi_array.comms` unit-tests.
Execute as::

   python -m mpi_array.comms_test

or::

   mpirun -n 4 python -m mpi_array.comms_test


Classes
=======

.. autosummary::
   :toctree: generated/
   :template: autosummary/inherits_TestCase_class.rst

   LocaleCommsTest - Tests for :obj:`mpi_array.comms.LocaleComms`.
   CartLocaleCommsTest - Tests for :obj:`mpi_array.comms.CartLocaleComms`.
   CreateDistributionTest - Tests for :func:`mpi_array.comms.create_distribution`.

"""
from __future__ import absolute_import

import mpi4py.MPI as _mpi
import numpy as _np  # noqa: E402,F401

from .license import license as _license, copyright as _copyright, version as _version
from . import unittest as _unittest
from . import logging as _logging  # noqa: E402,F401

from .comms import CartLocaleComms, LocaleComms
from .comms import create_single_locale_distribution, create_locale_comms, create_distribution
from .comms import check_distrib_type, DT_BLOCK, DT_SLAB, DT_CLONED, DT_SINGLE_LOCALE
from .comms import check_locale_type, LT_NODE, LT_PROCESS
from .utils import get_shared_mem_usage_percent_string
from .distribution import SingleLocaleDistribution as _SingleLocaleDistribution

__author__ = "Shane J. Latham"
__license__ = _license()
__copyright__ = _copyright()
__version__ = _version()


[docs]class LocaleCommsTest(_unittest.TestCase): """ Tests for :obj:`mpi_array.comms.LocaleComms`. """
[docs] def test_get_shared_mem_usage_percent_string(self): """ Coverage for :func:`mpi_array.comms.get_shared_mem_usage_percent_string`. """ p = \ get_shared_mem_usage_percent_string( shm_file_name="/probably/does/not_exist/on_file/system" ) self.assertEqual("unknown", p)
[docs] def test_construct(self): """ Test :meth:`mpi_array.comms.LocaleComms.__init__` """ i = LocaleComms(peer_comm=_mpi.COMM_WORLD) self.assertTrue(i.intra_locale_comm is not None) self.assertTrue(i.intra_locale_comm.size >= 1) self.assertTrue(i.peer_comm is not None) self.assertTrue(i.peer_comm.size >= 1) self.assertEqual(i.num_locales, len(i.peer_ranks_per_locale)) self.assertEqual( i.peer_comm.size, _np.sum(len(i.peer_ranks_per_locale[r]) for r in range(i.num_locales)) ) prpl_ranks = \ sorted(sum((tuple(i.peer_ranks_per_locale[r]) for r in range(i.num_locales)), ())) self.assertSequenceEqual( list(range(0, i.peer_comm.size)), prpl_ranks ) i = LocaleComms() self.assertTrue(i.intra_locale_comm is not None) self.assertTrue(i.intra_locale_comm.size >= 1) self.assertTrue(i.peer_comm is not None) self.assertTrue(i.peer_comm.size >= 1) i.inter_locale_comm = _mpi.COMM_NULL self.assertEqual(_mpi.COMM_NULL, i.inter_locale_comm) i.inter_locale_comm = None self.assertEqual(None, i.inter_locale_comm)
[docs] def test_construct_invalid_comms(self): """ Test :meth:`mpi_array.comms.LocaleComms.__init__` """ if _mpi.COMM_WORLD.size != _mpi.COMM_SELF.size: self.assertRaises( ValueError, LocaleComms, _mpi.COMM_SELF, # peer _mpi.COMM_SELF, # intra _mpi.COMM_WORLD # inter ) lc = LocaleComms() if lc.intra_locale_comm.size > 1: self.assertRaises( ValueError, LocaleComms, lc.peer_comm, # peer lc.intra_locale_comm, # intra lc.peer_comm # inter )
[docs] def test_construct_no_shared(self): lc = LocaleComms(intra_locale_comm=_mpi.COMM_SELF) self.assertEqual(_mpi.IDENT, _mpi.Comm.Compare(_mpi.COMM_WORLD, lc.peer_comm)) self.assertEqual(1, lc.intra_locale_comm.size) self.assertNotEqual(_mpi.COMM_WORLD, _mpi.COMM_NULL)
[docs]class CartLocaleCommsTest(_unittest.TestCase): """ :obj:`unittest.TestCase` for :obj:`mpi_array.comms.CartLocaleComms`. """
[docs] def test_construct_invalid_dims(self): lc = None with self.assertRaises(ValueError): lc = CartLocaleComms() with self.assertRaises(ValueError): lc = CartLocaleComms(ndims=None, dims=None) with self.assertRaises(ValueError): lc = CartLocaleComms(dims=tuple(), ndims=1) with self.assertRaises(ValueError): lc = CartLocaleComms(dims=tuple([0, 2]), ndims=1) with self.assertRaises(ValueError): lc = CartLocaleComms(dims=tuple([1, 2]), ndims=3) self.assertEqual(None, lc)
[docs] def test_construct_invalid_cart_comm(self): cart_comm = _mpi.COMM_WORLD.Create_cart(dims=(_mpi.COMM_WORLD.size,)) if _mpi.COMM_WORLD.size > 1: self.assertRaises( ValueError, CartLocaleComms, ndims=1, peer_comm=_mpi.COMM_WORLD, cart_comm=cart_comm )
[docs] def test_construct_shared(self): lc = CartLocaleComms(ndims=1) self.assertEqual(_mpi.IDENT, _mpi.Comm.Compare(_mpi.COMM_WORLD, lc.peer_comm)) self.assertEqual(1, lc.ndim) lc = CartLocaleComms(ndims=4) self.assertEqual(_mpi.IDENT, _mpi.Comm.Compare(_mpi.COMM_WORLD, lc.peer_comm)) self.assertEqual(4, lc.ndim) lc = CartLocaleComms(dims=(0,)) self.assertEqual(_mpi.IDENT, _mpi.Comm.Compare(_mpi.COMM_WORLD, lc.peer_comm)) self.assertEqual(1, lc.ndim) lc = CartLocaleComms(dims=(0, 0)) self.assertEqual(_mpi.IDENT, _mpi.Comm.Compare(_mpi.COMM_WORLD, lc.peer_comm)) self.assertEqual(2, lc.ndim) lc = CartLocaleComms(dims=(0, 0, 0)) self.assertEqual(_mpi.IDENT, _mpi.Comm.Compare(_mpi.COMM_WORLD, lc.peer_comm)) self.assertEqual(3, lc.ndim)
[docs] def test_construct_no_shared(self): lc = CartLocaleComms(ndims=1, intra_locale_comm=_mpi.COMM_SELF) self.assertEqual(_mpi.IDENT, _mpi.Comm.Compare(_mpi.COMM_WORLD, lc.peer_comm)) self.assertEqual(1, lc.intra_locale_comm.size) self.assertNotEqual(_mpi.COMM_WORLD, _mpi.COMM_NULL)
[docs] def test_alloc_locale_buffer(self): lc = CartLocaleComms(ndims=1) rma_window_buff = lc.alloc_locale_buffer(shape=(100,), dtype="uint16") self.assertEqual(_np.dtype("uint16"), rma_window_buff.dtype) self.assertEqual(_np.dtype("uint16").itemsize, rma_window_buff.itemsize) self.assertEqual(100 * rma_window_buff.dtype.itemsize, len(rma_window_buff.buffer)) lc = CartLocaleComms(ndims=1, intra_locale_comm=_mpi.COMM_SELF) rma_window_buff = lc.alloc_locale_buffer(shape=(100,), dtype="uint16") self.assertEqual(_np.dtype("uint16"), rma_window_buff.dtype) self.assertEqual(_np.dtype("uint16").itemsize, rma_window_buff.itemsize) self.assertEqual(100 * rma_window_buff.dtype.itemsize, len(rma_window_buff.buffer))
[docs]class CreateDistributionTest(_unittest.TestCase): """ Tests for :func:`mpi_array.comms.create_distribution`. """
[docs] def test_check_distrib_type(self): self.assertEqual(None, check_distrib_type(DT_SLAB)) self.assertEqual(None, check_distrib_type(DT_BLOCK)) self.assertEqual(None, check_distrib_type(DT_CLONED)) self.assertEqual(None, check_distrib_type(DT_SINGLE_LOCALE)) self.assertRaises(ValueError, check_distrib_type, "not_a_valid_distrib_type")
[docs] def test_check_locale_type(self): self.assertEqual(None, check_locale_type(LT_PROCESS)) self.assertEqual(None, check_locale_type(LT_NODE)) self.assertRaises(ValueError, check_locale_type, "not_a_valid_locale_type")
[docs] def test_create_locale_comms_invalid_args(self): """ Test that :func:`mpi_array.comms.create_locale_comms` raises exception for invalid arguments. """ if _mpi.COMM_WORLD.size > 1: self.assertRaises( ValueError, create_locale_comms, locale_type=LT_PROCESS, peer_comm=_mpi.COMM_WORLD, intra_locale_comm=_mpi.COMM_WORLD )
[docs] def check_is_single_locale_distribution(self, distrib): """ Asserts for checking that the :samp:`{distrib}` :obj:`Distribution` is single-locale. """ self.assertTrue(isinstance(distrib, _SingleLocaleDistribution)) gshape = tuple(distrib.globale_extent.shape_n) self.assertSequenceEqual( gshape, tuple(distrib.locale_extents[0].shape) ) self.assertSequenceEqual( (0, 0, 0, 0), tuple(distrib.locale_extents[0].start_n) ) self.assertSequenceEqual( gshape, tuple(distrib.locale_extents[0].stop_n) ) self.assertSequenceEqual( (0, 0, 0, 0), tuple(distrib.globale_extent.start_n) ) self.assertSequenceEqual( gshape, tuple(distrib.globale_extent.stop_n) )
[docs] def test_create_single_locale_distribution(self): """ Tests for :func:`mpi_array.comms.create_single_locale_distribution`. """ candd = \ create_single_locale_distribution( shape=(20, 31, 17, 4), locale_type=LT_PROCESS, peer_comm=_mpi.COMM_WORLD ) distrib = candd.distribution self.check_is_single_locale_distribution(distrib)
[docs] def test_create_distribution_slab(self): """ Tests for :func:`mpi_array.comms.create_distribution`. """ candd = \ create_distribution( shape=(20, 31, 17, 4), locale_type=LT_PROCESS, distrib_type=DT_SLAB, peer_comm=_mpi.COMM_WORLD ) distrib = candd.distribution self.assertSequenceEqual( (20, 31, 17, 4)[1:], tuple(distrib.locale_extents[0].shape)[1:] ) self.assertSequenceEqual( (0, 0, 0, 0)[1:], tuple(distrib.locale_extents[0].start_n)[1:] ) self.assertSequenceEqual( (20, 31, 17, 4)[1:], tuple(distrib.locale_extents[0].stop_n)[1:] ) self.assertEqual(candd.locale_comms.num_locales, distrib.num_locales) if distrib.num_locales > 1: for i in range(1, distrib.num_locales): self.assertSequenceEqual( (20, 31, 17, 4)[1:], tuple(distrib.locale_extents[i].shape)[1:] ) self.assertSequenceEqual( (0, 0, 0, 0)[1:], tuple(distrib.locale_extents[i].start_n)[1:] ) self.assertSequenceEqual( (20, 31, 17, 4)[1:], tuple(distrib.locale_extents[i].stop_n)[1:] ) self.assertSequenceEqual( (20, 31, 17, 4), tuple(distrib.globale_extent.shape) ) self.assertSequenceEqual( (0, 0, 0, 0), tuple(distrib.globale_extent.start_n) ) self.assertSequenceEqual( (20, 31, 17, 4), tuple(distrib.globale_extent.stop_n) )
[docs] def test_create_distribution_single_locale(self): """ Tests for :func:`mpi_array.comms.create_distribution`. """ candd = \ create_distribution( shape=(20, 31, 17, 4), locale_type=LT_PROCESS, distrib_type=DT_SINGLE_LOCALE, peer_comm=_mpi.COMM_WORLD ) distrib = candd.distribution self.assertEqual(candd.locale_comms.num_locales, distrib.num_locales) self.check_is_single_locale_distribution(distrib)
_unittest.main(__name__) __all__ = [s for s in dir() if not s.startswith('_')]