Refactor RGBD Pixel Kernel#161
Conversation
b0cbc5e to
e677b82
Compare
horizon-blue
left a comment
There was a problem hiding this comment.
These changes look good to me. I left some inline comments for some potential improvements, but I'm also okay with just merging this PR as-is so we can move fast :). Since we have the commit history anyways, we can also revert and make changes in the future if needed.
| total_log_prob = 0.0 | ||
|
|
||
| is_depth_non_return = observed_rgbd[3] == 0.0 | ||
|
|
||
| # Is visible | ||
| total_visible_log_prob = 0.0 | ||
| # color term | ||
| total_visible_log_prob += self.inlier_color_distribution.logpdf( | ||
| observed_rgbd[:3], latent_rgbd[:3], color_scale | ||
| ) | ||
| # depth term | ||
| total_visible_log_prob += jnp.where( | ||
| is_depth_non_return, | ||
| jnp.log(depth_nonreturn_prob), | ||
| jnp.log(1 - depth_nonreturn_prob) | ||
| + self.inlier_depth_distribution.logpdf( | ||
| observed_rgbd[3], latent_rgbd[3], depth_scale | ||
| ), | ||
| ) | ||
|
|
||
| # Is not visible | ||
| total_not_visible_log_prob = 0.0 | ||
| # color term | ||
| outlier_color_log_prob = self.outlier_color_distribution.logpdf( | ||
| observed_rgbd[:3], latent_rgbd[:3], color_scale | ||
| ) | ||
| outlier_depth_log_prob = self.outlier_depth_distribution.logpdf( | ||
| observed_rgbd[3], latent_rgbd[3], depth_scale | ||
| ) | ||
|
|
||
| total_not_visible_log_prob += outlier_color_log_prob | ||
| # depth term | ||
| total_not_visible_log_prob += jnp.where( | ||
| is_depth_non_return, | ||
| jnp.log(depth_nonreturn_prob), | ||
| jnp.log(1 - depth_nonreturn_prob) + outlier_depth_log_prob, | ||
| ) | ||
|
|
||
| total_log_prob += jnp.logaddexp( | ||
| jnp.log(visibility_prob) + total_visible_log_prob, | ||
| jnp.log(1 - visibility_prob) + total_not_visible_log_prob, | ||
| ) | ||
| depth_logpdf = self.depth_kernel.logpdf( | ||
| observed_rgbd[3], | ||
| latent_rgbd[3], | ||
| depth_scale, | ||
| visibility_prob, | ||
| depth_nonreturn_prob, | ||
| return jnp.where( | ||
| jnp.any(is_unexplained(latent_rgbd)), | ||
| outlier_color_log_prob + outlier_depth_log_prob, | ||
| total_log_prob, |
There was a problem hiding this comment.
Niice. I think we might be able to implement some of these mixture terms using the mixture distribution class, so that we can get both logpdf and sample for free, but we can do that in a future PR.
(actually if you don't mind, I can also do some of the clean ups after you merge this)
| from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( | ||
| FullPixelDepthDistribution, | ||
| ) | ||
| # import b3d.chisight.gen3d.inference_moves as im |
There was a problem hiding this comment.
Do you mind marking these tests as skipping (and maybe also pytest.importorskip) instead of commenting them out for now? Just so that we have a reminder to come back and clean these up in the future
|
@nishadgothoskar I made the inference changes we discussed, and the I noticed that the log q scores for one of the proposals are causing it to break (and they had been before as well), so I will work on debugging this this evening. But this should hopefully unblock you to continue developing your tests! If it would be helpful to discuss testing strategy or anything else this evening, in light of this changed inference code, please let me know! |
No description provided.