Skip to content

[Feature] Differentiable VMAS#80

Merged
matteobettini merged 7 commits intomainfrom
differentiable
Feb 6, 2024
Merged

[Feature] Differentiable VMAS#80
matteobettini merged 7 commits intomainfrom
differentiable

Conversation

@matteobettini
Copy link
Member

@matteobettini matteobettini commented Feb 5, 2024

VMAS is now fully differentiable and you can backporpagate through any of its scenarios.

To enable this, set grad_enabled=True at env construction

You can then do stuff like:

for step in steps:
    actions = []
    for agent in agents:
        action = ....
        action.requires_grad_(True)
        if step == 0:
            first_action = action
        actions.append(action)
    obs, rews, dones, info = env.step(actions)

loss = obs[-1].mean() + rews[-1].mean()
grad = torch.autograd.grad(loss, first_action)

Which will backpropagate a loss computed using observation and reward through time, back to input action in the first timestep

@matteobettini matteobettini marked this pull request as ready for review February 6, 2024 12:26
@matteobettini matteobettini merged commit 7c453bf into main Feb 6, 2024
@matteobettini matteobettini deleted the differentiable branch February 6, 2024 16:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant