Skip to content

What are Variables?

SRMs operate on Variables -- pieces of a sample that in full define the state of the sample. Variables might not only directly partition the data, but could also be some latent encoding of the sample. The only constraint we have, is that the shape of the Variables remains constant throughout training.

What do you most likely need?

A simple VariableMapper needs to implement four methods

unstructured_tensor_to_variables -- maps batch of Unstructured datapoints to the Variables format, with shape (batch_size, num_variables, num_features).

variables_tensor_to_unstructured -- does the inverse of the previous function, transforms the Variables format tensor to the Unstructured (the same as in your dataset). This method is used after inference, when you want to visualize the generated sample.

mask_unstructured_tensor_to_variables -- if your dataset contains masks, here you need to define how to map them to variables. Note that one variable's mask is represented by a single floating point value -- (for now) we don't support sub-variable masking.

mask_variables_tensor_to_unstructured -- takes the mask in the Variables format and transforms that back to the Unstructured. This method can be used in your Evaluation where you might want to visualize the region your SRM was inpainting. By default, the same method is applied to transform the noise levels of each variable into the Unstructured format.

It also needs to define the num_variables and num_features -- note that those could be @property methods, eg. depending on the config.

Example: Spiral Variable Mapper

"""
Spiral Variable Mapper Implementation for SpatialReasoners
"""

from dataclasses import dataclass
import torch
from jaxtyping import Float
from torch import Tensor

import spatialreasoners as sr


@dataclass(frozen=True, kw_only=True)
class SpiralVariablMapperCfg(sr.variable_mapper.VariableMapperCfg):
    pass


@sr.variable_mapper.register_variable_mapper("spiral", SpiralVariablMapperCfg)
class SpiralVariablMapper(sr.variable_mapper.VariableMapper[SpiralVariablMapperCfg]):
    num_variables = 3  # x, y, r (color)
    num_features = 1

    def unstructured_tensor_to_variables(self, x: Float[Tensor, "batch x_y_c"]) -> Float[Tensor, "batch num_variables features"]:
        # Input second dimension is x, y, and color
        # Treat the current dimension as variables and just add a feature dimension
        return x.unsqueeze(-1)

    def variables_tensor_to_unstructured(self, x: Float[Tensor, "batch num_variables features"]) -> Float[Tensor, "batch 3"]:
        # Remove the feature dimension
        return x.squeeze(-1)

    def mask_variables_tensor_to_unstructured(self, mask: Float[Tensor, "batch num_variables"]) -> Float[Tensor, "batch 3"]:
        # Masks already don't have feature dimension, so return as is
        return mask 

    def mask_unstructured_tensor_to_variables(self, mask: Float[Tensor, "batch 3"]) -> Float[Tensor, "batch num_variables"]:
        return mask

What if I know the dependency structure of my data?

If you want to exploit the known causal structure of your variables when choosing the order of your inference, you need to define the dependency_matrix. You can do so, by implementing the _calculate_dependency_matrix method in the VariableMapper. The matrix defines for each variable (column) on which variables (row) does it depend. There are no restrictions about the acyclicity of the dependency matrices.

Below is an example implementation for the Spirals toy dataset

def _calculate_dependency_matrix(self) -> Float[Tensor, "num_variables num_variables"]:
    # color is the third variable, and it depends both on x and y
    dependency_matrix = torch.tensor([
        [1, 0, 1],
        [0, 1, 1],
        [0, 0, 1],
    ]).float()

    return dependency_matrix