Skip to content

Adding data loss term led to very poor performance #28

@annien094

Description

@annien094

Hello!

I have been playing around with the Allen-Cahn example, which trained really well without any data but just the res_loss and ics_loss terms. I wanted to experiment adding a data loss term to the model, which I expect would help with the training as we give the model even more information. However, it actually led to very poor performance. Here is the visualisation of the result:

Image

and the training loss history:

Image Image

The data loss wouldn't go down, and the res_loss and ics_loss are also kept at significantly larger values. cas_weight also did not converge to 1 in this case. I tried turning on and off causal training and NTK weighting, which did not improve the performance. Using only the data loss and turning off res_loss and ics_loss also gave similarly poor performance.

I was wondering if you have had any experience with this? Any advice or suggestions would be much appreciated. Thank you in advance and I look forward to hearing from you.

Best wishes,
Annie
.
.
.
For your reference, the following are snippets of code relevant to adding the additional data term to the loss function.
In losses() I added:

def losses(self, params, batch):

        # Unpack batch
        data_coords_batch, data_batch = batch["data"]
        res_batch = batch["res"]

        ...
        
        # Data loss
        u_pred_data = vmap(self.u_net, (None, 0, 0))(params, data_coords_batch[:, 0], data_coords_batch[:, 1])
        data_loss = jnp.mean((data_batch - u_pred_data) ** 2)

        loss_dict = {"ics": ics_loss, "res": res_loss, "data": data_loss}
        return loss_dict

In compute_diag_ntk() I added:

    def compute_diag_ntk(self, params, batch):
        # Unpack batch
        data_coords_batch, data_batch = batch["data"]
        res_batch = batch["res"]

        ...
        # Compute data_ntk
        data_ntk = vmap(ntk_fn, (None, None, 0, 0))(
            self.u_net, params, data_coords_batch[:, 0], data_coords_batch[:, 1]
        )

        ntk_dict = {"ics": ics_ntk, "res": res_ntk, "data": data_ntk}

        return ntk_dict

where data_coords_batch is generated with:

class DataSampler(BaseSampler):
    def __init__(self, coords, u, test_size, batch_size, train=True, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.u = u.flatten()
        self.coords = coords

        # Perform train-test split
        coords_train, coords_test, u_train, _ = train_test_split(
            self.coords, self.u, test_size=test_size, random_state=42)
        self.u_train = jnp.array(u_train)  # Convert to jnp array so that it can be indexed
        self.selected_coords = coords_train if train else coords_test # Use either training or testing data

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        idx = random.choice(
            key, self.selected_coords.shape[0], shape=(self.batch_size,))
        coords_batch = self.selected_coords[idx, :]
        u_batch = self.u_train[idx, :]
        batch = (data_coords_batch, u_batch)
        return batch

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions