Skip to content

Config Module

This module defines the configuration classes and utilities for the semantic segmentation pipeline.

It includes classes for input type enumeration, model configuration, visualization configuration, and the main configuration class that encapsulates all settings for the segmentation process.

InputType

Bases: Enum

Enumeration of supported input types for the segmentation pipeline.

ModelConfig

ModelConfig(
    name,
    model_type=None,
    max_size=None,
    device=None,
    dataset=None,
    num_workers=8,
    pipe_batch=1,
)

Configuration class for the segmentation model.

ATTRIBUTE DESCRIPTION
name

The name or path of the pre-trained model to use.

TYPE: str

model_type

The type of the model (e.g., 'oneformer', 'mask2former').

TYPE: Optional[str]

max_size

The maximum size for input image resizing.

TYPE: Optional[int]

device

The device to use for processing (e.g., 'cpu', 'cuda').

TYPE: Optional[str]

__post_init__

__post_init__()

Post-initialization method to set up the model type if not provided.

Source code in cityseg/config.py
def __post_init__(self):
    """
    Post-initialization method to set up the model type if not provided.
    """
    self.auto_detect_model_type()
    if self.device == "mps" and self.num_workers > 0 or self.num_workers is None:
        logger.warning(
            "MPS is not compatible with multiple workers in pytorch. Setting num_workers to 0."
        )
        self.num_workers = 0

auto_detect_model_type

auto_detect_model_type()

Automatically detect the model type from the model name if not provided.

Source code in cityseg/config.py
def auto_detect_model_type(self):
    """
    Automatically detect the model type from the model name if not provided.
    """

    def auto_model_type(model_name: str) -> str:
        return model_name.split("/")[-1].split("-")[0]

    if self.model_type is None:
        try:
            self.model_type = auto_model_type(self.name)
        except IndexError:
            logger.warning(
                "Unable to auto-detect model type from the model name and none provided."
            )
            return
        logger.info(f"Auto-detected model type: {self.model_type}")
    elif self.model_type != auto_model_type(self.name):
        logger.warning(
            f"Model type does not match auto-detected model type. Using provided model type: {self.model_type}"
        )

VisualizationConfig

VisualizationConfig(alpha=0.5, colormap='default')

Configuration class for visualization settings.

ATTRIBUTE DESCRIPTION
alpha

The alpha value for blending the segmentation mask with the original image.

TYPE: float

colormap

The colormap to use for visualizing the segmentation mask.

TYPE: str

Config

Config(
    input,
    output_dir,
    output_prefix,
    model,
    ignore_files=None,
    frame_step=1,
    batch_size=16,
    output_fps=None,
    save_raw_segmentation=True,
    save_colored_segmentation=False,
    save_overlay=True,
    analyze_results=True,
    visualization=VisualizationConfig(),
    force_reprocess=False,
    disable_tqdm=False,
)

Main configuration class for the segmentation pipeline.

This class encapsulates all settings required for the segmentation process, including input/output paths, model configuration, processing parameters, and visualization settings.

ATTRIBUTE DESCRIPTION
input

The input path (file or directory) for processing.

TYPE: Union[Path, str]

output_dir

The output directory for saving results.

TYPE: Optional[Path]

output_prefix

The prefix for output file names.

TYPE: Optional[str]

model

The model configuration.

TYPE: ModelConfig

frame_step

The frame step for video processing.

TYPE: int

batch_size

The batch size for processing.

TYPE: int

output_fps

The output FPS for processed videos.

TYPE: Optional[float]

save_raw_segmentation

Whether to save raw segmentation maps.

TYPE: bool

save_colored_segmentation

Whether to save colored segmentation maps.

TYPE: bool

save_overlay

Whether to save overlay visualizations.

TYPE: bool

visualization

The visualization configuration.

TYPE: VisualizationConfig

input_type

The type of input (automatically determined).

TYPE: InputType

force_reprocess

Whether to force reprocessing of existing results.

TYPE: bool

disable_tqdm

Whether to disable the progress bar display.

TYPE: bool

__post_init__

__post_init__()

Post-initialization method to set up the input path and determine the input type.

RAISES DESCRIPTION
ValueError

If the input path does not exist.

Source code in cityseg/config.py
def __post_init__(self):
    """
    Post-initialization method to set up the input path and determine the input type.

    Raises:
        ValueError: If the input path does not exist.
    """
    self.input = Path(self.input)
    if not self.input.exists():
        raise ValueError(f"Input path does not exist: {self.input}")
    self.input_type = self._determine_input_type()
    self.ignore_files = self.ignore_files or []

generate_output_prefix

generate_output_prefix()

Generate an output prefix based on the input file name and model configuration.

RETURNS DESCRIPTION
str

The generated output prefix.

TYPE: str

Source code in cityseg/config.py
def generate_output_prefix(self) -> str:
    """
    Generate an output prefix based on the input file name and model configuration.

    Returns:
        str: The generated output prefix.
    """
    if self.input_type == InputType.DIRECTORY:
        name = self.input.name
    else:
        name = self.input.stem.split("_")[
            0
        ]  # Use only the first part of the filename

    model_name = self.model.name.split("/")[-1]
    base_name = f"{name}_{model_name}_step{self.frame_step}"

    return base_name

get_output_path

get_output_path()

Get the full output path based on the configuration.

This method determines the appropriate output directory and file name based on the input type and configuration settings.

RETURNS DESCRIPTION
Path

The full output path.

TYPE: Path

Source code in cityseg/config.py
def get_output_path(self) -> Path:
    """
    Get the full output path based on the configuration.

    This method determines the appropriate output directory and file name
    based on the input type and configuration settings.

    Returns:
        Path: The full output path.
    """
    if self.output_dir is None:
        self.output_dir = self.input.parent / "output"
    elif not Path(self.output_dir).is_absolute():
        self.output_dir = self.input.parent / self.output_dir

    self.output_dir = self.output_dir.resolve()

    if self.input_type == InputType.DIRECTORY:
        model_name = self.model.name.split("/")[-1]
        subdir_name = f"{model_name}_step{self.frame_step}"
        self.output_dir = self.output_dir / subdir_name

    self.output_dir.mkdir(parents=True, exist_ok=True)

    if self.input_type == InputType.DIRECTORY:
        return self.output_dir

    prefix = self.output_prefix or self.generate_output_prefix()
    if self.input_type == InputType.SINGLE_IMAGE:
        return self.output_dir / f"{prefix}{self.input.suffix}"
    else:  # SINGLE_VIDEO
        return self.output_dir / f"{prefix}.mp4"

from_yaml

from_yaml(config_path)

Create a Config instance from a YAML file.

PARAMETER DESCRIPTION
config_path

Path to the YAML configuration file.

TYPE: Path

RETURNS DESCRIPTION
Config

An instance of the Config class.

TYPE: Config

Source code in cityseg/config.py
@classmethod
def from_yaml(cls, config_path: Path) -> "Config":
    """
    Create a Config instance from a YAML file.

    Args:
        config_path (Path): Path to the YAML configuration file.

    Returns:
        Config: An instance of the Config class.
    """
    with open(config_path, "r") as f:
        config_dict = yaml.safe_load(f)

    # Convert string paths back to Path objects
    if "input" in config_dict:
        config_dict["input"] = Path(config_dict["input"])
    if "output_dir" in config_dict:
        config_dict["output_dir"] = Path(config_dict["output_dir"])

    model_config = ModelConfig(**config_dict.get("model", {}))
    vis_config = VisualizationConfig(**config_dict.get("visualization", {}))

    return cls(
        input=config_dict["input"],
        output_dir=config_dict.get("output_dir"),
        output_prefix=config_dict.get("output_prefix"),
        model=model_config,
        ignore_files=config_dict.get("ignore_files", []),
        frame_step=config_dict.get("frame_step", 1),
        batch_size=config_dict.get("batch_size", 16),
        output_fps=config_dict.get("output_fps"),
        save_raw_segmentation=config_dict.get("save_raw_segmentation", True),
        save_colored_segmentation=config_dict.get(
            "save_colored_segmentation", False
        ),
        save_overlay=config_dict.get("save_overlay", True),
        analyze_results=config_dict.get("analyze_results", True),
        visualization=vis_config,
        force_reprocess=config_dict.get("force_reprocess", False),
        disable_tqdm=config_dict.get("disable_tqdm", False),
    )

