arviz_base.NumPyroInferenceAdapter#
- class arviz_base.NumPyroInferenceAdapter(inference_obj, model, model_args, model_kwargs, sample_shape)[source]#
Standardize methods across NumPyro inference objects for use with NumPyroConverter.
- __init__(inference_obj, model, model_args, model_kwargs, sample_shape)[source]#
Initialize the adapter with common attributes for NumPyro inference objects.
This base class constructor sets up the shared infrastructure needed by all NumPyro inference adapters (MCMC, SVI, etc.) to provide a unified interface for the NumPyroConverter.
- Parameters:
- inference_obj
Any The NumPyro inference object to adapt (e.g., MCMC, SVI, or other inference types).
- model
callable The NumPyro model function that was used for inference.
- model_args
tuple, optional Positional arguments passed to the model during inference. If None, defaults to an empty tuple.
- model_kwargs
dict, optional Keyword arguments passed to the model during inference. If None, defaults to an empty dict.
- sample_shape
tupleofint Shape of the samples to be returned by get_samples(). For MCMC: (num_chains, num_draws) For SVI: (num_samples,)
- inference_obj
Methods
__init__(inference_obj, model, model_args, ...)Initialize the adapter with common attributes for NumPyro inference objects.
get_sample_stats(**kwargs)Get sample stats from the inference object (e.g., divergences for MCMC).
get_samples([seed])Get posterior samples from the inference object.
Attributes
sample_dimsReturn the sample dimension names.