arviz_base.NumPyroInferenceAdapter

arviz_base.NumPyroInferenceAdapter#

class arviz_base.NumPyroInferenceAdapter(inference_obj, model, model_args, model_kwargs, sample_shape)[source]#

Standardize methods across NumPyro inference objects for use with NumPyroConverter.

__init__(inference_obj, model, model_args, model_kwargs, sample_shape)[source]#

Initialize the adapter with common attributes for NumPyro inference objects.

This base class constructor sets up the shared infrastructure needed by all NumPyro inference adapters (MCMC, SVI, etc.) to provide a unified interface for the NumPyroConverter.

Parameters:
inference_objAny

The NumPyro inference object to adapt (e.g., MCMC, SVI, or other inference types).

modelcallable

The NumPyro model function that was used for inference.

model_argstuple, optional

Positional arguments passed to the model during inference. If None, defaults to an empty tuple.

model_kwargsdict, optional

Keyword arguments passed to the model during inference. If None, defaults to an empty dict.

sample_shapetuple of int

Shape of the samples to be returned by get_samples(). For MCMC: (num_chains, num_draws) For SVI: (num_samples,)

Methods

__init__(inference_obj, model, model_args, ...)

Initialize the adapter with common attributes for NumPyro inference objects.

get_sample_stats(**kwargs)

Get sample stats from the inference object (e.g., divergences for MCMC).

get_samples([seed])

Get posterior samples from the inference object.

Attributes

sample_dims

Return the sample dimension names.