Metadata-Version: 2.1
Name: onnx-tool
Version: 0.1.2
Summary: A tool for ONNX model's shape inference and MACs counting.
Home-page: https://github.com/ThanatosShinji/onnx-tool
Author: Luo Yu
Author-email: luoyu888888@gmail.com
License: MIT
Classifier: Programming Language :: Python :: 3
Description-Content-Type: text/markdown
License-File: LICENSE

# onnx-tool
A tool for ONNX model's shape inference and MACs counting.

* Shape inference
<p align="center">
  <img src="data/shape_inference.jpg">
</p>

---
* MACs counting for each node
<p align="center">
  <img src="data/macs_counting.jpg">
</p>

## How to install
    
`pip install onnx-tool`

OR

`pip install --upgrade git+https://github.com/ThanatosShinji/onnx-tool.git`
    
## How to use 
* Basic usage 
    ```python
    import onnx
    from onnx_tool.node_profilers import graph_profile,print_node_map
    model = onnx.load('resnet50.onnx')
    macs, params, node_map = graph_profile(model.graph, None) #shape inference included
    print_node_map(node_map)
    onnx.save_model(model,'resnet50_shapes.onnx') #save model with inferred shapes
    ```    

* Dynamic input shapes and dynamic resize scales('downsample_ratio')
    ```python
    import onnx
    from onnx_tool.node_profilers import graph_profile,print_node_map,create_ndarray_f32
    model = onnx.load('rvm_mobilenetv3_fp32.onnx')
    inputs= {'src': create_ndarray_f32((1, 3, 1080, 1920)), 'r1i': create_ndarray_f32((1, 16, 135, 240)),
                                 'r2i':create_ndarray_f32((1,20,68,120)),'r3i':create_ndarray_f32((1,40,34,60)),
                                 'r4i':create_ndarray_f32((1,64,17,30)),'downsample_ratio':numpy.array((0.25,),dtype=numpy.float32)}
    macs, params, node_map = graph_profile(model.graph, inputs) #shape inference included
    print_node_map(node_map,'rvm_nodemap.txt') #save node map to file
    onnx.save_model(model,'rvm_mobilenetv3_fp32_shapes.onnx')
    ```    

* Define your custom op's node profiler.
    ```python
    from onnx_tool.node_profilers import graph_profile,NODEPROFILER_REGISTRY
  
    @NODEPROFILER_REGISTRY.register()
    class YourOp():
        def __init__(self,nodeproto):
            #parse your attributes here

        def infer_shape(self,intensors:List[numpy.ndarray]):
            #calculate output shapes here
            #return a list of ndarray
            return outtensors

        def profile(self,intensors:List[numpy.ndarray],outtensors:List[numpy.ndarray]):
            #do macs and params accumulations here
            return macs,params
  
    macs, params, node_map = graph_profile(yourmodel.graph, None)
    ```

## Known Issues
* Loop op is not supported
* Shared weight tensor will be counted more than once

## Results of [ONNX Model Zoo](https://github.com/onnx/models) and SOTA models
Some models have dynamic input shapes. The MACs varies from input shapes. The input shapes used in these results are writen to data/public/config.py.
<p align="center">
<table>
<tr>

<td>

Model | Params(M) | MACs(M)
---|---|---
[MobileNet v2-1.0-fp32](https://github.com/onnx/models/blob/main/vision/classification/mobilenet) | 3.3 | 300
[ResNet50_fp32](https://github.com/onnx/models/tree/main/vision/classification/resnet) | 25 | 3868
[SqueezeNet 1.0](https://github.com/onnx/models/tree/main/vision/classification/squeezenet) | 1.23 | 351
[VGG 19](https://github.com/onnx/models/tree/main/vision/classification/vgg) | 143.66 | 19643
[AlexNet](https://github.com/onnx/models/tree/main/vision/classification/alexnet) | 60.96 | 665
[GoogleNet](https://github.com/onnx/models/tree/main/vision/classification/inception_and_googlenet/googlenet) | 6.99 | 1606
[googlenet_age_adience](https://github.com/onnx/models/tree/main/vision/body_analysis/age_gender) | 5.98 | 1605
[LResNet100E-IR](https://github.com/onnx/models/tree/main/vision/body_analysis/arcface) | 65.22 | 12102
[BERT-Squad](https://github.com/onnx/models/tree/main/text/machine_comprehension/bert-squad) | 113.61 | 22767
[BiDAF](https://github.com/onnx/models/tree/main/text/machine_comprehension/bidirectional_attention_flow) | 18.08 | 9.87
[EfficientNet-Lite4](https://github.com/onnx/models/tree/main/vision/classification/efficientnet-lite4) | 12.96 | 1361
[Emotion FERPlus](https://github.com/onnx/models/tree/main/vision/body_analysis/emotion_ferplus) | 12.95 | 877
[Mask R-CNN R-50-FPN-fp32](https://github.com/onnx/models/tree/main/vision/object_detection_segmentation/mask-rcnn) | 46.77 | 92077
</td>

<td>

Model | Params(M) | MACs(M)
---|---|---
[rvm_mobilenetv3_fp32.onnx](https://github.com/PeterL1n/RobustVideoMatting) | 3.73 | 4289
[yolov4](https://github.com/onnx/models/tree/main/vision/object_detection_segmentation/yolov4) | 64.33 | 33019
[ConvNeXt-L](https://github.com/facebookresearch/ConvNeXt) | 229.79 | 34872
[edgenext_small](https://github.com/mmaaz60/EdgeNeXt) | 5.58 | 1357
[SSD](https://github.com/onnx/models/tree/main/vision/object_detection_segmentation/ssd) | 19.98 | 216598
[RealESRGAN_x4plus.pth](https://github.com/xinntao/Real-ESRGAN) | 16.69 | 73551
[ShuffleNet-v2-fp32](https://github.com/onnx/models/tree/main/vision/classification/shufflenet) | 2.29 | 146
[GPT-2](https://github.com/onnx/models/tree/main/text/machine_comprehension/gpt-2) | 137.02 | 1103
[T5-encoder](https://github.com/onnx/models/tree/main/text/machine_comprehension/t5) | 109.62 | 686
[T5-decoder-with-lm-head](https://github.com/onnx/models/tree/main/text/machine_comprehension/t5) | 162.62 | 1113
[RoBERTa-BASE](https://github.com/onnx/models/tree/main/text/machine_comprehension/roberta) | 124.64 | 688
[Faster R-CNN R-50-FPN-fp32](https://github.com/onnx/models/blob/main/vision/object_detection_segmentation/faster-rcnn) | 44.10 | 46018
[FCN ResNet-50](https://github.com/onnx/models/tree/main/vision/object_detection_segmentation/fcn) | 35.29 | 37056


</td>
</tr>
</p>
