You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					231 lines
				
				6.0 KiB
			
		
		
			
		
	
	
					231 lines
				
				6.0 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								import math
							 | 
						||
| 
								 | 
							
								import textwrap
							 | 
						||
| 
								 | 
							
								import sys
							 | 
						||
| 
								 | 
							
								import pytest
							 | 
						||
| 
								 | 
							
								import threading
							 | 
						||
| 
								 | 
							
								import traceback
							 | 
						||
| 
								 | 
							
								import time
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import numpy as np
							 | 
						||
| 
								 | 
							
								from numpy.testing import IS_PYPY
							 | 
						||
| 
								 | 
							
								from . import util
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class TestF77Callback(util.F2PyTest):
							 | 
						||
| 
								 | 
							
								    sources = [util.getpath("tests", "src", "callback", "foo.f")]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @pytest.mark.parametrize("name", "t,t2".split(","))
							 | 
						||
| 
								 | 
							
								    def test_all(self, name):
							 | 
						||
| 
								 | 
							
								        self.check_function(name)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @pytest.mark.xfail(IS_PYPY,
							 | 
						||
| 
								 | 
							
								                       reason="PyPy cannot modify tp_doc after PyType_Ready")
							 | 
						||
| 
								 | 
							
								    def test_docstring(self):
							 | 
						||
| 
								 | 
							
								        expected = textwrap.dedent("""\
							 | 
						||
| 
								 | 
							
								        a = t(fun,[fun_extra_args])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Wrapper for ``t``.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Parameters
							 | 
						||
| 
								 | 
							
								        ----------
							 | 
						||
| 
								 | 
							
								        fun : call-back function
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Other Parameters
							 | 
						||
| 
								 | 
							
								        ----------------
							 | 
						||
| 
								 | 
							
								        fun_extra_args : input tuple, optional
							 | 
						||
| 
								 | 
							
								            Default: ()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Returns
							 | 
						||
| 
								 | 
							
								        -------
							 | 
						||
| 
								 | 
							
								        a : int
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Notes
							 | 
						||
| 
								 | 
							
								        -----
							 | 
						||
| 
								 | 
							
								        Call-back functions::
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            def fun(): return a
							 | 
						||
| 
								 | 
							
								            Return objects:
							 | 
						||
| 
								 | 
							
								                a : int
							 | 
						||
| 
								 | 
							
								        """)
							 | 
						||
| 
								 | 
							
								        assert self.module.t.__doc__ == expected
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def check_function(self, name):
							 | 
						||
| 
								 | 
							
								        t = getattr(self.module, name)
							 | 
						||
| 
								 | 
							
								        r = t(lambda: 4)
							 | 
						||
| 
								 | 
							
								        assert r == 4
							 | 
						||
| 
								 | 
							
								        r = t(lambda a: 5, fun_extra_args=(6, ))
							 | 
						||
| 
								 | 
							
								        assert r == 5
							 | 
						||
| 
								 | 
							
								        r = t(lambda a: a, fun_extra_args=(6, ))
							 | 
						||
| 
								 | 
							
								        assert r == 6
							 | 
						||
| 
								 | 
							
								        r = t(lambda a: 5 + a, fun_extra_args=(7, ))
							 | 
						||
| 
								 | 
							
								        assert r == 12
							 | 
						||
| 
								 | 
							
								        r = t(lambda a: math.degrees(a), fun_extra_args=(math.pi, ))
							 | 
						||
| 
								 | 
							
								        assert r == 180
							 | 
						||
| 
								 | 
							
								        r = t(math.degrees, fun_extra_args=(math.pi, ))
							 | 
						||
| 
								 | 
							
								        assert r == 180
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        r = t(self.module.func, fun_extra_args=(6, ))
							 | 
						||
| 
								 | 
							
								        assert r == 17
							 | 
						||
| 
								 | 
							
								        r = t(self.module.func0)
							 | 
						||
| 
								 | 
							
								        assert r == 11
							 | 
						||
| 
								 | 
							
								        r = t(self.module.func0._cpointer)
							 | 
						||
| 
								 | 
							
								        assert r == 11
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        class A:
							 | 
						||
| 
								 | 
							
								            def __call__(self):
							 | 
						||
| 
								 | 
							
								                return 7
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            def mth(self):
							 | 
						||
| 
								 | 
							
								                return 9
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        a = A()
							 | 
						||
| 
								 | 
							
								        r = t(a)
							 | 
						||
| 
								 | 
							
								        assert r == 7
							 | 
						||
| 
								 | 
							
								        r = t(a.mth)
							 | 
						||
| 
								 | 
							
								        assert r == 9
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @pytest.mark.skipif(sys.platform == 'win32',
							 | 
						||
| 
								 | 
							
								                        reason='Fails with MinGW64 Gfortran (Issue #9673)')
							 | 
						||
| 
								 | 
							
								    def test_string_callback(self):
							 | 
						||
| 
								 | 
							
								        def callback(code):
							 | 
						||
| 
								 | 
							
								            if code == "r":
							 | 
						||
| 
								 | 
							
								                return 0
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                return 1
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        f = getattr(self.module, "string_callback")
							 | 
						||
| 
								 | 
							
								        r = f(callback)
							 | 
						||
| 
								 | 
							
								        assert r == 0
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @pytest.mark.skipif(sys.platform == 'win32',
							 | 
						||
| 
								 | 
							
								                        reason='Fails with MinGW64 Gfortran (Issue #9673)')
							 | 
						||
| 
								 | 
							
								    def test_string_callback_array(self):
							 | 
						||
| 
								 | 
							
								        # See gh-10027
							 | 
						||
| 
								 | 
							
								        cu1 = np.zeros((1, ), "S8")
							 | 
						||
| 
								 | 
							
								        cu2 = np.zeros((1, 8), "c")
							 | 
						||
| 
								 | 
							
								        cu3 = np.array([""], "S8")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        def callback(cu, lencu):
							 | 
						||
| 
								 | 
							
								            if cu.shape != (lencu,):
							 | 
						||
| 
								 | 
							
								                return 1
							 | 
						||
| 
								 | 
							
								            if cu.dtype != "S8":
							 | 
						||
| 
								 | 
							
								                return 2
							 | 
						||
| 
								 | 
							
								            if not np.all(cu == b""):
							 | 
						||
| 
								 | 
							
								                return 3
							 | 
						||
| 
								 | 
							
								            return 0
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        f = getattr(self.module, "string_callback_array")
							 | 
						||
| 
								 | 
							
								        for cu in [cu1, cu2, cu3]:
							 | 
						||
| 
								 | 
							
								            res = f(callback, cu, cu.size)
							 | 
						||
| 
								 | 
							
								            assert res == 0
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_threadsafety(self):
							 | 
						||
| 
								 | 
							
								        # Segfaults if the callback handling is not threadsafe
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        errors = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        def cb():
							 | 
						||
| 
								 | 
							
								            # Sleep here to make it more likely for another thread
							 | 
						||
| 
								 | 
							
								            # to call their callback at the same time.
							 | 
						||
| 
								 | 
							
								            time.sleep(1e-3)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # Check reentrancy
							 | 
						||
| 
								 | 
							
								            r = self.module.t(lambda: 123)
							 | 
						||
| 
								 | 
							
								            assert r == 123
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            return 42
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        def runner(name):
							 | 
						||
| 
								 | 
							
								            try:
							 | 
						||
| 
								 | 
							
								                for j in range(50):
							 | 
						||
| 
								 | 
							
								                    r = self.module.t(cb)
							 | 
						||
| 
								 | 
							
								                    assert r == 42
							 | 
						||
| 
								 | 
							
								                    self.check_function(name)
							 | 
						||
| 
								 | 
							
								            except Exception:
							 | 
						||
| 
								 | 
							
								                errors.append(traceback.format_exc())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        threads = [
							 | 
						||
| 
								 | 
							
								            threading.Thread(target=runner, args=(arg, ))
							 | 
						||
| 
								 | 
							
								            for arg in ("t", "t2") for n in range(20)
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for t in threads:
							 | 
						||
| 
								 | 
							
								            t.start()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for t in threads:
							 | 
						||
| 
								 | 
							
								            t.join()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        errors = "\n\n".join(errors)
							 | 
						||
| 
								 | 
							
								        if errors:
							 | 
						||
| 
								 | 
							
								            raise AssertionError(errors)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_hidden_callback(self):
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            self.module.hidden_callback(2)
							 | 
						||
| 
								 | 
							
								        except Exception as msg:
							 | 
						||
| 
								 | 
							
								            assert str(msg).startswith("Callback global_f not defined")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            self.module.hidden_callback2(2)
							 | 
						||
| 
								 | 
							
								        except Exception as msg:
							 | 
						||
| 
								 | 
							
								            assert str(msg).startswith("cb: Callback global_f not defined")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.module.global_f = lambda x: x + 1
							 | 
						||
| 
								 | 
							
								        r = self.module.hidden_callback(2)
							 | 
						||
| 
								 | 
							
								        assert r == 3
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.module.global_f = lambda x: x + 2
							 | 
						||
| 
								 | 
							
								        r = self.module.hidden_callback(2)
							 | 
						||
| 
								 | 
							
								        assert r == 4
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        del self.module.global_f
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            self.module.hidden_callback(2)
							 | 
						||
| 
								 | 
							
								        except Exception as msg:
							 | 
						||
| 
								 | 
							
								            assert str(msg).startswith("Callback global_f not defined")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        self.module.global_f = lambda x=0: x + 3
							 | 
						||
| 
								 | 
							
								        r = self.module.hidden_callback(2)
							 | 
						||
| 
								 | 
							
								        assert r == 5
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # reproducer of gh18341
							 | 
						||
| 
								 | 
							
								        r = self.module.hidden_callback2(2)
							 | 
						||
| 
								 | 
							
								        assert r == 3
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class TestF77CallbackPythonTLS(TestF77Callback):
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    Callback tests using Python thread-local storage instead of
							 | 
						||
| 
								 | 
							
								    compiler-provided
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    options = ["-DF2PY_USE_PYTHON_TLS"]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class TestF90Callback(util.F2PyTest):
							 | 
						||
| 
								 | 
							
								    sources = [util.getpath("tests", "src", "callback", "gh17797.f90")]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_gh17797(self):
							 | 
						||
| 
								 | 
							
								        def incr(x):
							 | 
						||
| 
								 | 
							
								            return x + 123
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        y = np.array([1, 2, 3], dtype=np.int64)
							 | 
						||
| 
								 | 
							
								        r = self.module.gh17797(incr, y)
							 | 
						||
| 
								 | 
							
								        assert r == 123 + 1 + 2 + 3
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class TestGH18335(util.F2PyTest):
							 | 
						||
| 
								 | 
							
								    """The reproduction of the reported issue requires specific input that
							 | 
						||
| 
								 | 
							
								    extensions may break the issue conditions, so the reproducer is
							 | 
						||
| 
								 | 
							
								    implemented as a separate test class. Do not extend this test with
							 | 
						||
| 
								 | 
							
								    other tests!
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    sources = [util.getpath("tests", "src", "callback", "gh18335.f90")]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def test_gh18335(self):
							 | 
						||
| 
								 | 
							
								        def foo(x):
							 | 
						||
| 
								 | 
							
								            x[0] += 1
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        r = self.module.gh18335(foo)
							 | 
						||
| 
								 | 
							
								        assert r == 123 + 1
							 |