mirror of
https://github.com/azaion/detections.git
synced 2026-06-23 09:51:08 +00:00
This commit is contained in:
@@ -0,0 +1,111 @@
|
||||
import ast
|
||||
import io
|
||||
|
||||
import onnx
|
||||
from onnx import helper, numpy_helper
|
||||
|
||||
|
||||
_REDUCE_OPS_WITH_AXES_INPUT = {
|
||||
"ReduceL1",
|
||||
"ReduceL2",
|
||||
"ReduceLogSum",
|
||||
"ReduceLogSumExp",
|
||||
"ReduceMax",
|
||||
"ReduceMean",
|
||||
"ReduceMin",
|
||||
"ReduceProd",
|
||||
"ReduceSum",
|
||||
"ReduceSumSquare",
|
||||
}
|
||||
|
||||
|
||||
def _metadata(model):
|
||||
return {p.key: p.value for p in model.metadata_props}
|
||||
|
||||
|
||||
def _input_size(model):
|
||||
try:
|
||||
imgsz = _metadata(model).get("imgsz")
|
||||
parsed = ast.literal_eval(imgsz)
|
||||
if isinstance(parsed, (list, tuple)) and len(parsed) == 2:
|
||||
h, w = int(parsed[0]), int(parsed[1])
|
||||
if h > 0 and w > 0:
|
||||
return h, w
|
||||
except Exception:
|
||||
pass
|
||||
return 1280, 1280
|
||||
|
||||
|
||||
def _constant_values(graph):
|
||||
values = {init.name: numpy_helper.to_array(init) for init in graph.initializer}
|
||||
for node in graph.node:
|
||||
if node.op_type != "Constant" or not node.output:
|
||||
continue
|
||||
for attr in node.attribute:
|
||||
if attr.name == "value":
|
||||
values[node.output[0]] = numpy_helper.to_array(attr.t)
|
||||
break
|
||||
return values
|
||||
|
||||
|
||||
def _as_int_list(value):
|
||||
if value is None:
|
||||
return None
|
||||
if getattr(value, "shape", ()) == ():
|
||||
return [int(value)]
|
||||
return [int(v) for v in value.reshape(-1).tolist()]
|
||||
|
||||
|
||||
def _set_static_input_shape(model, batch=1):
|
||||
h, w = _input_size(model)
|
||||
for graph_input in model.graph.input:
|
||||
tensor_type = graph_input.type.tensor_type
|
||||
if tensor_type.elem_type != onnx.TensorProto.FLOAT:
|
||||
continue
|
||||
dims = tensor_type.shape.dim
|
||||
if len(dims) != 4:
|
||||
continue
|
||||
for dim, value in zip(dims, (batch, 3, h, w)):
|
||||
dim.dim_value = value
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _rewrite_reduce_axes_inputs(model):
|
||||
constants = _constant_values(model.graph)
|
||||
changed = False
|
||||
for node in model.graph.node:
|
||||
if node.op_type not in _REDUCE_OPS_WITH_AXES_INPUT or len(node.input) < 2:
|
||||
continue
|
||||
axes = _as_int_list(constants.get(node.input[1]))
|
||||
if axes is None:
|
||||
continue
|
||||
kept_attrs = [attr for attr in node.attribute if attr.name != "axes"]
|
||||
del node.attribute[:]
|
||||
node.attribute.extend(kept_attrs)
|
||||
node.attribute.extend([helper.make_attribute("axes", axes)])
|
||||
del node.input[1:]
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
|
||||
def _cap_default_opset(model, max_opset=17):
|
||||
for opset in model.opset_import:
|
||||
if opset.domain in ("", "ai.onnx") and opset.version > max_opset:
|
||||
opset.version = max_opset
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def prepare_for_tensorrt(model_bytes):
|
||||
model = onnx.load_model_from_string(model_bytes)
|
||||
changed = False
|
||||
changed = _set_static_input_shape(model) or changed
|
||||
changed = _rewrite_reduce_axes_inputs(model) or changed
|
||||
changed = _cap_default_opset(model) or changed
|
||||
if not changed:
|
||||
return model_bytes
|
||||
|
||||
buffer = io.BytesIO()
|
||||
onnx.save_model(model, buffer)
|
||||
return buffer.getvalue()
|
||||
Reference in New Issue
Block a user