Extending Tensors to lie on Riemannian manifold

Hi! I started implementing a recent paper that seemed to be exciting and useful for the community (it is Riemannian Adaptive Optimization Methods). Deciding on API I chose native Tensors to be subclassed as well as Parameters. This allowed generic implementation for Riemannian optimizers. However, having a decent development experience, I also thought it is a good thing to ask pytorch developers if I’m doing things in the right way.

My question is this way to implement such kind of subclasses is right and a follows best practices? (here is it)

Some issues/benefits I got with this API

  • isinstance(r_parameter, torch.nn.Parameter) is True that is a good thing, you can use such params in a module
  • r_tensor.clone() returns Tensor type. I have not yet decided whether it is expected behaviour or should be changed. and return ManifoldTensor type.
  • r_rensor.data – same thing

I think the last two concerns are the only I have. Thank you!

2 Likes

This looks quite interesting. Thanks a lot for sharing!

As of right now, since there isn’t a better way to runtime-add functionality to the Tensor class, I’d expect that this is probably fine.
I’d say that for clone and data you probably want to preserve your own Tensor class, instead of it being de-promoted back to Tensor class.

We are working on adding runtime-registerable functions on Tensor class, so that you can declare something like torch.my_fun(Tensor, Tensor) -> Tensor and it be dispatched to your own function. I’m not sure when that will land, hoping for 1.1

I’d say that for clone and data you probably want to preserve your own Tensor class, instead of it being de-promoted back to Tensor class.

I copied the behaviour of torch.nn.Parameter class, it actually de-promotes to Tensor (that seems to be expected).

Do you plan to add a kind of tensor conversion? Say, I have a complex data structure that is able to construct a Tensor representation. I would like to store that complexity on a backend, providing users just an API so that they seamlessly pass this data structure into torch functions as a regular tensor.

The developer facing api for that should look like

def __to_tensor__(self, dtype=None, device=None, clone=False):

That’s not a bad idea, we haven’t thought about that tbh. If you open an issue, we’ll discuss the semantics / implementation / timeline.