Source code for tests.test_solver

"""
This module contains unit tests for the :py:func:`src.solver.steady_state_transport_solver` function from the :py:mod:`src.solver` module. It validates the numerical solver against the analytical solution for the steady-state advection-diffusion equation.

The tests ensure that the numerical solution matches the analytical solution within a specified tolerance and that the outputs have the correct shapes and types.

Functions:
----------
- test_steady_state_transport_solver: Tests the solver against an analytical solution.
"""

import numpy as np
import matplotlib.pyplot as plt
from src.solver import steady_state_transport_solver


[docs] def test_steady_state_transport_solver(): """ Tests the :py:func:`src.solver.steady_state_transport_solver` function against an analytical solution for the steady-state advection-diffusion equation with a point source at the center of the domain. The test validates: - The shape and type of the solver's outputs. - The numerical accuracy of the solver by comparing its results to the analytical solution within a specified tolerance. Raises: AssertionError: If any of the output shapes, types, or numerical accuracy checks fail. """ # Define test inputs nx, ny = 50, 50 # Grid size x = np.linspace(-50, 50, nx) y = np.linspace(-50, 50, ny) z = np.linspace(0, 100, 1) # Vertical grid points u = 1.0 # Constant wind speed K = 10.0 # Constant eddy diffusivity Q = 1.0 # Source strength profiles = (np.ones(len(z)) * u, np.zeros(len(z)), np.ones(len(z)) * K) domain = (100, 100) # Domain size # Create a point source at the center of the domain srf_flx = np.zeros((ny, nx)) srf_flx[ny // 4, nx // 4] = Q # Call the solver with analytic=True srf_conc_analytic, bg_conc_analytic, conc_analytic, flx_analytic = ( steady_state_transport_solver(srf_flx, z, profiles, domain, analytic=True) ) # Call the solver with analytic=False srf_conc_numeric, bg_conc_numeric, conc_numeric, flx_numeric = ( steady_state_transport_solver(srf_flx, z, profiles, domain, analytic=False) ) # Assertions to validate numeric output shapes assert srf_conc_numeric.shape == (ny, nx), "Surface concentration shape mismatch" assert isinstance( bg_conc_numeric, float ), "Background concentration should be a float" assert conc_numeric.shape == (ny, nx), "Concentration shape mismatch" assert flx_numeric.shape == (ny, nx), "Flux shape mismatch" # Compare the results assert np.allclose( srf_conc_analytic, srf_conc_numeric, atol=1e-2 ), "Surface concentration mismatch between analytic and numeric" assert np.isclose( bg_conc_analytic, bg_conc_numeric, atol=1e-2 ), "Background concentration mismatch between analytic and numeric" assert np.allclose( conc_analytic, conc_numeric, atol=1e-2 ), "Concentration mismatch between analytic and numeric" assert np.allclose( flx_analytic, flx_numeric, atol=1e-2 ), "Flux mismatch between analytic and numeric" # Plot solutions plt.figure(figsize=(12, 6)) # Plot numerical solution plt.subplot(1, 2, 1) plt.title("Numerical Solution") plt.imshow( srf_conc_numeric, origin="lower", extent=[x.min(), x.max(), y.min(), y.max()] ) plt.colorbar() # Plot analytical solution plt.subplot(1, 2, 2) plt.title("Analytical Solution") plt.imshow( srf_conc_analytic, origin="lower", extent=[x.min(), x.max(), y.min(), y.max()] ) plt.colorbar() plt.savefig("plots/analytical_vs_numerical_solution.png") # Compare numerical and analytical solutions error = np.abs(srf_conc_numeric - srf_conc_analytic) max_error = np.max(error) assert ( max_error < 1e-2 ), f"Numerical solution deviates too much from analytical solution (max error: {max_error})"
if __name__ == "__main__": # Run the test manually (optional, for debugging) test_steady_state_transport_solver()