Conversation
|
See #3753 for why Zero3 won't be supported in this implementation. |
EricMichaelSmith
left a comment
There was a problem hiding this comment.
Seems reasonable - minor comments
parlai/core/params.py
Outdated
| ) | ||
| grp.add_argument( | ||
| '--ddp-backend', | ||
| choices=['ddp', 'zero2', 'zero3'], |
There was a problem hiding this comment.
Hmm should we even give 'zero3' as an option for the time being? (Don't really care either way)
parlai/utils/fsdp.py
Outdated
|
|
||
| def should_sync_gradnorm(opt): | ||
| """ | ||
| Indicates whether fp16 optimizer wrappers should cumulate over workers. |
There was a problem hiding this comment.
Nit: "accumulate"?
parlai/core/torch_agent.py
Outdated
|
|
||
| For models or optimizers that shard parameters, this ensures we sync. | ||
| """ | ||
| if self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3'): |
There was a problem hiding this comment.
Nit: should we pull in DEFAULT_DDP_BACKEND here?
parlai/core/torch_generator_agent.py
Outdated
| if ( | ||
| shared is None | ||
| and is_distributed() | ||
| and opt.get('ddp_backend', 'ddp') == 'ddp' |
There was a problem hiding this comment.
(same here about maybe using DEFAULT_DDP_BACKEND instead)
klshuster
left a comment
There was a problem hiding this comment.
really really cool. lots of nits though (and a few real questions 😄 )
| if hasattr(self, 'model'): # save model params | ||
| if hasattr(self.model, 'module'): | ||
| # did we wrap in a DistributedDataParallel | ||
| if hasattr(self.model, 'module') and not is_fsdp(self.model): |
There was a problem hiding this comment.
nit: could make this a helper function too? like should_sync_gradnorm (not necessary of course)
parlai/core/torch_generator_agent.py
Outdated
| self.model = self.build_model() | ||
| with fsdp_utils.maybe_fsdp_wrap(opt): | ||
| self.model = fsdp_utils.fsdp_wrap(self.build_model()) | ||
| if self.fp16 and not fsdp_utils.should_use_fsdp(opt): |
There was a problem hiding this comment.
remember that bug with the instability stuff? is this not re-introducing it?
There was a problem hiding this comment.
(because we moved the model.half() call?)
There was a problem hiding this comment.
Okay I think this needs to use my utility should_delay_halving. Forgot this.
We haven't really moved it the moment of halving. The operations between these two points don't do much, and the original code path should be about the same.
- We now half it on CPU instead of GPU, and then transfer. That's probably a small speedup in initialization really, with maybe some small numerical differences
- We model parallel after halving. Probably small speedup at initialization.
- We synchronize parameters after halving. Again, small initialization speedup.
The catch is that FSDP expects the model pre-halved if we're doing safe optimization, and post-halved if we're doing memory-efficient. (Similar to the optimizer wrappers, it looks for parameters of types to decide what type are the gradients).
This is the desired pattern
- If we're in Safe and using DDP, we SHOULD still halve, just as before
- If we're in MemEff and using DDP, we SHOULD still halve, just as before
- If we're in Safe and Zero2, we should NOT halve here
- If we're in MemEff and Zero2, we SHOULD halve here.
|
|
||
|
|
||
| def launch_and_train(opt, port): | ||
| def launch_and_train(opt, port=None): |
There was a problem hiding this comment.
will we ever specify a port here?
| self.best_valid = new_valid | ||
| self.impatience = 0 | ||
| if opt.get('model_file') and is_primary_worker(): | ||
| if opt.get('model_file'): |
There was a problem hiding this comment.
just making sure I understand - we can get rid of this check because it's handled in save_model right?
There was a problem hiding this comment.
We need to be able do save_on_nonprimary_worker actually
| if max_norm > 0: | ||
| clip_coef = max_norm / (grad_norm + 1e-6) | ||
| for p in params: | ||
| p.grad.detach().mul_(clip_coef) |
There was a problem hiding this comment.
Don't want grads of grads! (This is in the original pytorch code too)
| return | ||
|
|
||
| # zero3 not supported at this time. Throw an exception | ||
| if opt['ddp_backend'] == 'zero3': |
There was a problem hiding this comment.
i know this is just for overkill testing but it's not even a choice in the param options so we'll already error there if calling from command line
There was a problem hiding this comment.
I'm leaving it for the future
parlai/utils/fsdp.py
Outdated
| return ( | ||
| self.fp16 | ||
| and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3') | ||
| and self.opt['fp16_impl'] == 'safe' |
There was a problem hiding this comment.
but if we're using mem_efficient we don't delay?
There was a problem hiding this comment.
Correct, see main comment
Patch description
Add support for Fairscale's FullyShardedDataParallel (FSDP). This is an implementation of DeepSpeed's Zero2 optimization, wherein optimizer state and gradients are sharded across different workers in order to reduce memory usage. Switching to
--ddp-backend zero2results in about a 25% speedup in UPS (without bg workers, probably can be a bit higher), and about a 50% reduction in memory usage. It's recommended everyone switches to this for distributed training, and use the savings to increase batchsize or lower number of GPUs.We also carve out support for Zero3, but cannot support it at this time due to high level design in ParlAI. See #3753 for a detailed description of why, and how we might overcome this in the future.
As a side change, this also makes our unit tests use OS-assigned free ports, instead of randomized ones, to slightly improve the reliability of running our test suites. I tried pulling this into another PR, but got tired of dealing with stacking.
Testing steps
Manual tests. New CI.
Here are some screenshots from a sweep that contained both
--ddp-backend ddpand--ddp-backend zero2