DIRICHLET_MULTINOM

Overview

The DIRICHLET_MULTINOM function computes statistical properties of the Dirichlet-multinomial distribution, a compound probability distribution that arises when category probabilities are uncertain. Also known as the Dirichlet compound multinomial (DCM) or multivariate Pólya distribution, it models scenarios where observations follow a multinomial distribution with probabilities drawn from a Dirichlet distribution.

This distribution is constructed by first drawing a probability vector \mathbf{p} from a Dirichlet distribution with concentration parameters \boldsymbol{\alpha} = (\alpha_1, \ldots, \alpha_K), then drawing counts from a multinomial distribution with n trials and probability vector \mathbf{p}. The probability mass function is:

P(\mathbf{x} \mid n, \boldsymbol{\alpha}) = \frac{\Gamma(\alpha_0) \Gamma(n+1)}{\Gamma(n + \alpha_0)} \prod_{k=1}^{K} \frac{\Gamma(x_k + \alpha_k)}{\Gamma(\alpha_k) \Gamma(x_k + 1)}

where \alpha_0 = \sum_{k=1}^{K} \alpha_k is the sum of concentration parameters, and \mathbf{x} = (x_1, \ldots, x_K) represents counts in each of K categories with \sum x_k = n.

The expected value for category i is E(X_i) = n \alpha_i / \alpha_0, and the variance is:

\text{Var}(X_i) = n \frac{\alpha_i}{\alpha_0} \left(1 - \frac{\alpha_i}{\alpha_0}\right) \frac{n + \alpha_0}{1 + \alpha_0}

The distribution exhibits overdispersion relative to the multinomial—the variance is inflated by a factor of (n + \alpha_0)/(1 + \alpha_0). This makes it suitable for modeling count data with extra variability, such as word frequencies in documents or allele counts in population genetics. The concentration parameter \alpha_0 controls the degree of overdispersion: smaller values produce greater variability, while larger values make the distribution approach a standard multinomial.

This implementation uses SciPy’s dirichlet_multinomial module and supports computing the PMF, log-PMF, mean, variance, and covariance matrix. For additional theoretical background, see the Wikipedia article on the Dirichlet-multinomial distribution.

This example function is provided as-is without any representation of accuracy.

Excel Usage

=DIRICHLET_MULTINOM(x, alpha, n, dm_method)
  • x (list[list], optional, default: null): 2D list of integer counts for each category. Required for pmf and logpmf methods.
  • alpha (list[list], optional, default: null): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution.
  • n (list[list], optional, default: null): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov.
  • dm_method (str, optional, default: “pmf”): Computation method to use.

Returns (list[list]): 2D list of results, or error message string.

Examples

Example 1: Basic PMF calculation with uniform concentration

Inputs:

x alpha n dm_method
2 3 5 1 1 1 10 pmf

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "pmf")

Expected output:

Result
0.0152

Example 2: Log-PMF calculation for same distribution

Inputs:

x alpha n dm_method
2 3 5 1 1 1 10 logpmf

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {1,1,1}, {10}, "logpmf")

Expected output:

Result
-4.1897

Example 3: Expected mean counts for weighted concentration

Inputs:

alpha n dm_method
2 3 5 10 mean

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {10}, "mean")

Expected output:

Result
2 3 5

Example 4: Variance for weighted concentration

Inputs:

alpha n dm_method
2 3 5 10 var

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, {10}, "var")

Expected output:

Result
2.9091 3.8182 4.5455

Example 5: Covariance matrix for three categories

Inputs:

alpha dm_method
2 3 5 cov

Excel formula:

=DIRICHLET_MULTINOM({2,3,5}, "cov")

Expected output:

Result
0.16 -0.06 -0.1
-0.06 0.21 -0.15
-0.1 -0.15 0.25

Python Code

from scipy.stats import dirichlet_multinomial as scipy_dirichlet_multinomial

def dirichlet_multinom(x=None, alpha=None, n=None, dm_method='pmf'):
    """
    Computes the probability mass function, log probability mass function, mean, variance, or covariance of the Dirichlet multinomial distribution.

    See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet_multinomial.html

    This example function is provided as-is without any representation of accuracy.

    Args:
        x (list[list], optional): 2D list of integer counts for each category. Required for pmf and logpmf methods. Default is None.
        alpha (list[list], optional): 2D list of concentration parameters (positive floats). Each row represents parameters for one distribution. Default is None.
        n (list[list], optional): 2D list containing the number of trials for each distribution. Each row contains one integer. Required for all methods except cov. Default is None.
        dm_method (str, optional): Computation method to use. Valid options: PMF, Log PMF, Mean, Variance, Covariance. Default is 'pmf'.

    Returns:
        list[list]: 2D list of results, or error message string.
    """
    def to2d(val):
      if val is None:
        return None
      return [[val]] if not isinstance(val, list) else val

    def to_float_list(arr):
      if hasattr(arr, 'tolist'):
        arr = arr.tolist()
      if isinstance(arr, (float, int)):
        return [float(arr)]
      return [float(v) for v in arr]

    try:
      valid_methods = {'pmf', 'logpmf', 'mean', 'var', 'cov'}
      if dm_method not in valid_methods:
        return f"Error: Invalid method '{dm_method}'. Must be one of {sorted(valid_methods)}."

      if alpha is None:
        return "Error: Invalid input: alpha is required."
      alpha = to2d(alpha)
      if not isinstance(alpha, list) or not all(isinstance(row, list) and len(row) > 0 for row in alpha):
        return "Error: alpha must be a 2D list of positive floats."
      if len(alpha) < 1:
        return "Error: alpha must have at least one row."

      try:
        alpha = [[float(v) for v in row] for row in alpha]
      except (TypeError, ValueError):
        return "Error: alpha must be a 2D list of positive floats."
      if any(any(v <= 0 for v in row) for row in alpha):
        return "Error: alpha must be a 2D list of positive floats."

      # n is required for pmf/logpmf/mean/var; for cov, default to n=1 if omitted
      if dm_method != 'cov':
        if n is None:
          return "Error: Invalid input: n is required."
        n = to2d(n)
        if not isinstance(n, list) or len(n) != len(alpha):
          return "Error: n must be a 2D list with the same number of rows as alpha."
        for n_row in n:
          if not isinstance(n_row, list) or len(n_row) != 1:
            return "Error: Each row of n must contain exactly one integer."
        try:
          n = [[int(val[0])] for val in n]
        except (TypeError, ValueError):
          return "Error: n must contain integers."
        if any(val[0] < 0 for val in n):
          return "Error: n must contain non-negative integers."
      else:
        if n is not None:
          n = to2d(n)
          if not isinstance(n, list) or len(n) != len(alpha):
            return "Error: n must be a 2D list with the same number of rows as alpha."
          for n_row in n:
            if not isinstance(n_row, list) or len(n_row) != 1:
              return "Error: Each row of n must contain exactly one integer."
          try:
            n = [[int(val[0])] for val in n]
          except (TypeError, ValueError):
            return "Error: n must contain integers."
          if any(val[0] < 0 for val in n):
            return "Error: n must contain non-negative integers."

      if dm_method in {'pmf', 'logpmf'}:
        if x is None:
          return "Error: Invalid input: x is required for pmf/logpmf."
        x = to2d(x)
        if not isinstance(x, list) or len(x) != len(alpha):
          return "Error: x must be a 2D list with the same number of rows as alpha."
        for row in x:
          if not isinstance(row, list) or len(row) != len(alpha[0]):
            return "Error: Each row of x must have the same length as alpha rows."
          try:
            if any(int(val) < 0 for val in row):
              return "Error: x must contain non-negative integers."
          except (TypeError, ValueError):
            return "Error: x must contain integers."

      results = []
      for i, alpha_row in enumerate(alpha):
        try:
          if dm_method == 'cov':
            n_val = 1 if n is None else n[i][0]
          else:
            n_val = n[i][0]

          if dm_method in {'pmf', 'logpmf'}:
            row_sum = sum(int(v) for v in x[i])
            if row_sum != n_val:
              return "Error: Invalid input: each row of x must sum to n."

          dist = scipy_dirichlet_multinomial(alpha=alpha_row, n=n_val)

          if dm_method == 'pmf':
            res = dist.pmf(x[i])
          elif dm_method == 'logpmf':
            res = dist.logpmf(x[i])
          elif dm_method == 'mean':
            res = dist.mean()
          elif dm_method == 'var':
            res = dist.var()
          elif dm_method == 'cov':
            res = dist.cov()

          if dm_method == 'cov':
            cov_matrix = res.tolist() if hasattr(res, 'tolist') else res
            for row in cov_matrix:
              results.append([float(val) for val in row])
          else:
            results.append(to_float_list(res))
        except Exception as e:
          return f"Error: computing {dm_method}: {e}"

      return results
    except Exception as e:
      return f"Error: {str(e)}"

Online Calculator