mirror of
https://github.com/azaion/gps-denied-onboard.git
synced 2026-04-22 11:36:37 +00:00
Initial commit
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
import math
|
||||
from typing import Dict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class IRobustKernels(ABC):
|
||||
@abstractmethod
|
||||
def huber_loss(self, error: float, threshold: float) -> float: pass
|
||||
@abstractmethod
|
||||
def cauchy_loss(self, error: float, k: float) -> float: pass
|
||||
@abstractmethod
|
||||
def compute_weight(self, error: float, kernel_type: str, params: Dict[str, float]) -> float: pass
|
||||
|
||||
class RobustKernels(IRobustKernels):
|
||||
"""H03: Huber/Cauchy loss functions for outlier rejection in optimization."""
|
||||
def huber_loss(self, error: float, threshold: float) -> float:
|
||||
abs_err = abs(error)
|
||||
if abs_err <= threshold:
|
||||
return 0.5 * (error ** 2)
|
||||
return threshold * (abs_err - 0.5 * threshold)
|
||||
|
||||
def cauchy_loss(self, error: float, k: float) -> float:
|
||||
return (k ** 2 / 2.0) * math.log(1.0 + (error / k) ** 2)
|
||||
|
||||
def compute_weight(self, error: float, kernel_type: str, params: Dict[str, float]) -> float:
|
||||
abs_err = abs(error)
|
||||
if abs_err < 1e-8: return 1.0
|
||||
|
||||
if kernel_type.lower() == "huber":
|
||||
threshold = params.get("threshold", 1.0)
|
||||
return 1.0 if abs_err <= threshold else threshold / abs_err
|
||||
elif kernel_type.lower() == "cauchy":
|
||||
k = params.get("k", 1.0)
|
||||
return 1.0 / (1.0 + (error / k) ** 2)
|
||||
return 1.0
|
||||
Reference in New Issue
Block a user