API Reference¶
This page provides detailed API documentation for Spatial Reasoners.
High-Level API¶
Training Functions¶
sr.run_training()
¶
Main function for running training with predefined configurations.
def run_training(
config_name: str = "api_default",
config_path: str = None,
overrides: List[str] = None,
enable_beartype: bool = False,
) -> Any
Parameters:
config_name
(str): Name of the configuration file to use. Default: "api_default"config_path
(str): Path to the custom configuration directory. If None, uses embedded configs onlyoverrides
(List[str]): List of configuration overrides in Hydra formatenable_beartype
(bool): Enable runtime type checking for better debugging
Example:
import spatialreasoners as sr
# Basic usage
sr.run_training()
# With overrides
sr.run_training(overrides=[
"experiment=mnist_sudoku",
"trainer.max_epochs=50"
])
@sr.config_main
¶
Decorator for creating training scripts with automatic configuration management.
def config_main(
config_path: str = "configs",
config_name: str = "main"
)
Parameters:
config_path
(str): Path to configuration directoryconfig_name
(str): Name of the main configuration file
Example:
@sr.config_main(config_path="configs", config_name="main")
def main(cfg):
lightning_module = sr.create_lightning_module(cfg)
data_module = sr.create_data_module(cfg)
trainer = sr.create_trainer(cfg)
trainer.fit(lightning_module, datamodule=data_module)
Configuration Functions¶
sr.load_default_config()
¶
Load the default embedded configuration.
def load_default_config() -> DictConfig
Returns: - DictConfig
: Default configuration object
sr.load_config_from_yaml()
¶
Load configuration from YAML files with optional overrides.
def load_config_from_yaml(
config_path: str = None,
config_name: str = "api_default",
overrides: List[str] = None
) -> DictConfig
Parameters:
config_path
(str): Path to configuration directoryconfig_name
(str): Name of configuration fileoverrides
(List[str]): Configuration overrides
Component Factory Functions¶
sr.create_lightning_module()
¶
Create a PyTorch Lightning module from configuration.
def create_lightning_module(cfg: DictConfig) -> LightningModule
sr.create_data_module()
¶
Create a PyTorch Lightning data module from configuration.
def create_data_module(cfg: DictConfig) -> LightningDataModule
sr.create_trainer()
¶
Create a PyTorch Lightning trainer from configuration.
def create_trainer(cfg: DictConfig) -> Trainer
Registering components¶
Example:
from spatialreasoners.dataset import register_dataset, DatasetCfg
from spatialreasoners.type_extensions import UnstructuredExample
@dataclass
class MyDatasetCfg(DatasetCfg):
data_path: str = "data/"
# Other parameters
@register_dataset("my_dataset", MyDatasetCfg)
class MyDataset(sr.Dataset):
def __init__(self, cfg: MyDatasetCfg):
super().__init__()
@property
def _num_available(self) -> int
pass # get full size of the full dataset
def __getitem__(self, idx: int) -> UnstructuredExample:
pass # prepare Unstructured Example
Type Checking¶
Spatial Reasoners supports optional runtime type checking with beartype:
# Enable type checking
sr.run_training(enable_beartype=True)
# Or via CLI
python training.py --enable-beartype
This provides enhanced error messages and catches type mismatches early. All the builtin methods that operate on tensors use jaxtyping
, with together with beartype
allows us to check whether there are any shape
mismatches. We also encourage you to type your methods with jaxtyping
, as it can really speed up debugging!
Examples¶
See the Quick Tour and example project for complete usage examples.
For a minimalistic project template, check out our example_project
in the the GitHub repository.