# Copyright (c) 2017 The PyWavelets Developers # # See COPYING for license details. import inspect import numpy as np from collections.abc import Iterable from ._extensions._pywt import (Wavelet, ContinuousWavelet, DiscreteContinuousWavelet, Modes) def _as_wavelet(wavelet): """Convert wavelet name to a Wavelet object.""" if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): wavelet = DiscreteContinuousWavelet(wavelet) if isinstance(wavelet, ContinuousWavelet): raise ValueError( "A ContinuousWavelet object was provided, but only discrete " "Wavelet objects are supported by this function. A list of all " "supported discrete wavelets can be obtained by running:\n" "print(pywt.wavelist(kind='discrete'))") return wavelet def _wavelets_per_axis(wavelet, axes): """Initialize Wavelets for each axis to be transformed. Parameters ---------- wavelet : Wavelet or tuple of Wavelets If a single Wavelet is provided, it will used for all axes. Otherwise one Wavelet per axis must be provided. axes : list The tuple of axes to be transformed. Returns ------- wavelets : list of Wavelet objects A tuple of Wavelets equal in length to ``axes``. """ axes = tuple(axes) if isinstance(wavelet, (str, Wavelet)): # same wavelet on all axes wavelets = [_as_wavelet(wavelet), ] * len(axes) elif isinstance(wavelet, Iterable): # (potentially) unique wavelet per axis (e.g. for dual-tree DWT) if len(wavelet) == 1: wavelets = [_as_wavelet(wavelet[0]), ] * len(axes) else: if len(wavelet) != len(axes): raise ValueError(( "The number of wavelets must match the number of axes " "to be transformed.")) wavelets = [_as_wavelet(w) for w in wavelet] else: raise ValueError("wavelet must be a str, Wavelet or iterable") return wavelets def _modes_per_axis(modes, axes): """Initialize mode for each axis to be transformed. Parameters ---------- modes : str or tuple of strings If a single mode is provided, it will used for all axes. Otherwise one mode per axis must be provided. axes : tuple The tuple of axes to be transformed. Returns ------- modes : tuple of int A tuple of Modes equal in length to ``axes``. """ axes = tuple(axes) if isinstance(modes, (int, str)): # same wavelet on all axes modes = [Modes.from_object(modes), ] * len(axes) elif isinstance(modes, Iterable): if len(modes) == 1: modes = [Modes.from_object(modes[0]), ] * len(axes) else: # (potentially) unique wavelet per axis (e.g. for dual-tree DWT) if len(modes) != len(axes): raise ValueError(("The number of modes must match the number " "of axes to be transformed.")) modes = [Modes.from_object(mode) for mode in modes] else: raise ValueError("modes must be a str, Mode enum or iterable") return modes