import ast import io import onnx from onnx import helper, numpy_helper _REDUCE_OPS_WITH_AXES_INPUT = { "ReduceL1", "ReduceL2", "ReduceLogSum", "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", "ReduceSumSquare", } def _metadata(model): return {p.key: p.value for p in model.metadata_props} def _input_size(model): try: imgsz = _metadata(model).get("imgsz") parsed = ast.literal_eval(imgsz) if isinstance(parsed, (list, tuple)) and len(parsed) == 2: h, w = int(parsed[0]), int(parsed[1]) if h > 0 and w > 0: return h, w except Exception: pass return 1280, 1280 def _constant_values(graph): values = {init.name: numpy_helper.to_array(init) for init in graph.initializer} for node in graph.node: if node.op_type != "Constant" or not node.output: continue for attr in node.attribute: if attr.name == "value": values[node.output[0]] = numpy_helper.to_array(attr.t) break return values def _as_int_list(value): if value is None: return None if getattr(value, "shape", ()) == (): return [int(value)] return [int(v) for v in value.reshape(-1).tolist()] def _set_static_input_shape(model, batch=1): h, w = _input_size(model) for graph_input in model.graph.input: tensor_type = graph_input.type.tensor_type if tensor_type.elem_type != onnx.TensorProto.FLOAT: continue dims = tensor_type.shape.dim if len(dims) != 4: continue for dim, value in zip(dims, (batch, 3, h, w)): dim.dim_value = value return True return False def _rewrite_reduce_axes_inputs(model): constants = _constant_values(model.graph) changed = False for node in model.graph.node: if node.op_type not in _REDUCE_OPS_WITH_AXES_INPUT or len(node.input) < 2: continue axes = _as_int_list(constants.get(node.input[1])) if axes is None: continue kept_attrs = [attr for attr in node.attribute if attr.name != "axes"] del node.attribute[:] node.attribute.extend(kept_attrs) node.attribute.extend([helper.make_attribute("axes", axes)]) del node.input[1:] changed = True return changed def _cap_default_opset(model, max_opset=17): for opset in model.opset_import: if opset.domain in ("", "ai.onnx") and opset.version > max_opset: opset.version = max_opset return True return False def prepare_for_tensorrt(model_bytes): model = onnx.load_model_from_string(model_bytes) changed = False changed = _set_static_input_shape(model) or changed changed = _rewrite_reduce_axes_inputs(model) or changed changed = _cap_default_opset(model) or changed if not changed: return model_bytes buffer = io.BytesIO() onnx.save_model(model, buffer) return buffer.getvalue()