conv.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. class Conv2d(nn.Module):
  5. def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
  6. super().__init__(*args, **kwargs)
  7. self.conv_block = nn.Sequential(
  8. nn.Conv2d(cin, cout, kernel_size, stride, padding),
  9. nn.BatchNorm2d(cout)
  10. )
  11. self.act = nn.ReLU()
  12. self.residual = residual
  13. def forward(self, x):
  14. out = self.conv_block(x)
  15. if self.residual:
  16. out += x
  17. return self.act(out)
  18. class nonorm_Conv2d(nn.Module):
  19. def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
  20. super().__init__(*args, **kwargs)
  21. self.conv_block = nn.Sequential(
  22. nn.Conv2d(cin, cout, kernel_size, stride, padding),
  23. )
  24. self.act = nn.LeakyReLU(0.01, inplace=True)
  25. def forward(self, x):
  26. out = self.conv_block(x)
  27. return self.act(out)
  28. class Conv2dTranspose(nn.Module):
  29. def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
  30. super().__init__(*args, **kwargs)
  31. self.conv_block = nn.Sequential(
  32. nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
  33. nn.BatchNorm2d(cout)
  34. )
  35. self.act = nn.ReLU()
  36. def forward(self, x):
  37. out = self.conv_block(x)
  38. return self.act(out)