Skip to content

Make UNet2DConditionOutput pickle-able#3857

Merged
sayakpaul merged 6 commits into
huggingface:mainfrom
prathikr:prathikrao/unet-output-bug-fix
Jul 6, 2023
Merged

Make UNet2DConditionOutput pickle-able#3857
sayakpaul merged 6 commits into
huggingface:mainfrom
prathikr:prathikrao/unet-output-bug-fix

Conversation

@prathikr

@prathikr prathikr commented Jun 22, 2023

Copy link
Copy Markdown
Contributor

This PR addresses previous concerns that the output of the UNet's forward pass is not copy-able. The root cause appears to be because copy fails on collections.OrderedDict dataclass with required args. The solution presented sets a default value for sample such that is it no longer a required parameter of the output class while still erroring when missing since the default setting is None (link to similar solution for different model).

Reproduction Instructions:

from diffusers.utils import BaseOutput
from dataclasses import dataclass
import copy

@dataclass
class NetParams(BaseOutput):
    sample: torch.FloatTensor

m = NetParams(sample=torch.randn(1, 10))
n = copy.copy(m)

@prathikr prathikr changed the title add default to unet output to prevent it from being a required arg [WIP] add default to unet output to prevent it from being a required arg Jun 22, 2023
@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Jun 22, 2023

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@prathikr prathikr changed the title [WIP] add default to unet output to prevent it from being a required arg [WIP] Make UNet2DConditionOutput pickle-able Jun 22, 2023
@prathikr prathikr changed the title [WIP] Make UNet2DConditionOutput pickle-able Make UNet2DConditionOutput pickle-able Jun 22, 2023
@prathikr prathikr marked this pull request as ready for review June 22, 2023 23:15
@prathikr

Copy link
Copy Markdown
Contributor Author

@patrickvonplaten @anton-l can I please get a review on this?

@prathikr

Copy link
Copy Markdown
Contributor Author

@patrickvonplaten @anton-l any updates?

@patrickvonplaten

Copy link
Copy Markdown
Contributor

This change is ok for me! Could we add a test here that shows how we can now pickle the output?

@prathikr

prathikr commented Jun 28, 2023

Copy link
Copy Markdown
Contributor Author

@patrickvonplaten I gave adding a unit test a try. Let me know if I should change it or put it somewhere else.

@prathikr

Copy link
Copy Markdown
Contributor Author

@patrickvonplaten any updates?

@prathikr

prathikr commented Jul 3, 2023

Copy link
Copy Markdown
Contributor Author

@patrickvonplaten this is currently blocking ONNX Runtime integration with Diffusers. Can you please provide an update? Thank you.

@patrickvonplaten

Copy link
Copy Markdown
Contributor

Ok for me!

@patrickvonplaten

Copy link
Copy Markdown
Contributor

@sayakpaul @pcuenca could you maybe also quickly check?

Comment thread tests/models/test_models_unet_2d_condition.py Outdated

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! PR looks great to me except for https://github.com/huggingface/diffusers/pull/3857/files#r1252260499.

@prathikr prathikr requested a review from sayakpaul July 5, 2023 18:08
@prathikr

prathikr commented Jul 6, 2023

Copy link
Copy Markdown
Contributor Author

@sayakpaul can you please review/merge? Thank you.

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@sayakpaul sayakpaul merged commit de14261 into huggingface:main Jul 6, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add default to unet output to prevent it from being a required arg

* add unit test

* make style

* adjust unit test

* mark as fast test

* adjust assert statement in test

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add default to unet output to prevent it from being a required arg

* add unit test

* make style

* adjust unit test

* mark as fast test

* adjust assert statement in test

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants