API Reference¶
This page provides detailed API documentation for Spatial Reasoners.
High-Level API¶
Training Functions¶
sr.run_training()¶
Main function for running training with predefined configurations.
def run_training(
config_name: str = "api_default",
config_path: str = None,
overrides: List[str] = None,
enable_beartype: bool = False,
) -> Any
Parameters:
config_name(str): Name of the configuration file to use. Default: "api_default"config_path(str): Path to the custom configuration directory. If None, uses embedded configs onlyoverrides(List[str]): List of configuration overrides in Hydra formatenable_beartype(bool): Enable runtime type checking for better debugging
Example:
import spatialreasoners as sr
# Basic usage
sr.run_training()
# With overrides
sr.run_training(overrides=[
"experiment=mnist_sudoku",
"trainer.max_epochs=50"
])
@sr.config_main¶
Decorator for creating training scripts with automatic configuration management.
def config_main(
config_path: str = "configs",
config_name: str = "main"
)
Parameters:
config_path(str): Path to configuration directoryconfig_name(str): Name of the main configuration file
Example:
@sr.config_main(config_path="configs", config_name="main")
def main(cfg):
lightning_module = sr.create_lightning_module(cfg)
data_module = sr.create_data_module(cfg)
trainer = sr.create_trainer(cfg)
trainer.fit(lightning_module, datamodule=data_module)
Configuration Functions¶
sr.load_default_config()¶
Load the default embedded configuration.
def load_default_config() -> DictConfig
Returns: - DictConfig: Default configuration object
sr.load_config_from_yaml()¶
Load configuration from YAML files with optional overrides.
def load_config_from_yaml(
config_path: str = None,
config_name: str = "api_default",
overrides: List[str] = None
) -> DictConfig
Parameters:
config_path(str): Path to configuration directoryconfig_name(str): Name of configuration fileoverrides(List[str]): Configuration overrides
Component Factory Functions¶
sr.create_lightning_module()¶
Create a PyTorch Lightning module from configuration.
def create_lightning_module(cfg: DictConfig) -> LightningModule
sr.create_data_module()¶
Create a PyTorch Lightning data module from configuration.
def create_data_module(cfg: DictConfig) -> LightningDataModule
sr.create_trainer()¶
Create a PyTorch Lightning trainer from configuration.
def create_trainer(cfg: DictConfig) -> Trainer
Registering components¶
Example:
from spatialreasoners.dataset import register_dataset, DatasetCfg
from spatialreasoners.type_extensions import UnstructuredExample
@dataclass
class MyDatasetCfg(DatasetCfg):
data_path: str = "data/"
# Other parameters
@register_dataset("my_dataset", MyDatasetCfg)
class MyDataset(sr.Dataset):
def __init__(self, cfg: MyDatasetCfg):
super().__init__()
@property
def _num_available(self) -> int
pass # get full size of the full dataset
def __getitem__(self, idx: int) -> UnstructuredExample:
pass # prepare Unstructured Example
Type Checking¶
Spatial Reasoners supports optional runtime type checking with beartype:
# Enable type checking
sr.run_training(enable_beartype=True)
# Or via CLI
python training.py --enable-beartype
This provides enhanced error messages and catches type mismatches early. All the builtin methods that operate on tensors use jaxtyping, with together with beartype allows us to check whether there are any shape mismatches. We also encourage you to type your methods with jaxtyping, as it can really speed up debugging!
Examples¶
See the Quick Tour and example project for complete usage examples.
For a minimalistic project template, check out our example_project in the the GitHub repository.