Changed to prepare onnx
ci/woodpecker/manual/e2e-smoke-jetson Pipeline was successful

This commit is contained in:
Roman Meshko
2026-04-26 23:14:08 +03:00
parent 4ec9633902
commit 73c9d57827
5 changed files with 158 additions and 5 deletions
+111
View File
@@ -0,0 +1,111 @@
import ast
import io
import onnx
from onnx import helper, numpy_helper
_REDUCE_OPS_WITH_AXES_INPUT = {
"ReduceL1",
"ReduceL2",
"ReduceLogSum",
"ReduceLogSumExp",
"ReduceMax",
"ReduceMean",
"ReduceMin",
"ReduceProd",
"ReduceSum",
"ReduceSumSquare",
}
def _metadata(model):
return {p.key: p.value for p in model.metadata_props}
def _input_size(model):
try:
imgsz = _metadata(model).get("imgsz")
parsed = ast.literal_eval(imgsz)
if isinstance(parsed, (list, tuple)) and len(parsed) == 2:
h, w = int(parsed[0]), int(parsed[1])
if h > 0 and w > 0:
return h, w
except Exception:
pass
return 1280, 1280
def _constant_values(graph):
values = {init.name: numpy_helper.to_array(init) for init in graph.initializer}
for node in graph.node:
if node.op_type != "Constant" or not node.output:
continue
for attr in node.attribute:
if attr.name == "value":
values[node.output[0]] = numpy_helper.to_array(attr.t)
break
return values
def _as_int_list(value):
if value is None:
return None
if getattr(value, "shape", ()) == ():
return [int(value)]
return [int(v) for v in value.reshape(-1).tolist()]
def _set_static_input_shape(model, batch=1):
h, w = _input_size(model)
for graph_input in model.graph.input:
tensor_type = graph_input.type.tensor_type
if tensor_type.elem_type != onnx.TensorProto.FLOAT:
continue
dims = tensor_type.shape.dim
if len(dims) != 4:
continue
for dim, value in zip(dims, (batch, 3, h, w)):
dim.dim_value = value
return True
return False
def _rewrite_reduce_axes_inputs(model):
constants = _constant_values(model.graph)
changed = False
for node in model.graph.node:
if node.op_type not in _REDUCE_OPS_WITH_AXES_INPUT or len(node.input) < 2:
continue
axes = _as_int_list(constants.get(node.input[1]))
if axes is None:
continue
kept_attrs = [attr for attr in node.attribute if attr.name != "axes"]
del node.attribute[:]
node.attribute.extend(kept_attrs)
node.attribute.extend([helper.make_attribute("axes", axes)])
del node.input[1:]
changed = True
return changed
def _cap_default_opset(model, max_opset=17):
for opset in model.opset_import:
if opset.domain in ("", "ai.onnx") and opset.version > max_opset:
opset.version = max_opset
return True
return False
def prepare_for_tensorrt(model_bytes):
model = onnx.load_model_from_string(model_bytes)
changed = False
changed = _set_static_input_shape(model) or changed
changed = _rewrite_reduce_axes_inputs(model) or changed
changed = _cap_default_opset(model) or changed
if not changed:
return model_bytes
buffer = io.BytesIO()
onnx.save_model(model, buffer)
return buffer.getvalue()