Skip to content

Custom Aggregators

The 14 built-in aggregators cover most reporting needs, but eventually you'll want something that's not in the box — skewness for a tail-risk study, a portfolio Sharpe ratio, a weighted TVaR your regulator hasn't seen before. The aggregator layer is a plugin path: register your own class, use it in a ScenarioRun exactly like a built-in, and it survives YAML round-trip the same way.

You write five methods plus a canonical form. The framework handles batching, partitioning, merging across processes, and serialisation.


When to write one

Reach for a custom aggregator when:

  • The metric you need is mergeable across batches but isn't built in
  • You want it to round-trip through YAML governance like the built-ins
  • You're using it in more than one plan and want the metric to survive a code review

For a one-shot exploratory metric, the .of(pl.Expr) escape hatch on an existing aggregator is usually faster — see the modifiers section in the aggregators page.


The contract

A custom aggregator implements five hooks plus a canonical_form. Inherit from BaseAggregator and you get the column, alias, over(), and of() modifiers for free.

from dataclasses import dataclass
from typing import Any
from gaspatchio_core.scenarios import BaseAggregator, scenario_aggregator


@scenario_aggregator("Skewness")
@dataclass(frozen=True)
class Skewness(BaseAggregator):
    """Skewness across scenarios. Welford-Chan parallel-merge."""

    def create_accumulator(self) -> dict[str, float]:
        return {"n": 0.0, "mean": 0.0, "m2": 0.0, "m3": 0.0}

    def add_input(self, state, value):
        v = float(value) if value is not None else 0.0
        n1 = state["n"] + 1.0
        delta = v - state["mean"]
        delta_n = delta / n1
        term1 = delta * delta_n * state["n"]
        new_mean = state["mean"] + delta_n
        new_m3 = state["m3"] + term1 * delta_n * (n1 - 2.0) - 3.0 * delta_n * state["m2"]
        new_m2 = state["m2"] + term1
        return {"n": n1, "mean": new_mean, "m2": new_m2, "m3": new_m3}

    def merge_accumulators(self, a, b):
        na, nb = a["n"], b["n"]
        if na == 0:
            return b
        if nb == 0:
            return a
        n = na + nb
        delta = b["mean"] - a["mean"]
        mean = (na * a["mean"] + nb * b["mean"]) / n
        m2 = a["m2"] + b["m2"] + delta * delta * na * nb / n
        m3 = (
            a["m3"] + b["m3"]
            + delta ** 3 * na * nb * (na - nb) / (n * n)
            + 3.0 * delta * (na * b["m2"] - nb * a["m2"]) / n
        )
        return {"n": n, "mean": mean, "m2": m2, "m3": m3}

    def extract_output(self, state):
        n, m2, m3 = state["n"], state["m2"], state["m3"]
        if n < 3 or m2 == 0.0:
            return float("nan")
        std = (m2 / n) ** 0.5
        return (m3 / n) / (std ** 3)

    def canonical_form(self) -> dict[str, Any]:
        return {"kind": "Skewness", "column": self.column, "within": self.within}

The five hooks, in the order the framework calls them:

Hook Receives Returns Run when
create_accumulator nothing a fresh accumulator state once per scenario / partition
add_input state, one value per scenario new state once per scenario contribution
merge_accumulators two states one merged state when batches merge
extract_output final state a JSON-serialisable value once at the end
canonical_form nothing a recipe dict for SHA + YAML serialisation

The state can be any Python object — a dict, a tuple, a custom class. It travels in memory only, never to disk. What goes to disk is canonical_form() and extract_output().


Use it like a built-in

Once @scenario_aggregator("Skewness") runs, Skewness("loss").alias("skew") is just another aggregator.

from gaspatchio_core.scenarios import Sum, for_each_scenario


def stressed(af, *, tables=None, drivers=None):
    sid = pl.col("scenario_id")
    factor = (
        pl.when(sid == "BASE").then(1.0)
          .when(sid == "MILD").then(1.5)
          .when(sid == "MEDIUM").then(2.0)
          .when(sid == "SEVERE").then(3.0)
          .otherwise(5.0)
    )
    return af.with_columns((af["premium"] * factor).alias("loss"))


result = for_each_scenario(
    policies(),
    scenarios=["BASE", "MILD", "MEDIUM", "SEVERE", "CATASTROPHIC"],
    model_fn=stressed,
    aggregations=(
        Sum("loss").alias("total"),
        Skewness("loss").alias("skew"),
    ),
)
print(result.aggregations["total"])  # 27500.0
print(result.aggregations["skew"])   # 0.79

A right-skewed stress distribution — the catastrophic scenario dominates the tail.


Sketch-backed custom aggregators

Custom metrics that need a tail-quantile or a CTE can reuse the same SignedSketch the built-in Quantile / CTE use. The merge is bit-exact across batches and processes.

from gaspatchio_core.scenarios._sketch import SignedSketch


@scenario_aggregator("TVaR95")
@dataclass(frozen=True)
class TVaR95(BaseAggregator):
    """Tail Value-at-Risk at 95% — DDSketch-backed mergeable."""

    relative_accuracy: float = 1e-4

    def create_accumulator(self) -> SignedSketch:
        return SignedSketch(relative_accuracy=self.relative_accuracy)

    def add_input(self, state, value):
        state.add(float(value))
        return state

    def merge_accumulators(self, a, b):
        return SignedSketch.merge(a, b)

    def extract_output(self, state):
        return state.cte(level=0.05, direction="upper")

    def canonical_form(self) -> dict[str, Any]:
        return {
            "kind": "TVaR95",
            "column": self.column,
            "within": self.within,
            "relative_accuracy": self.relative_accuracy,
        }

Validating against the built-in CTE:

result = for_each_scenario(
    policies(),
    scenarios=[f"S{i:03d}" for i in range(200)],
    model_fn=stochastic_model,
    aggregations=(
        TVaR95("loss").alias("tvar_custom"),
        CTE("loss", level=0.05, direction="upper").alias("tvar_builtin"),
    ),
)
print(result.aggregations["tvar_custom"])    # 8602.42
print(result.aggregations["tvar_builtin"])   # 8602.42
print(result.aggregations["tvar_custom"] == result.aggregations["tvar_builtin"])
# True

Bit-exact match. Both use the same sketch with the same relative_accuracy; both add the same values in the same order; the final CTE call is the same. The custom path is identical to the built-in for this case — useful as a sanity check before you write something more exotic.


Test the merge

The framework guarantees that:

extract(fold(A ++ B)) == extract(merge(fold(A), fold(B)))

— for every aggregator, scalar or partitioned, across every batch boundary. Your custom aggregator must satisfy this for the run to be batch-equivalent. Pin it with a property test in your test suite:

def fold(values, agg):
    state = agg.create_accumulator()
    for v in values:
        state = agg.add_input(state, v)
    return state


def test_skewness_merge_associative():
    agg = Skewness("x")
    A = [1.0, 2.0, 4.0, 7.0, 11.0]
    B = [16.0, 22.0, 29.0]

    extract_concat = agg.extract_output(fold(A + B, agg))
    extract_merge  = agg.extract_output(
        agg.merge_accumulators(fold(A, agg), fold(B, agg))
    )

    assert abs(extract_concat - extract_merge) < 1e-9

For sketch-backed aggregators, compare extract values rather than internal state — the merge is bit-exact at the bucket level but floating-point output can differ by 1 ULP if you compare states directly.


YAML round-trip

canonical_form() is what gets serialised to YAML and rehydrated on reload. Two rules:

  1. The kind field must match the name you registered with (@scenario_aggregator("Skewness")).
  2. Every other field must match a constructor parameter name on your class.

The reload path is cls(**{k: v for k, v in recipe.items() if k not in {"kind", "alias"}}). If your canonical_form() emits relative_accuracy=1e-4, your constructor must accept relative_accuracy — which the @dataclass declaration handles automatically.

Custom aggregators that use the .of(pl.Expr) escape hatch do not survive YAML round-trip — the polars expression isn't serialisable into the recipe. If you reload such a plan, the framework raises a clear error rather than rebuilding with the wrong expression.


Cross-process governance

For an audit handoff to work — counterparty saves YAML, you reload in a fresh interpreter and reproduce — your aggregator module must be imported in the fresh process before ScenarioRun.from_yaml runs. Otherwise the registry doesn't have a Skewness class to look up.

The current pattern is explicit: ship the module path with the YAML and have the auditor import it manually. A future YAML plugins: key will let plans self-describe their plugin dependencies; until then, document the imports alongside the plan.


Cross-join semantics

When you write add_input, your aggregator receives one value per scenario — already reduced inside the scenario by within_expr (sum by default).

def add_input(self, state, value):
    # `value` here is one number per scenario.
    # If you ran 100 scenarios across 1,000 policies, this is called 100 times
    # (or 100 × number-of-partitions, with .over).
    ...

What add_input does not receive is a row per policy. The per-scenario projection — a frame with one row per (policy_id, scenario_id) — is reduced inside the scenario by the within parameter (sum, mean, max, …) before it reaches your aggregator. This means:

  • You don't write a Python loop over rows. The within-reduction is a polars expression and runs at speed.
  • For an aggregator that genuinely needs per-policy data inside a single scenario (rare), use Skewness.of(pl.col("loss")) to override the within-reduction with a raw expression.

The framework cross-joins your ActuarialFrame with the batch's scenario_id column before model_fn runs, so every scenario sees every policy — that's where model_fn produces the columns your aggregator reads. The aggregator itself never sees individual policies.


requires_scenario_id — when you need the scenario name

ArgMin / ArgMax are special: they need to remember which scenario produced the extreme value, not just the value. They opt in via a class attribute:

from typing import ClassVar


@scenario_aggregator("ArgWorstLossRatio")
@dataclass(frozen=True)
class ArgWorstLossRatio(BaseAggregator):
    """Scenario_id of the worst loss ratio observed."""

    requires_scenario_id: ClassVar[bool] = True

    def add_input(self, state, value):
        # `value` is now a (scenario_id, scalar) tuple instead of just scalar
        sid, ratio = value
        ...

With requires_scenario_id = True, the framework packs (scenario_id, value) into add_input instead of passing the bare value. Set it on classes that need to return a scenario identity rather than a scalar.


Where this lands in the bigger story

A custom aggregator participates in everything you've seen on the previous pages: the Sum / CTE mix on the aggregators page, the ScenarioRun plan layer, the audit chain, the YAML round-trip. There's no separate "custom path" — it's the same path the built-ins use.