mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 10:56:37 +00:00
34 lines
1.3 KiB
Python
34 lines
1.3 KiB
Python
import math
|
|
from typing import Dict
|
|
from abc import ABC, abstractmethod
|
|
|
|
class IRobustKernels(ABC):
|
|
@abstractmethod
|
|
def huber_loss(self, error: float, threshold: float) -> float: pass
|
|
@abstractmethod
|
|
def cauchy_loss(self, error: float, k: float) -> float: pass
|
|
@abstractmethod
|
|
def compute_weight(self, error: float, kernel_type: str, params: Dict[str, float]) -> float: pass
|
|
|
|
class RobustKernels(IRobustKernels):
|
|
"""H03: Huber/Cauchy loss functions for outlier rejection in optimization."""
|
|
def huber_loss(self, error: float, threshold: float) -> float:
|
|
abs_err = abs(error)
|
|
if abs_err <= threshold:
|
|
return 0.5 * (error ** 2)
|
|
return threshold * (abs_err - 0.5 * threshold)
|
|
|
|
def cauchy_loss(self, error: float, k: float) -> float:
|
|
return (k ** 2 / 2.0) * math.log(1.0 + (error / k) ** 2)
|
|
|
|
def compute_weight(self, error: float, kernel_type: str, params: Dict[str, float]) -> float:
|
|
abs_err = abs(error)
|
|
if abs_err < 1e-8: return 1.0
|
|
|
|
if kernel_type.lower() == "huber":
|
|
threshold = params.get("threshold", 1.0)
|
|
return 1.0 if abs_err <= threshold else threshold / abs_err
|
|
elif kernel_type.lower() == "cauchy":
|
|
k = params.get("k", 1.0)
|
|
return 1.0 / (1.0 + (error / k) ** 2)
|
|
return 1.0 |