to_dict

to_dict()

Convert the Config instance to a dictionary.

RETURNS DESCRIPTION
Dict[str, Any]

Dict[str, Any]: A dictionary representation of the Config instance.

Source code in cityseg/config.py
def to_dict(self) -> Dict[str, Any]:
    """
    Convert the Config instance to a dictionary.

    Returns:
        Dict[str, Any]: A dictionary representation of the Config instance.
    """
    return {
        "input": str(self.input),
        "output_dir": str(self.output_dir) if self.output_dir else None,
        "output_prefix": self.output_prefix,
        "model": asdict(self.model),
        "ignore_files": self.ignore_files,
        "frame_step": self.frame_step,
        "batch_size": self.batch_size,
        "output_fps": self.output_fps,
        "save_raw_segmentation": self.save_raw_segmentation,
        "save_colored_segmentation": self.save_colored_segmentation,
        "save_overlay": self.save_overlay,
        "analyze_results": self.analyze_results,
        "visualization": asdict(self.visualization),
        "input_type": self.input_type.value,
        "force_reprocess": self.force_reprocess,
        "disable_tqdm": self.disable_tqdm,
    }

ConfigHasher

A utility class for generating hashes of relevant configuration settings.

get_relevant_config

get_relevant_config(config)

Extract the relevant configuration settings for hashing.

This method filters out configuration settings that don't affect the analysis results or output format, focusing only on settings that would require reprocessing if changed.

PARAMETER DESCRIPTION
config

The full configuration object.

TYPE: Config

RETURNS DESCRIPTION
Dict[str, Any]

Dict[str, Any]: A dictionary of relevant configuration settings.

Source code in cityseg/config.py
@staticmethod
def get_relevant_config(config: Config) -> Dict[str, Any]:
    """
    Extract the relevant configuration settings for hashing.

    This method filters out configuration settings that don't affect the
    analysis results or output format, focusing only on settings that would
    require reprocessing if changed.

    Args:
        config (Config): The full configuration object.

    Returns:
        Dict[str, Any]: A dictionary of relevant configuration settings.
    """
    return {
        "model": {
            "name": config.model.name,
            "model_type": config.model.model_type,
            "max_size": config.model.max_size,
        },
        "frame_step": config.frame_step,
        "save_raw_segmentation": config.save_raw_segmentation,
        "save_colored_segmentation": config.save_colored_segmentation,
        "save_overlay": config.save_overlay,
        "visualization": config.visualization.alpha,  # Assuming this is the relevant part
    }

calculate_hash

calculate_hash(config)

Calculate a hash of the relevant configuration settings.

This method creates a deterministic hash of the configuration settings that affect the analysis results or output format.

PARAMETER DESCRIPTION
config

The full configuration object.

TYPE: Config

RETURNS DESCRIPTION
str

A hexadecimal string representing the hash of the relevant config.

TYPE: str

Source code in cityseg/config.py
@staticmethod
def calculate_hash(config: Config) -> str:
    """
    Calculate a hash of the relevant configuration settings.

    This method creates a deterministic hash of the configuration settings
    that affect the analysis results or output format.

    Args:
        config (Config): The full configuration object.

    Returns:
        str: A hexadecimal string representing the hash of the relevant config.
    """
    relevant_config = ConfigHasher.get_relevant_config(config)
    config_str = json.dumps(relevant_config, sort_keys=True)
    return hashlib.md5(config_str.encode()).hexdigest()