Skip to content

Pass losses as callables when building detection models #5325

Open
@Quintulius

Description

@Quintulius

🚀 The feature

Some models currently accept normalization strategies as callables (mobilenet_backbone accepts a norm_layer argument for example) but loss functions are currently hardcoded (F.cross_entropy for fastercnn.roi_heads for example).

Following what has been done on normalization strategies loss function could be passed as callables in the modules constructor. This shouldn't break any backward compatibility. Reduction strategies still need to be properly handled.

Motivation, pitch

Currently, trying different loss functions requires to use some dirty model patches. Accepting the losses in the model constructors would provide a much cleaner way to hack around the models.

If any interest I can propose a first PR modifying the Faster-RNN models.

Alternatives

No response

Additional context

No response

cc @datumbox

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions