__all__ = ("_FactorialGenerator", "ZernikePolynomialGenerator")
import numbers
import numpy as np
class _FactorialGenerator:
"""
A class that generates factorials
and stores them in a dict to be referenced
as needed.
"""
def __init__(self):
self._values = {0: 1, 1: 1}
self._max_i = 1
def evaluate(self, num):
"""
Return the factorial of num
"""
if num < 0:
raise RuntimeError("Cannot handle negative factorial")
i_num = int(np.round(num))
if i_num in self._values:
return self._values[i_num]
val = self._values[self._max_i]
for ii in range(self._max_i, num):
val *= ii + 1
self._values[ii + 1] = val
self._max_i = num
return self._values[num]
[docs]
class ZernikePolynomialGenerator:
"""
A class to generate and evaluate the Zernike
polynomials. Definitions of Zernike polynomials
are taken from
https://en.wikipedia.org/wiki/Zernike_polynomials
"""
def __init__(self):
self._factorial = _FactorialGenerator()
self._coeffs = {}
self._powers = {}
def _validate_nm(self, n, m):
"""
Make sure that n, m are a valid pair of indices for
a Zernike polynomial.
n is the radial order
m is the angular order
"""
if not isinstance(n, int) and not isinstance(n, np.int64):
raise RuntimeError("Zernike polynomial n must be int")
if not isinstance(m, int) and not isinstance(m, np.int64):
raise RuntimeError("Zernike polynomial m must be int")
if n < 0:
raise RuntimeError("Radial Zernike n cannot be negative")
if m < 0:
raise RuntimeError("Radial Zernike m cannot be negative")
if n < m:
raise RuntimeError("Radial Zerniki n must be >= m")
n = int(n)
m = int(m)
return (n, m)
def _make_polynomial(self, n, m):
"""
Make the radial part of the n, m Zernike
polynomial.
n is the radial order
m is the angular order
Returns 2 numpy arrays: coeffs and powers.
The radial part of the Zernike polynomial is
sum([coeffs[ii]*power(r, powers[ii])
for ii in range(len(coeffs))])
"""
n, m = self._validate_nm(n, m)
# coefficients taken from
# https://en.wikipedia.org/wiki/Zernike_polynomials
n_coeffs = 1 + (n - m) // 2
local_coeffs = np.zeros(n_coeffs, dtype=float)
local_powers = np.zeros(n_coeffs, dtype=float)
for k in range(0, n_coeffs):
if k % 2 == 0:
sgn = 1.0
else:
sgn = -1.0
num_fac = self._factorial.evaluate(n - k)
k_fac = self._factorial.evaluate(k)
d1_fac = self._factorial.evaluate(((n + m) // 2) - k)
d2_fac = self._factorial.evaluate(((n - m) // 2) - k)
local_coeffs[k] = sgn * num_fac / (k_fac * d1_fac * d2_fac)
local_powers[k] = n - 2 * k
self._coeffs[(n, m)] = local_coeffs
self._powers[(n, m)] = local_powers
def _evaluate_radial_number(self, r, nm_tuple):
"""
Evaluate the radial part of a Zernike polynomial.
r is a scalar value
nm_tuple is a tuple of the form (radial order, angular order)
denoting the polynomial to evaluate
Return the value of the radial part of the polynomial at r
"""
if r > 1.0:
return np.nan
r_term = np.power(r, self._powers[nm_tuple])
return (self._coeffs[nm_tuple] * r_term).sum()
def _evaluate_radial_array(self, r, nm_tuple):
"""
Evaluate the radial part of a Zernike polynomial.
r is a numpy array of radial values
nm_tuple is a tuple of the form (radial order, angular order)
denoting the polynomial to evaluate
Return the values of the radial part of the polynomial at r
(returns np.nan if r>1.0)
"""
if len(r) == 0:
return np.array([], dtype=float)
# since we use np.where to handle cases of
# r==0, use np.errstate to temporarily
# turn off the divide by zero and
# invalid double scalar RuntimeWarnings
with np.errstate(divide="ignore", invalid="ignore"):
log_r = np.log(r)
log_r = np.where(np.isfinite(log_r), log_r, -1.0e10)
r_power = np.exp(np.outer(log_r, self._powers[nm_tuple]))
results = np.dot(r_power, self._coeffs[nm_tuple])
return np.where(r < 1.0, results, np.nan)
def _evaluate_radial(self, r, n, m):
"""
Evaluate the radial part of a Zernike polynomial
r is a radial value or an array of radial values
n is the radial order of the polynomial
m is the angular order of the polynomial
Return the value(s) of the radial part of the polynomial at r
(returns np.nan if r>1.0)
"""
is_array = False
if not isinstance(r, numbers.Number):
is_array = True
nm_tuple = self._validate_nm(n, m)
if (nm_tuple[0] - nm_tuple[1]) % 2 == 1:
if is_array:
return np.zeros(len(r), dtype=float)
return 0.0
if nm_tuple not in self._coeffs:
self._make_polynomial(nm_tuple[0], nm_tuple[1])
if is_array:
return self._evaluate_radial_array(r, nm_tuple)
return self._evaluate_radial_number(r, nm_tuple)
[docs]
def evaluate(self, r, phi, n, m):
"""
Evaluate a Zernike polynomial in polar coordinates
r is the radial coordinate (a scalar or an array)
phi is the angular coordinate in radians (a scalar or an array)
n is the radial order of the polynomial
m is the angular order of the polynomial
Return the value(s) of the polynomial at r, phi
(returns np.nan if r>1.0)
"""
radial_part = self._evaluate_radial(r, n, np.abs(m))
if m >= 0:
return radial_part * np.cos(m * phi)
return radial_part * np.sin(m * phi)
[docs]
def norm(self, n, m):
"""
Return the normalization of the n, m Zernike
polynomial
n is the radial order
m is the angular order
"""
nm_tuple = self._validate_nm(n, np.abs(m))
if nm_tuple[1] == 0:
eps = 2.0
else:
eps = 1.0
return eps * np.pi / (nm_tuple[0] * 2 + 2)
[docs]
def evaluate_xy(self, x, y, n, m):
"""
Evaluate a Zernike polynomial at a point in
Cartesian space.
x and y are the Cartesian coordinaes (either scalars
or arrays)
n is the radial order of the polynomial
m is the angular order of the polynomial
Return the value(s) of the polynomial at x, y
(returns np.nan if sqrt(x**2+y**2)>1.0)
"""
# since we use np.where to handle r==0 cases,
# use np.errstate to temporarily turn off the
# divide by zero and invalid double scalar
# RuntimeWarnings
with np.errstate(divide="ignore", invalid="ignore"):
r = np.sqrt(x**2 + y**2)
cos_phi = np.where(r > 0.0, x / r, 0.0)
arccos_phi = np.arccos(cos_phi)
phi = np.where(y >= 0.0, arccos_phi, 0.0 - arccos_phi)
return self.evaluate(r, phi, n, m)