-
Notifications
You must be signed in to change notification settings - Fork 3.3k
feat(SpecEagleV2): add standalone_worker_v2(WIP) #12625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat(SpecEagleV2): add standalone_worker_v2(WIP) #12625
Conversation
Summary of ChangesHello @attack204, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Summary of ChangesHello @attack204, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new version of the standalone speculative decoding worker, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Summary of ChangesHello @attack204, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new, more modular version of the 'standalone' speculative decoding worker, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces standalone_worker_v2 to support the standalone speculative decoding algorithm with overlap scheduling (SpecEagleV2). The changes primarily involve extending existing logic for the 'EAGLE' algorithm to also cover the 'STANDALONE' algorithm across various manager and utility files. A new file standalone_worker_v2.py is added, which defines the v2 worker.
The implementation looks solid and follows the existing patterns. My main feedback is around improving maintainability by reducing code duplication. I've pointed out a couple of repeated conditional checks that could be refactored into helper properties or methods. Additionally, there's a FIXME in the new worker related to pipeline parallelism that should be addressed to ensure full feature compatibility.
|
|
||
| def resolve_future(self, model_worker_batch: ModelWorkerBatch): | ||
| if self.spec_algo.is_eagle(): | ||
| if self.spec_algo.is_eagle() or self.spec_algo.is_standalone(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition self.spec_algo.is_eagle() or self.spec_algo.is_standalone() is repeated in this file and also in scheduler_output_processor_mixin.py. To improve maintainability, consider adding a helper method to the SpeculativeAlgorithm class in spec_info.py. For example:
def is_eagle_or_standalone(self):
return self.is_eagle() or self.is_standalone()This would centralize the logic and make the code cleaner.
| bs = len(self.reqs) | ||
|
|
||
| if self.is_v2_eagle: | ||
| if self.is_v2_eagle or self.is_v2_standalone: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition self.is_v2_eagle or self.is_v2_standalone is repeated in several places (here, maybe_wait_verify_done, and in scheduler.py and scheduler_output_processor_mixin.py). To improve maintainability and reduce code duplication, consider adding a new property to the ScheduleBatch class that encapsulates this logic. For example:
@property
def is_v2_spec(self):
return self.is_v2_eagle or self.is_v2_standaloneThen you can simplify this condition to if self.is_v2_spec:. This would make the code cleaner and easier to modify in the future if more v2 speculative algorithms are added.
| server_args=server_args, | ||
| gpu_id=gpu_id, | ||
| tp_rank=tp_rank, | ||
| pp_rank=0, # FIXME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces standalone_worker_v2 to implement a new speculative decoding algorithm. The changes are mostly about integrating this new algorithm into the existing logic paths. The implementation is functional, but there are several opportunities to improve maintainability by reducing code duplication and repeated conditional logic. My review focuses on refactoring these areas for better code clarity and easier future extensions.
| def __init__( | ||
| self, | ||
| server_args: ServerArgs, | ||
| gpu_id: int, | ||
| tp_rank: int, | ||
| dp_rank: int, | ||
| moe_ep_rank: int, | ||
| nccl_port: int, | ||
| target_worker: TpModelWorker, | ||
| ): | ||
| # copy args | ||
| self.server_args = server_args | ||
| self.gpu_id = gpu_id | ||
| self.tp_rank = tp_rank | ||
| self.dp_rank = dp_rank | ||
| self.moe_ep_rank = moe_ep_rank | ||
| self.nccl_port = nccl_port | ||
| self.target_worker = target_worker | ||
|
|
||
| # Args for easy access | ||
| self.device = server_args.device | ||
| self.topk = server_args.speculative_eagle_topk | ||
| self.speculative_num_steps = server_args.speculative_num_steps | ||
| self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens | ||
| self.speculative_algorithm = SpeculativeAlgorithm.from_string( | ||
| server_args.speculative_algorithm | ||
| ) | ||
|
|
||
| # Set constant | ||
| from sglang.srt.speculative.eagle_info import EagleDraftInput | ||
| EagleDraftInput.ALLOC_LEN_PER_DECODE = max( | ||
| self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens | ||
| ) | ||
|
|
||
| # Do not capture cuda graph in `TpModelWorker` init, | ||
| # will capture later with init_cuda_graphs() | ||
| backup_disable_cuda_graph = server_args.disable_cuda_graph | ||
| server_args.disable_cuda_graph = True | ||
|
|
||
| # Share the allocator with a target worker. | ||
| # Draft and target worker own their own KV cache pools. | ||
| self.req_to_token_pool, self.token_to_kv_pool_allocator = ( | ||
| target_worker.get_memory_pool() | ||
| ) | ||
| with empty_context(): | ||
| # Init draft worker | ||
| self.draft_worker = TpModelWorker( | ||
| server_args=server_args, | ||
| gpu_id=gpu_id, | ||
| tp_rank=tp_rank, | ||
| pp_rank=0, # FIXME | ||
| dp_rank=dp_rank, | ||
| moe_ep_rank=moe_ep_rank, | ||
| nccl_port=nccl_port, | ||
| is_draft_worker=True, | ||
| req_to_token_pool=self.req_to_token_pool, | ||
| token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, | ||
| ) | ||
|
|
||
| # Alias for better readability | ||
| self.draft_runner = self.draft_worker.model_runner | ||
|
|
||
| self.init_token_map() | ||
| self.init_lm_head() | ||
|
|
||
| # Init attention backend and cuda graphs | ||
| self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph | ||
| self.draft_tp_context = ( | ||
| draft_tp_context if server_args.enable_dp_attention else empty_context | ||
| ) | ||
| with self.draft_tp_context(self.draft_runner.tp_group): | ||
| self.init_attention_backend() | ||
| self.init_cuda_graphs() | ||
|
|
||
| from sglang.srt.speculative.eagle_utils import TreeMaskMode | ||
| self.tree_mask_mode = TreeMaskMode.FULL_MASK | ||
|
|
||
| self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method of StandaloneDraftWorker is an exact copy of EagleDraftWorker.__init__. This code duplication can be avoided by removing the __init__ method from StandaloneDraftWorker and letting it inherit from EagleDraftWorker. Since EagleDraftWorker.__init__ calls self.init_lm_head(), your override of init_lm_head will be correctly used, achieving the same goal with much cleaner code.
|
|
||
| def _lazy_init_buf(self, draft_input: EagleDraftInput): | ||
| if self.buf_initialized or not self.spec_algo.is_eagle(): | ||
| if self.buf_initialized or (not self.spec_algo.is_eagle() and not self.spec_algo.is_standalone()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| def resolve_future(self, model_worker_batch: ModelWorkerBatch): | ||
| if self.spec_algo.is_eagle(): | ||
| if self.spec_algo.is_eagle() or self.spec_algo.is_standalone(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ): | ||
| intv = future_indices.interval | ||
| if self.spec_algo.is_eagle(): | ||
| if self.spec_algo.is_eagle() or self.spec_algo.is_standalone(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @property | ||
| def is_v2_standalone(self): | ||
| return self.enable_overlap and self.spec_algorithm.is_standalone() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following the suggestion to add is_v2_algo to SpeculativeAlgorithm, you can introduce a corresponding is_v2_spec property here. This will centralize the logic for checking v2 speculative algorithms and simplify conditions in this and other files.
| @property | |
| def is_v2_spec(self): | |
| return self.enable_overlap and self.spec_algorithm.is_v2_algo() |
| future_indices_or_next_token_ids = -future_indices.indices | ||
|
|
||
| if batch.is_v2_eagle: | ||
| if batch.is_v2_eagle or batch.is_v2_standalone: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if batch.return_logprob: | ||
| next_token_logprobs = logits_output.next_token_logprobs.tolist() | ||
| elif batch.is_v2_eagle: | ||
| elif batch.is_v2_eagle or batch.is_v2_standalone: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if self.enable_overlap and (req.finished() or req.is_retracted): | ||
| indices_to_free = None | ||
| if batch.spec_algorithm.is_eagle(): | ||
| if batch.spec_algorithm.is_eagle() or batch.spec_algorithm.is_standalone(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if batch.spec_algorithm.is_none(): | ||
| req.output_ids.append(next_token_id) | ||
| elif batch.is_v2_eagle: | ||
| elif batch.is_v2_eagle or batch.is_v2_standalone: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| if req.finished(): | ||
| if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend(): | ||
| if (batch.is_v2_eagle or batch.is_v2_standalone) and self.cur_batch.forward_mode.is_extend(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces standalone_worker_v2 to enable speculative decoding with standalone draft models in an overlapped execution mode. The changes primarily involve extending existing logic for the 'EAGLE' algorithm to also support the new 'STANDALONE' algorithm. The main implementation resides in the new standalone_worker_v2.py file.
My review focuses on improving code maintainability by reducing duplication. I've identified several areas where repeated conditional logic can be refactored into helper properties or methods. Specifically, I've suggested creating a new property in schedule_batch.py to simplify checks for V2 speculative overlap. In the new standalone_worker_v2.py, I've recommended using super().__init__ to reduce significant code duplication in the __init__ methods of StandaloneDraftWorker and StandaloneWorkerV2. Additionally, I've included a minor style suggestion in server_args.py to improve readability. These changes will make the codebase cleaner and easier to maintain.
| def is_v2_standalone(self): | ||
| return self.enable_overlap and self.spec_algorithm.is_standalone() | ||
|
|
||
| def prepare_for_decode(self): | ||
| self.forward_mode = ForwardMode.DECODE | ||
| bs = len(self.reqs) | ||
|
|
||
| if self.is_v2_eagle: | ||
| if self.is_v2_eagle or self.is_v2_standalone: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve maintainability and reduce code duplication, you can introduce a new property is_v2_speculative_overlap that combines the checks for is_v2_eagle and is_v2_standalone. This new property can then be used here and in other places like maybe_wait_verify_done in this file, as well as in scheduler_output_processor_mixin.py, to simplify the conditional logic.
| def is_v2_standalone(self): | |
| return self.enable_overlap and self.spec_algorithm.is_standalone() | |
| def prepare_for_decode(self): | |
| self.forward_mode = ForwardMode.DECODE | |
| bs = len(self.reqs) | |
| if self.is_v2_eagle: | |
| if self.is_v2_eagle or self.is_v2_standalone: | |
| def is_v2_standalone(self): | |
| return self.enable_overlap and self.spec_algorithm.is_standalone() | |
| @property | |
| def is_v2_speculative_overlap(self): | |
| return self.is_v2_eagle or self.is_v2_standalone | |
| def prepare_for_decode(self): | |
| self.forward_mode = ForwardMode.DECODE | |
| bs = len(self.reqs) | |
| if self.is_v2_speculative_overlap: |
| (self.speculative_algorithm == "EAGLE" | ||
| or self.speculative_algorithm == "STANDALONE") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def __init__( | ||
| self, | ||
| server_args: ServerArgs, | ||
| gpu_id: int, | ||
| tp_rank: int, | ||
| dp_rank: int, | ||
| moe_ep_rank: int, | ||
| nccl_port: int, | ||
| target_worker: TpModelWorker, | ||
| ): | ||
| # copy args | ||
| self.server_args = server_args | ||
| self.gpu_id = gpu_id | ||
| self.tp_rank = tp_rank | ||
| self.dp_rank = dp_rank | ||
| self.moe_ep_rank = moe_ep_rank | ||
| self.nccl_port = nccl_port | ||
| self.target_worker = target_worker | ||
|
|
||
| # Args for easy access | ||
| self.device = server_args.device | ||
| self.topk = server_args.speculative_eagle_topk | ||
| self.speculative_num_steps = server_args.speculative_num_steps | ||
| self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens | ||
| self.speculative_algorithm = SpeculativeAlgorithm.from_string( | ||
| server_args.speculative_algorithm | ||
| ) | ||
|
|
||
| # Set constant | ||
| from sglang.srt.speculative.eagle_info import EagleDraftInput | ||
| EagleDraftInput.ALLOC_LEN_PER_DECODE = max( | ||
| self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens | ||
| ) | ||
|
|
||
| # Do not capture cuda graph in `TpModelWorker` init, | ||
| # will capture later with init_cuda_graphs() | ||
| backup_disable_cuda_graph = server_args.disable_cuda_graph | ||
| server_args.disable_cuda_graph = True | ||
|
|
||
| # Share the allocator with a target worker. | ||
| # Draft and target worker own their own KV cache pools. | ||
| self.req_to_token_pool, self.token_to_kv_pool_allocator = ( | ||
| target_worker.get_memory_pool() | ||
| ) | ||
| with empty_context(): | ||
| # Init draft worker | ||
| self.draft_worker = TpModelWorker( | ||
| server_args=server_args, | ||
| gpu_id=gpu_id, | ||
| tp_rank=tp_rank, | ||
| pp_rank=0, # FIXME | ||
| dp_rank=dp_rank, | ||
| moe_ep_rank=moe_ep_rank, | ||
| nccl_port=nccl_port, | ||
| is_draft_worker=True, | ||
| req_to_token_pool=self.req_to_token_pool, | ||
| token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, | ||
| ) | ||
|
|
||
| # Alias for better readability | ||
| self.draft_runner = self.draft_worker.model_runner | ||
|
|
||
| self.init_token_map() | ||
| self.init_lm_head() | ||
|
|
||
| # Init attention backend and cuda graphs | ||
| self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph | ||
| self.draft_tp_context = ( | ||
| draft_tp_context if server_args.enable_dp_attention else empty_context | ||
| ) | ||
| with self.draft_tp_context(self.draft_runner.tp_group): | ||
| self.init_attention_backend() | ||
| self.init_cuda_graphs() | ||
|
|
||
| from sglang.srt.speculative.eagle_utils import TreeMaskMode | ||
| self.tree_mask_mode = TreeMaskMode.FULL_MASK | ||
|
|
||
| self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method of StandaloneDraftWorker is almost a complete copy of EagleDraftWorker.__init__. To avoid this large code duplication and improve maintainability, you can call super().__init__ and rely on the parent class's implementation. The only functional difference, which is not sharing the LM head, is already correctly handled by overriding init_lm_head.
| def __init__( | |
| self, | |
| server_args: ServerArgs, | |
| gpu_id: int, | |
| tp_rank: int, | |
| dp_rank: int, | |
| moe_ep_rank: int, | |
| nccl_port: int, | |
| target_worker: TpModelWorker, | |
| ): | |
| # copy args | |
| self.server_args = server_args | |
| self.gpu_id = gpu_id | |
| self.tp_rank = tp_rank | |
| self.dp_rank = dp_rank | |
| self.moe_ep_rank = moe_ep_rank | |
| self.nccl_port = nccl_port | |
| self.target_worker = target_worker | |
| # Args for easy access | |
| self.device = server_args.device | |
| self.topk = server_args.speculative_eagle_topk | |
| self.speculative_num_steps = server_args.speculative_num_steps | |
| self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens | |
| self.speculative_algorithm = SpeculativeAlgorithm.from_string( | |
| server_args.speculative_algorithm | |
| ) | |
| # Set constant | |
| from sglang.srt.speculative.eagle_info import EagleDraftInput | |
| EagleDraftInput.ALLOC_LEN_PER_DECODE = max( | |
| self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens | |
| ) | |
| # Do not capture cuda graph in `TpModelWorker` init, | |
| # will capture later with init_cuda_graphs() | |
| backup_disable_cuda_graph = server_args.disable_cuda_graph | |
| server_args.disable_cuda_graph = True | |
| # Share the allocator with a target worker. | |
| # Draft and target worker own their own KV cache pools. | |
| self.req_to_token_pool, self.token_to_kv_pool_allocator = ( | |
| target_worker.get_memory_pool() | |
| ) | |
| with empty_context(): | |
| # Init draft worker | |
| self.draft_worker = TpModelWorker( | |
| server_args=server_args, | |
| gpu_id=gpu_id, | |
| tp_rank=tp_rank, | |
| pp_rank=0, # FIXME | |
| dp_rank=dp_rank, | |
| moe_ep_rank=moe_ep_rank, | |
| nccl_port=nccl_port, | |
| is_draft_worker=True, | |
| req_to_token_pool=self.req_to_token_pool, | |
| token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, | |
| ) | |
| # Alias for better readability | |
| self.draft_runner = self.draft_worker.model_runner | |
| self.init_token_map() | |
| self.init_lm_head() | |
| # Init attention backend and cuda graphs | |
| self.draft_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph | |
| self.draft_tp_context = ( | |
| draft_tp_context if server_args.enable_dp_attention else empty_context | |
| ) | |
| with self.draft_tp_context(self.draft_runner.tp_group): | |
| self.init_attention_backend() | |
| self.init_cuda_graphs() | |
| from sglang.srt.speculative.eagle_utils import TreeMaskMode | |
| self.tree_mask_mode = TreeMaskMode.FULL_MASK | |
| self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device) | |
| def __init__( | |
| self, | |
| server_args: ServerArgs, | |
| gpu_id: int, | |
| tp_rank: int, | |
| dp_rank: int, | |
| moe_ep_rank: int, | |
| nccl_port: int, | |
| target_worker: TpModelWorker, | |
| ): | |
| super().__init__( | |
| server_args, | |
| gpu_id, | |
| tp_rank, | |
| dp_rank, | |
| moe_ep_rank, | |
| nccl_port, | |
| target_worker, | |
| ) |
| def __init__( | ||
| self, | ||
| server_args: ServerArgs, | ||
| gpu_id: int, | ||
| tp_rank: int, | ||
| dp_rank: Optional[int], | ||
| moe_ep_rank: int, | ||
| nccl_port: int, | ||
| target_worker: TpModelWorker, | ||
| ): | ||
| # Parse arguments | ||
| self.server_args = server_args | ||
| self.topk = server_args.speculative_eagle_topk | ||
| self.speculative_num_steps = server_args.speculative_num_steps | ||
| self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens | ||
| self.enable_nan_detection = server_args.enable_nan_detection | ||
| self.gpu_id = gpu_id | ||
| self.device = server_args.device | ||
| self._target_worker = target_worker | ||
| self.page_size = server_args.page_size | ||
| self.speculative_algorithm = SpeculativeAlgorithm.from_string( | ||
| server_args.speculative_algorithm | ||
| ) | ||
|
|
||
| self.req_to_token_pool, self.token_to_kv_pool_allocator = ( | ||
| target_worker.get_memory_pool() | ||
| ) | ||
|
|
||
| # Override the context length of the draft model to be the same as the target model. | ||
| server_args.context_length = target_worker.model_runner.model_config.context_len | ||
|
|
||
| # Create our custom draft worker that doesn't share embeddings/lm_head | ||
| self._draft_worker = StandaloneDraftWorker( | ||
| server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker | ||
| ) | ||
|
|
||
| # Some dummy tensors | ||
| self.num_new_pages_per_topk = torch.empty( | ||
| (), dtype=torch.int64, device=self.device | ||
| ) | ||
| self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device) | ||
|
|
||
| self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__ method of StandaloneWorkerV2 is a near-exact copy of EAGLEWorkerV2.__init__, with the only difference being the instantiation of StandaloneDraftWorker instead of EagleDraftWorker. To reduce code duplication and make the code more maintainable, you can call super().__init__ and then just override the self._draft_worker attribute.
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
super().__init__(
server_args,
gpu_id,
tp_rank,
dp_rank,
moe_ep_rank,
nccl_port,
target_worker,
)
# Create our custom draft worker that doesn't share embeddings/lm_head
self._draft_worker = StandaloneDraftWorker(
server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker
)
Motivation
Test And Benchmark
Env: 1*H200
standaloneV2
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist