Help developing a small shim to allow PyTorch models to be used in spaCy

I develop the spaCy NLP library. We have our own NN library, Thinc, to avoid dependencies (plus I started writing it before PyTorch was around :p), but for obvious reasons we’d like to let people use PyTorch models in spaCy as well.

The plan has been to write small shim classes that would wrap PyTorch (or other libraries’) models to have the same API as Thinc. You can find the wrapper class so far here: https://github.com/explosion/thinc/blob/master/thinc/extra/wrappers.py

Questions

  1. How do I resize an input layer? If neurons are added, the weights for the new activations should be zero. If the new size is smaller, the last activations should be truncated.

  2. How do I resize an output layer? If neurons are added, the weights for the new activations should be zero. If the new size is smaller, the last activations should be truncated.

  3. Thinc has a use_params() context-manager, which allows usage of weights passed in for the scope of a block. Is load_state_dict() the right thing there?

Current progress

The heart of the wrapper is Thinc’s begin_update() method. This takes a batch of inputs, and returns a tuple with a batch of outputs and a callback to complete the backward pass. This was pretty easy to do, but I wrote it a few months ago — hopefully it’s still current?


    def begin_update(self, x_data, drop=0.):
        '''Return the output of the wrapped PyTorch model for the given input,
        along with a callback to handle the backward pass.
        '''
        x_var = torch.autograd.Variable(torch.Tensor(x_data),
                                        requires_grad=True)
        # Make prediction
        y_var = self._model(x_var)
        def backward_pytorch(dy_data, sgd=None):
            dy_var = torch.autograd.Variable(torch.Tensor(dy_data))
            torch.autograd.backward((y_var,), grad_variables=(dy_var,))
            dX = self.ops.asarray(x_var.grad.data)
            if sgd is not None:
                optimizer.step()
            return dX
        return self.ops.asarray(y_var.data), backward

The main outstanding problem with the begin_update() wrapper above is that Thinc takes an argument drop which is a float between 0 and 1. This is used to dropout the outgoing activations. We shouldn’t need to worry about making this auto-differentiable – it should be fine to get the dropout mask, and then just multiply it by the activations. Then we can multiply the incoming gradient by the mask as well, since the mask will be stored in the enclosing scope.

I’ve also drafted out the serialization methods. Thinc uses to/from_bytes/disk. The architecture is not saved, just the parameters — we assume that the architecture is reconstructed before you call from_bytes(). So this seems fairly straight-forward.

Hi Matthew,

That’s really cool! I’m looking forward to this integration.

Questions

  1. What kind of layer do you want to enlarge/shrink? Is it only for linear modules, or do you also consider things like conv? In general you’ll probably need to fiddle with the weight tensor (just reassign the attr with the new value).

  2. As above

  3. I’m not sure what’s the exact purpose of this context manager so it’s hard for me to answer that. I’ll need more details.

Wrapper code

It depends on what version are you targetting. This is using things that are outdated in master, but looks good from the level of 0.3.

About the dropout - yes, multiplying y_var and dy_data by the mask should be enough. Although I’d say it’s a bit of a surprising interface (what makes dropout so special?).

For serialization of modules you probably want to use state_dict() and load_state_dict(). Also torch.load and torch.save will be more efficient with tensors than pickle, but both ways work.


Hope this helps! Let us know if you have any other questions!

Thanks! Seems simple enough.

Mostly linear layers. The most important is adding output variables to add classes to a model, e.g. to add a new entity type.

I’ve just realized I can easily construct this using the to_bytes() and from_bytes() methods. So there’s no real problem here. For the record though, it looks like this:

for epoch in epochs:
    train_epoch(model, optimizer, train_data)
    with model.use_params(optimizer.averages):
        print(model.evaluate(dev_data))

This is used for the “averaging trick”, where you store the EMA of the parameters during training, and use them for inference instead of the final weight values. This helps avoid sharp minima, and in my experience is pretty much always good. The context manager is necessary because you want to print a checkpoint against the weights you’ll actually use for inference, while resuming training from where you left off.

It does suck a bit. There are a few of these warts in Thinc at the moment.

The rationale was that I wanted the calling code to control the dropout. This makes it easy to decay or increase the dropout rate without walking the models changing hyper-parameters. I also wanted layers to own how they perform dropout. I was also thinking about allowing variational inference, although I haven’t played with that.

Weight changes

For linear layers it should be simple. It’s something like:

model.weight = nn.Parameter(F.pad(model.weight, ...)) # add classes
model.weight = nn.Parameter(model.weight[...])        # remove classes

I’m not sure about the exact params, but they should be easy to figure out. IIRC the weight layout is out_features x in_features.

use_params()

I see, that makes sense! The model. part was missing, so I was unsure how are you specifying where the params live :smile:

First test working :tada: . I actually had a significant bug in that begin_update() function: I wasn’t calling optim.zero_grad().