Skip to content

Commit

Permalink
updated propensity models to use softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Jul 21, 2024
1 parent 94cbb86 commit f91fd0b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ If importance sampling based methods are evaluating to 0, consider visualising t
The different kinds of importance samples can also be visualised by querying the ```traj_is_weights``` attribute of a given ```ImportanceSampler``` object. If for example, vanilla importance sampling is being used and the samples are not ```NaN``` or ```Inf``` then visualising the ```traj_is_weights``` may provide insight. In particular, IS weights will tend to inifinity when the evaluation policy places large density on an action in comparison to the behaviour policy.

### Release log
#### 6.1.0 (forthcoming)
* Altered discrete torch propensity model to use softmax instead of torch. Requires modelling both classes for binary classification however, improves generalisability of code

#### 6.0.0
* Updated PropensityModels structure for sklearn and added a helper class for compatability with torch
* Full runtime typechecking with jaxtyping
Expand Down
4 changes: 2 additions & 2 deletions src/offline_rl_ope/PropensityModels/torch/models/Discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(

super().__init__(input_dim=input_dim, layers_dim=layers_dim,
actvton=actvton, init_bias=init_bias)

assert all([i>1 for i in out_dim]), "If predicting single positive class. Provide output dim of 2"
self.layers.append(actvton)
# Add the final layer
self.out_layers = nn.ModuleList()
Expand All @@ -31,7 +31,7 @@ def __init__(
in_features=layers_dim[-1],
out_features=head_dim
))
self.out_actvton = nn.Sigmoid()
self.out_actvton = nn.Softmax()

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/offline_rl_ope/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "6.0.0"
__version__ = "6.1.0"

0 comments on commit f91fd0b

Please sign in to comment.