Skip to content

[QEff Finetune] : Made fixes to training script #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

quic-mamta
Copy link
Contributor

@quic-mamta quic-mamta commented Jun 10, 2025

Made fixes to training script.



def get_preprocessed_samsum(dataset_config, tokenizer, split, context_length=None):
dataset = datasets.load_dataset("Samsung/samsum", split=split, trust_remote_code=True)
dataset = datasets.load_dataset("knkarthick/samsum", split=split, trust_remote_code=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check if this dataset can be used.

def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
if train_config.enable_ddp:
print("Length of dataset before: ", len(dataset))
dataset = pad_dataset(dataset, batch_size, 2)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of 2 use world_size here

@@ -115,10 +115,26 @@ def generate_dataset_config(dataset_name: str) -> Any:
return dataset_config


def pad_dataset(dataset, batch_size, num_replicas):
reminder = len(dataset) % (batch_size * num_replicas)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use remainder as variable name here.

Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com>
@quic-mamta quic-mamta changed the title Made fixes to training script based on recent findings. [QEff Finetune] : Made fixes to training script Jun 12, 2025
@@ -235,11 +241,23 @@ def train(
train_step_metric.append(step_metric_val)

if train_config.grad_scaler:
scaler.scale(loss).backward() # backward pass
if train_config.enable_ddp:
with model.no_sync():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will result in no syncing of gradients at any step.

if train_config.enable_ddp:
# FIXME: We can not stop transfer of gradient across devices every time.
# In grad accumulation last step should transfer gradients across devices.
with model.no_sync():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will result in no syncing of gradients at any step here as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants